GoMLX:Go语言的机器学习和数学计算框架
一个易于使用的机器学习和通用数学库与工具集,可视为Go语言的PyTorch/Jax/TensorFlow
info 框架简介
GoMLX是一个专为Go语言设计的机器学习和数学计算框架,旨在提供类似于PyTorch、Jax或TensorFlow的功能,但完全基于Go语言实现。它支持从训练、微调到修改和组合机器学习模型的完整工作流程,并提供了一系列工具使这些工作变得简单易用。
GoMLX的主要特点包括:
- 纯Go后端:几乎可以在任何Go运行的环境中工作,包括浏览器(通过WASM)和可能的嵌入式设备
- 高性能后端:支持基于OpenXLA/PJRT的优化引擎,使用即时编译技术针对CPU、GPU(Nvidia,未来将支持AMD ROCm、Intel、Macs和Google的TPU)进行优化
- 完整的工具链:从可微分操作符到训练过程中绘制指标的UI工具
- 易于扩展:支持实验新的优化器思想、复杂的正则化器、奇特的多任务等
architecture 架构设计
GoMLX的架构设计遵循模块化和可扩展的原则,主要分为以下几个核心组件:
前端API层
前端API层是用户直接交互的部分,提供了直观、易用的接口,包括:
- 张量操作:提供类似于NumPy的张量操作接口,支持创建、操作和转换多维数组
- 神经网络层:预定义的常用神经网络层,如全连接层、卷积层、循环层等
- 模型构建:用于构建复杂模型的工具和抽象
- 训练工具:包括优化器、损失函数、指标计算等
计算图表示
GoMLX使用计算图来表示机器学习模型和计算过程。计算图是一个有向无环图(DAG),其中节点表示操作(如加法、乘法、卷积等),边表示数据(张量)在操作之间的流动。
// 计算图示例
func buildModel(ctx context.Context, graph *graph.Graph, inputs []*graph.Node) []*graph.Node {
// 第一个隐藏层
hidden1 := layers.FC(graph, inputs[0], 128, "hidden1")
hidden1 = graph.ReLU(hidden1)
// 第二个隐藏层
hidden2 := layers.FC(graph, hidden1, 64, "hidden2")
hidden2 = graph.ReLU(hidden2)
// 输出层
logits := layers.FC(graph, hidden2, 10, "logits")
return []*graph.Node{logits}
}
后端执行引擎
GoMLX支持多种后端执行引擎,以适应不同的使用场景和性能需求:
| 后端类型 | 特点 | 适用场景 |
|---|---|---|
| 纯Go后端 | 完全用Go实现,可移植性强 | 原型开发、小规模模型、浏览器环境(WASM) |
| OpenXLA/PJRT后端 | 高性能,支持JIT编译到多种硬件 | 大规模模型训练、生产环境部署 |
| StableHLO后端(测试版) | 更简单的安装,更广泛的硬件兼容性 | 未来主要后端,支持ROCm、Apple Metal、Intel等 |
自动微分系统
自动微分是机器学习框架的核心功能,GoMLX实现了基于计算图的自动微分系统。它能够自动计算梯度,支持前向模式和反向模式微分,为模型训练提供必要的梯度信息。
// 自动微分示例
func trainStep(ctx context.Context, model *layers.Model, optimizer train.Optimizer,
inputs, labels []*graph.Node, trainMetrics *train.Metrics) {
// 前向传播
logits := model.Fwd(ctx, inputs...)
// 计算损失
loss := losses.CrossEntropyLogits(logits, labels)
// 计算梯度
gradients := graph.Gradient(ctx, loss, model.Variables()...)
// 更新变量
optimizer.UpdateVariables(gradients...)
// 更新指标
trainMetrics.Update(loss, logits, labels)
}
lightbulb 核心原理
计算图与自动微分
GoMLX的核心原理之一是基于计算图的自动微分。计算图不仅表示了数据流和操作,还记录了操作之间的依赖关系,这使得系统能够自动计算梯度。
自动微分的过程如下:
- 前向传播:执行计算图,记录每个操作的输入和输出
- 反向传播:从输出节点开始,根据链式法则反向计算梯度
- 梯度累积:将梯度传播到参数节点,用于参数更新
内存管理与优化
GoMLX采用了高效的内存管理策略,包括:
- 延迟分配:只在需要时分配内存
- 内存复用:尽可能复用已分配的内存
- 垃圾回收优化:减少不必要的内存分配,降低GC压力
计算图优化
GoMLX在执行计算图之前会进行一系列优化,以提高计算效率:
- 常量折叠:在编译时计算常量表达式
- 操作融合:将多个操作合并为一个,减少内存访问
- 公共子表达式消除:识别并消除重复的计算
- 死代码消除:移除不影响结果的操作
多后端支持
GoMLX的设计允许它支持多种后端执行引擎,这是通过抽象后端接口实现的。后端接口定义了一组标准操作,如:
- 张量创建和操作
- 计算图构建和执行
- 梯度计算
- 设备管理
这种设计使得GoMLX能够灵活地适应不同的硬件环境和性能需求,从纯Go实现到高性能的OpenXLA/PJRT后端。
psychology 设计思想
简洁性与透明性
GoMLX的设计哲学之一是保持简洁性和透明性。它力求让用户能够轻松理解代码的工作原理,形成正确的心理模型。这种设计思想与Go语言的哲学相一致,强调代码的清晰度和可读性。
可扩展性
GoMLX被设计为高度可扩展的框架,允许用户轻松实验新的想法和技术。这种可扩展性体现在:
- 自定义操作:用户可以轻松添加新的操作和函数
- 自定义层:可以创建新的神经网络层
- 自定义优化器:可以实验新的优化算法
- 自定义损失函数:可以定义特定的损失函数
Go语言哲学
GoMLX遵循Go语言的设计哲学,包括:
- 简单性:避免不必要的复杂性
- 明确性:代码应该清晰地表达其意图
- 组合性:通过组合简单的组件构建复杂的系统
- 并发性:利用Go的并发特性提高性能
文档与错误处理
GoMLX非常重视文档和错误处理,认为:
这种设计思想使得GoMLX不仅功能强大,而且易于使用和维护,降低了学习和使用的门槛。
stars 主要特性和优势
speed高性能
支持OpenXLA/PJRT后端,利用即时编译技术优化计算性能,与TensorFlow和Jax在许多情况下具有相同的速度。
devices跨平台
纯Go后端使其几乎可以在任何Go运行的环境中工作,包括浏览器(通过WASM)和可能的嵌入式设备。
extension易于扩展
高度可扩展的架构允许用户轻松实验新的优化器思想、复杂的正则化器、奇特的多任务等。
auto_awesome完整工具链
提供从可微分操作符到训练过程中绘制指标的UI工具的完整工具链,使机器学习工作流程变得简单。
与其他框架的对比
| 特性 | GoMLX | TensorFlow/PyTorch | Jax |
|---|---|---|---|
| 语言 | Go | Python | Python |
| 性能 | 高(通过OpenXLA/PJRT) | 高 | 高 |
| 易用性 | 中等(需要Go知识) | 高 | 中等 |
| 生态系统 | 发展中 | 成熟 | 发展中 |
| 部署 | 简单(Go编译) | 复杂(需要Python环境) | 复杂(需要Python环境) |
StableHLO支持
GoMLX正在添加StableHLO支持作为测试版后端,这将带来以下优势:
- 更简单的安装(只需要PJRT插件)
- 更广泛的硬件兼容性(ROCm、Apple Metal、Intel)
- 访问仅对StableHLO可用的新功能
// 启用StableHLO后端 import _ "github.com/gomlx/gomlx/backends/stablehlo" // 或者使用标签 // go run -tags=stablehlo main.go
code 应用场景和示例
典型应用场景
GoMLX适用于多种机器学习任务和应用场景,包括:
- 图像分类:如MNIST、CIFAR-10、Dogs vs Cats等数据集的分类任务
- 自然语言处理:如IMDB电影评论情感分析
- 生成模型:如牛津花卉102数据集的扩散模型
- 图神经网络:如OGBN-MAG数据集的GNN模型
- 强化学习:如Hive游戏的AlphaZero AI
代码示例:简单的神经网络
package main
import (
"context"
"fmt"
"github.com/gomlx/gomlx/graph"
"github.com/gomlx/gomlx/layers"
"github.com/gomlx/gomlx/ml/train"
"github.com/gomlx/gomlx/ml/train/optimizers"
"github.com/gomlx/gomlx/types/tensors"
)
func main() {
// 创建后端和上下文
backend := backends.New()
ctx := context.Background()
// 创建模型
model := layers.Sequential(
layers.FC(nil, 128, "fc1"),
layers.ReLU(nil),
layers.FC(nil, 64, "fc2"),
layers.ReLU(nil),
layers.FC(nil, 10, "fc3"),
)
// 创建优化器
optimizer := optimizers.Adam().Build()
// 训练循环
for epoch := 0; epoch < 10; epoch++ {
// 这里应该是数据加载和训练步骤
// 简化的训练步骤示例
trainStep(ctx, model, optimizer, inputs, labels)
}
fmt.Println("训练完成")
}
func trainStep(ctx context.Context, model *layers.Model, optimizer train.Optimizer,
inputs, labels []*graph.Node) {
// 前向传播
logits := model.Fwd(ctx, inputs...)
// 计算损失
loss := losses.CrossEntropyLogits(logits, labels)
// 计算梯度
gradients := graph.Gradient(ctx, loss, model.Variables()...)
// 更新变量
optimizer.UpdateVariables(gradients...)
}
代码示例:使用GoMLX进行图像分类
package main
import (
"context"
"fmt"
"github.com/gomlx/gomlx/backends"
"github.com/gomlx/gomlx/graph"
"github.com/gomlx/gomlx/layers"
"github.com/gomlx/gomlx/ml/train"
"github.com/gomlx/gomlx/ml/train/optimizers"
"github.com/gomlx/gomlx/types/tensors"
)
// ConvNet 定义一个简单的卷积神经网络
func ConvNet(g *graph.Graph, inputs []*graph.Node) []*graph.Node {
// 第一个卷积层
conv1 := layers.Conv2D(g, inputs[0], 32, []int{3, 3}, []int{1, 1}, "SAME", "conv1")
conv1 = g.ReLU(conv1)
pool1 := g.MaxPool(conv1, []int{1, 2, 2, 1}, []int{1, 2, 2, 1}, "SAME")
// 第二个卷积层
conv2 := layers.Conv2D(g, pool1, 64, []int{3, 3}, []int{1, 1}, "SAME", "conv2")
conv2 = g.ReLU(conv2)
pool2 := g.MaxPool(conv2, []int{1, 2, 2, 1}, []int{1, 2, 2, 1}, "SAME")
// 展平
shape := g.Shape(pool2)
batchSize := shape.Dim(0)
flattened := g.Reshape(pool2, []int{batchSize, -1})
// 全连接层
fc1 := layers.FC(g, flattened, 128, "fc1")
fc1 = g.ReLU(fc1)
// 输出层
logits := layers.FC(g, fc1, 10, "logits")
return []*graph.Node{logits}
}
func main() {
// 创建后端
backend := backends.New()
ctx := context.Background()
// 创建模型
model := layers.NewModel(ConvNet)
// 创建优化器
optimizer := optimizers.Adam().Build()
// 训练模型(这里省略了数据加载和训练循环的具体实现)
fmt.Println("卷积神经网络模型已创建")
}
实际应用案例
GoMLX已经被用于多个实际项目中,包括:
- GoMLX/Gemma:Google DeepMind的Gemma v2模型的GoMLX实现
- Hive游戏AI:使用GoMLX实现的AlphaZero风格AI,包括WASM演示
- 神经风格迁移:使用GoMLX实现的神经风格迁移算法
- 三元组损失:实现了各种负采样策略和距离度量
download 安装和使用方法
系统要求
GoMLX对系统的要求相对较低,主要包括:
- Go 1.18或更高版本
- 对于高性能后端:Linux/amd64(未来将支持更多平台)
- 可选:NVIDIA GPU(用于GPU加速)
安装步骤
安装GoMLX非常简单,可以通过Go的模块系统进行安装:
# 创建新项目 mkdir my-gomlx-project cd my-gomlx-project go mod init example.com/my-gomlx-project # 添加GoMLX依赖 go get github.com/gomlx/gomlx # 如果需要高性能后端 go get github.com/gomlx/gomlx/backends/xla # 或者使用StableHLO后端(测试版) go get github.com/gomlx/gomlx/backends/stablehlo
基本使用
以下是使用GoMLX的基本步骤:
package main
import (
"context"
"fmt"
"github.com/gomlx/gomlx/backends"
"github.com/gomlx/gomlx/graph"
"github.com/gomlx/gomlx/types/tensors"
)
func main() {
// 1. 创建后端
backend := backends.New()
ctx := context.Background()
// 2. 创建计算图
g := graph.New(ctx, backend)
// 3. 创建输入张量
x := g.Parameter(tensors.FromScalar(2.0))
y := g.Parameter(tensors.FromScalar(3.0))
// 4. 执行计算
z := g.Add(x, g.Mul(x, y)) // z = x + x * y
// 5. 执行计算图
result := g.Run(z)
// 6. 输出结果
fmt.Printf("结果: %f\n", result.Value().(float64)) // 应该输出 8.0
}
学习资源
为了更好地学习和使用GoMLX,可以参考以下资源:
- 官方教程:https://gomlx.github.io/gomlx/notebooks/tutorial.html
- 示例代码:https://github.com/gomlx/gomlx/tree/main/examples
- 文档:https://pkg.go.dev/github.com/gomlx/gomlx?tab=doc
- 博客文章:"GoMLX: ML in Go without Python" by Eli Bendersky
- Slack社区:https://app.slack.com/client/T029RQSE6/C08TX33BX6U
常见问题
以下是一些常见问题和解决方案:
Q: 如何选择后端?
A: 如果您需要高性能或大规模训练,请使用OpenXLA/PJRT或StableHLO后端。如果您需要跨平台兼容性或简单的原型开发,请使用纯Go后端。
Q: 如何调试模型?
A: GoMLX提供了多种调试工具,包括计算图可视化、中间结果检查和详细的错误信息。您还可以使用标准的Go调试工具。
Q: 如何贡献代码?
A: GoMLX欢迎社区贡献。您可以通过提交问题、拉取请求或参与讨论来贡献代码。请确保遵循项目的代码风格和测试要求。
讨论回复
0 条回复还没有人回复,快来发表你的看法吧!