Loading...
正在加载...
请稍候

GoMLX:Go语言的机器学习和数学计算框架

✨步子哥 (steper) 2025年09月25日 14:38
GoMLX:Go语言的机器学习和数学计算框架

GoMLX:Go语言的机器学习和数学计算框架

一个易于使用的机器学习和通用数学库与工具集,可视为Go语言的PyTorch/Jax/TensorFlow

info 框架简介

GoMLX是一个专为Go语言设计的机器学习和数学计算框架,旨在提供类似于PyTorch、Jax或TensorFlow的功能,但完全基于Go语言实现。它支持从训练、微调到修改和组合机器学习模型的完整工作流程,并提供了一系列工具使这些工作变得简单易用。

GoMLX的主要特点包括:

  • 纯Go后端:几乎可以在任何Go运行的环境中工作,包括浏览器(通过WASM)和可能的嵌入式设备
  • 高性能后端:支持基于OpenXLA/PJRT的优化引擎,使用即时编译技术针对CPU、GPU(Nvidia,未来将支持AMD ROCm、Intel、Macs和Google的TPU)进行优化
  • 完整的工具链:从可微分操作符到训练过程中绘制指标的UI工具
  • 易于扩展:支持实验新的优化器思想、复杂的正则化器、奇特的多任务等
"它被开发为Go的全功能ML平台,并便于轻松实验ML想法。它力求简单易读和推理,引导用户形成正确透明的心理模型,与Go哲学保持一致。"

architecture 架构设计

GoMLX的架构设计遵循模块化和可扩展的原则,主要分为以下几个核心组件:

前端API层

前端API层是用户直接交互的部分,提供了直观、易用的接口,包括:

  • 张量操作:提供类似于NumPy的张量操作接口,支持创建、操作和转换多维数组
  • 神经网络层:预定义的常用神经网络层,如全连接层、卷积层、循环层等
  • 模型构建:用于构建复杂模型的工具和抽象
  • 训练工具:包括优化器、损失函数、指标计算等

计算图表示

GoMLX使用计算图来表示机器学习模型和计算过程。计算图是一个有向无环图(DAG),其中节点表示操作(如加法、乘法、卷积等),边表示数据(张量)在操作之间的流动。

Go
// 计算图示例
func buildModel(ctx context.Context, graph *graph.Graph, inputs []*graph.Node) []*graph.Node {
    // 第一个隐藏层
    hidden1 := layers.FC(graph, inputs[0], 128, "hidden1")
    hidden1 = graph.ReLU(hidden1)
    
    // 第二个隐藏层
    hidden2 := layers.FC(graph, hidden1, 64, "hidden2")
    hidden2 = graph.ReLU(hidden2)
    
    // 输出层
    logits := layers.FC(graph, hidden2, 10, "logits")
    
    return []*graph.Node{logits}
}

后端执行引擎

GoMLX支持多种后端执行引擎,以适应不同的使用场景和性能需求:

后端类型 特点 适用场景
纯Go后端 完全用Go实现,可移植性强 原型开发、小规模模型、浏览器环境(WASM)
OpenXLA/PJRT后端 高性能,支持JIT编译到多种硬件 大规模模型训练、生产环境部署
StableHLO后端(测试版) 更简单的安装,更广泛的硬件兼容性 未来主要后端,支持ROCm、Apple Metal、Intel等

自动微分系统

自动微分是机器学习框架的核心功能,GoMLX实现了基于计算图的自动微分系统。它能够自动计算梯度,支持前向模式和反向模式微分,为模型训练提供必要的梯度信息。

Go
// 自动微分示例
func trainStep(ctx context.Context, model *layers.Model, optimizer train.Optimizer, 
               inputs, labels []*graph.Node, trainMetrics *train.Metrics) {
    // 前向传播
    logits := model.Fwd(ctx, inputs...)
    
    // 计算损失
    loss := losses.CrossEntropyLogits(logits, labels)
    
    // 计算梯度
    gradients := graph.Gradient(ctx, loss, model.Variables()...)
    
    // 更新变量
    optimizer.UpdateVariables(gradients...)
    
    // 更新指标
    trainMetrics.Update(loss, logits, labels)
}

lightbulb 核心原理

计算图与自动微分

GoMLX的核心原理之一是基于计算图的自动微分。计算图不仅表示了数据流和操作,还记录了操作之间的依赖关系,这使得系统能够自动计算梯度。

自动微分的过程如下:

  1. 前向传播:执行计算图,记录每个操作的输入和输出
  2. 反向传播:从输出节点开始,根据链式法则反向计算梯度
  3. 梯度累积:将梯度传播到参数节点,用于参数更新

内存管理与优化

GoMLX采用了高效的内存管理策略,包括:

  • 延迟分配:只在需要时分配内存
  • 内存复用:尽可能复用已分配的内存
  • 垃圾回收优化:减少不必要的内存分配,降低GC压力

计算图优化

GoMLX在执行计算图之前会进行一系列优化,以提高计算效率:

  • 常量折叠:在编译时计算常量表达式
  • 操作融合:将多个操作合并为一个,减少内存访问
  • 公共子表达式消除:识别并消除重复的计算
  • 死代码消除:移除不影响结果的操作

多后端支持

GoMLX的设计允许它支持多种后端执行引擎,这是通过抽象后端接口实现的。后端接口定义了一组标准操作,如:

  • 张量创建和操作
  • 计算图构建和执行
  • 梯度计算
  • 设备管理

这种设计使得GoMLX能够灵活地适应不同的硬件环境和性能需求,从纯Go实现到高性能的OpenXLA/PJRT后端。

psychology 设计思想

简洁性与透明性

GoMLX的设计哲学之一是保持简洁性和透明性。它力求让用户能够轻松理解代码的工作原理,形成正确的心理模型。这种设计思想与Go语言的哲学相一致,强调代码的清晰度和可读性。

"它力求简单易读和推理,引导用户形成正确透明的心理模型,没有意外情况——与Go哲学保持一致。虽然有时需要更多的输入(更冗长)。"

可扩展性

GoMLX被设计为高度可扩展的框架,允许用户轻松实验新的想法和技术。这种可扩展性体现在:

  • 自定义操作:用户可以轻松添加新的操作和函数
  • 自定义层:可以创建新的神经网络层
  • 自定义优化器:可以实验新的优化算法
  • 自定义损失函数:可以定义特定的损失函数

Go语言哲学

GoMLX遵循Go语言的设计哲学,包括:

  • 简单性:避免不必要的复杂性
  • 明确性:代码应该清晰地表达其意图
  • 组合性:通过组合简单的组件构建复杂的系统
  • 并发性:利用Go的并发特性提高性能

文档与错误处理

GoMLX非常重视文档和错误处理,认为:

"文档保持最新(如果文档不完善,代码就如同不存在),错误信息有用(总是带有堆栈跟踪)并尝试使解决问题变得容易。"

这种设计思想使得GoMLX不仅功能强大,而且易于使用和维护,降低了学习和使用的门槛。

stars 主要特性和优势

speed高性能

支持OpenXLA/PJRT后端,利用即时编译技术优化计算性能,与TensorFlow和Jax在许多情况下具有相同的速度。

devices跨平台

纯Go后端使其几乎可以在任何Go运行的环境中工作,包括浏览器(通过WASM)和可能的嵌入式设备。

extension易于扩展

高度可扩展的架构允许用户轻松实验新的优化器思想、复杂的正则化器、奇特的多任务等。

auto_awesome完整工具链

提供从可微分操作符到训练过程中绘制指标的UI工具的完整工具链,使机器学习工作流程变得简单。

与其他框架的对比

特性 GoMLX TensorFlow/PyTorch Jax
语言 Go Python Python
性能 高(通过OpenXLA/PJRT)
易用性 中等(需要Go知识) 中等
生态系统 发展中 成熟 发展中
部署 简单(Go编译) 复杂(需要Python环境) 复杂(需要Python环境)

StableHLO支持

GoMLX正在添加StableHLO支持作为测试版后端,这将带来以下优势:

  • 更简单的安装(只需要PJRT插件)
  • 更广泛的硬件兼容性(ROCm、Apple Metal、Intel)
  • 访问仅对StableHLO可用的新功能
Go
// 启用StableHLO后端
import _ "github.com/gomlx/gomlx/backends/stablehlo"

// 或者使用标签
// go run -tags=stablehlo main.go

code 应用场景和示例

典型应用场景

GoMLX适用于多种机器学习任务和应用场景,包括:

  • 图像分类:如MNIST、CIFAR-10、Dogs vs Cats等数据集的分类任务
  • 自然语言处理:如IMDB电影评论情感分析
  • 生成模型:如牛津花卉102数据集的扩散模型
  • 图神经网络:如OGBN-MAG数据集的GNN模型
  • 强化学习:如Hive游戏的AlphaZero AI

代码示例:简单的神经网络

Go
package main

import (
    "context"
    "fmt"
    
    "github.com/gomlx/gomlx/graph"
    "github.com/gomlx/gomlx/layers"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/ml/train/optimizers"
    "github.com/gomlx/gomlx/types/tensors"
)

func main() {
    // 创建后端和上下文
    backend := backends.New()
    ctx := context.Background()
    
    // 创建模型
    model := layers.Sequential(
        layers.FC(nil, 128, "fc1"),
        layers.ReLU(nil),
        layers.FC(nil, 64, "fc2"),
        layers.ReLU(nil),
        layers.FC(nil, 10, "fc3"),
    )
    
    // 创建优化器
    optimizer := optimizers.Adam().Build()
    
    // 训练循环
    for epoch := 0; epoch < 10; epoch++ {
        // 这里应该是数据加载和训练步骤
        // 简化的训练步骤示例
        trainStep(ctx, model, optimizer, inputs, labels)
    }
    
    fmt.Println("训练完成")
}

func trainStep(ctx context.Context, model *layers.Model, optimizer train.Optimizer, 
               inputs, labels []*graph.Node) {
    // 前向传播
    logits := model.Fwd(ctx, inputs...)
    
    // 计算损失
    loss := losses.CrossEntropyLogits(logits, labels)
    
    // 计算梯度
    gradients := graph.Gradient(ctx, loss, model.Variables()...)
    
    // 更新变量
    optimizer.UpdateVariables(gradients...)
}

代码示例:使用GoMLX进行图像分类

Go
package main

import (
    "context"
    "fmt"
    
    "github.com/gomlx/gomlx/backends"
    "github.com/gomlx/gomlx/graph"
    "github.com/gomlx/gomlx/layers"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/ml/train/optimizers"
    "github.com/gomlx/gomlx/types/tensors"
)

// ConvNet 定义一个简单的卷积神经网络
func ConvNet(g *graph.Graph, inputs []*graph.Node) []*graph.Node {
    // 第一个卷积层
    conv1 := layers.Conv2D(g, inputs[0], 32, []int{3, 3}, []int{1, 1}, "SAME", "conv1")
    conv1 = g.ReLU(conv1)
    pool1 := g.MaxPool(conv1, []int{1, 2, 2, 1}, []int{1, 2, 2, 1}, "SAME")
    
    // 第二个卷积层
    conv2 := layers.Conv2D(g, pool1, 64, []int{3, 3}, []int{1, 1}, "SAME", "conv2")
    conv2 = g.ReLU(conv2)
    pool2 := g.MaxPool(conv2, []int{1, 2, 2, 1}, []int{1, 2, 2, 1}, "SAME")
    
    // 展平
    shape := g.Shape(pool2)
    batchSize := shape.Dim(0)
    flattened := g.Reshape(pool2, []int{batchSize, -1})
    
    // 全连接层
    fc1 := layers.FC(g, flattened, 128, "fc1")
    fc1 = g.ReLU(fc1)
    
    // 输出层
    logits := layers.FC(g, fc1, 10, "logits")
    
    return []*graph.Node{logits}
}

func main() {
    // 创建后端
    backend := backends.New()
    ctx := context.Background()
    
    // 创建模型
    model := layers.NewModel(ConvNet)
    
    // 创建优化器
    optimizer := optimizers.Adam().Build()
    
    // 训练模型(这里省略了数据加载和训练循环的具体实现)
    fmt.Println("卷积神经网络模型已创建")
}

实际应用案例

GoMLX已经被用于多个实际项目中,包括:

  • GoMLX/Gemma:Google DeepMind的Gemma v2模型的GoMLX实现
  • Hive游戏AI:使用GoMLX实现的AlphaZero风格AI,包括WASM演示
  • 神经风格迁移:使用GoMLX实现的神经风格迁移算法
  • 三元组损失:实现了各种负采样策略和距离度量

download 安装和使用方法

系统要求

GoMLX对系统的要求相对较低,主要包括:

  • Go 1.18或更高版本
  • 对于高性能后端:Linux/amd64(未来将支持更多平台)
  • 可选:NVIDIA GPU(用于GPU加速)

安装步骤

安装GoMLX非常简单,可以通过Go的模块系统进行安装:

Bash
# 创建新项目
mkdir my-gomlx-project
cd my-gomlx-project
go mod init example.com/my-gomlx-project

# 添加GoMLX依赖
go get github.com/gomlx/gomlx

# 如果需要高性能后端
go get github.com/gomlx/gomlx/backends/xla

# 或者使用StableHLO后端(测试版)
go get github.com/gomlx/gomlx/backends/stablehlo

基本使用

以下是使用GoMLX的基本步骤:

Go
package main

import (
    "context"
    "fmt"
    
    "github.com/gomlx/gomlx/backends"
    "github.com/gomlx/gomlx/graph"
    "github.com/gomlx/gomlx/types/tensors"
)

func main() {
    // 1. 创建后端
    backend := backends.New()
    ctx := context.Background()
    
    // 2. 创建计算图
    g := graph.New(ctx, backend)
    
    // 3. 创建输入张量
    x := g.Parameter(tensors.FromScalar(2.0))
    y := g.Parameter(tensors.FromScalar(3.0))
    
    // 4. 执行计算
    z := g.Add(x, g.Mul(x, y))  // z = x + x * y
    
    // 5. 执行计算图
    result := g.Run(z)
    
    // 6. 输出结果
    fmt.Printf("结果: %f\n", result.Value().(float64))  // 应该输出 8.0
}

学习资源

为了更好地学习和使用GoMLX,可以参考以下资源:

  • 官方教程:https://gomlx.github.io/gomlx/notebooks/tutorial.html
  • 示例代码:https://github.com/gomlx/gomlx/tree/main/examples
  • 文档:https://pkg.go.dev/github.com/gomlx/gomlx?tab=doc
  • 博客文章:"GoMLX: ML in Go without Python" by Eli Bendersky
  • Slack社区:https://app.slack.com/client/T029RQSE6/C08TX33BX6U

常见问题

以下是一些常见问题和解决方案:

Q: 如何选择后端?

A: 如果您需要高性能或大规模训练,请使用OpenXLA/PJRT或StableHLO后端。如果您需要跨平台兼容性或简单的原型开发,请使用纯Go后端。

Q: 如何调试模型?

A: GoMLX提供了多种调试工具,包括计算图可视化、中间结果检查和详细的错误信息。您还可以使用标准的Go调试工具。

Q: 如何贡献代码?

A: GoMLX欢迎社区贡献。您可以通过提交问题、拉取请求或参与讨论来贡献代码。请确保遵循项目的代码风格和测试要求。

讨论回复

0 条回复

还没有人回复,快来发表你的看法吧!