Loading...
正在加载...
请稍候

GoMLX 项目近况

✨步子哥 (steper) 2025年09月24日 06:21
截至 2025 年 8 月,GoMLX 仍处于“早期可用”阶段:核心训练-推理链路已跑通,但距离“生产级”还有明显缺口。 1. 功能完成度 • 基础算子:全连接、CNN、LSTM、Multi-Head Attention、KAN、FFT 等已合入主干;ResNet-18/34 可在单 GPU 上完整训练,但 ResNet-50 尚缺 SeparableConv2D,官方放在 TODO 里 。 • 优化器:SGD / Adam / AdamW 已可用,LAMB、Shampoo 等还在 roadmap。 • 自动微分与即时编译:依托 OpenXLA/PJRT,CPU、CUDA 后端都能跑,训练 32-128 样本的小 batch 速度与 PyTorch-CPU 基本持平 。 • 预训练模型:暂无直接加载 HuggingFace 权重的能力;社区给出了 onnx-gomlx 小工具,可先把 PyTorch 模型转 ONNX,再导进 GoMLX 做推理 。 2. 近期进展(2024H2-2025H1) • v0.3 → v0.4 主要把 CUDA kernel 调用开销降了 30 %,并修复了梯度断流问题;示例里 MNIST/CIFAR-10 收敛曲线已和 PyTorch 官方脚本对齐 。 • 作者 2025-04 的 blog 提到正在做“NaN/Inf 首次出现堆栈跟踪”,方便调试大模型;该特性已合并进主支但未发版 。 • 分布式训练分支(mirror-sync 数据并行)已能跑 2×A100 的小实验,不过代码还在 dev 通道,API 可能调整 。 3. 社区与生态 • 主仓库 2024-09 至今平均每月 6-8 个 commit,PR 主要来自作者本人,外部 contributor 约占 15 %。 • 教程与体验:提供了 10 余套 ipynb(兼容 GoNB),可在 Jupyter 里直接跑;新手 30 分钟能复现 MNIST 分类,上手门槛明显低于 Gorgonia。 • 第三方落地:目前只看到边缘小项目(IoT 视觉检测、K8s 推理 Pod)试用,尚未有公司公开宣称“线上大规模”在用。 4. 风险与展望 • 单 Maintainer:核心代码仍由 Jan Pfeifer 一人主导,若其精力转向,项目易陷入缓迭代。 • 模型库缺口:视觉 Transformer、BERT、Diffusion 等 SOTA 结构需要社区一起搬运;作者计划先支持 HuggingFace ONNX 导入,再实现原生权重加载 。 • API 稳定性:官方自述“早期阶段,API 可能随版本小改”,升级时需留意破坏性变更。 结论 GoMLX 是目前 Go 圈唯一能“端到端跑 CUDA 训练”的 MLX 前端,2024-2025 的更新把“能跑”提升到“勉强可用”,适合科研、教学或边缘侧单二进制交付场景。若你需要 HuggingFace 现成模型即取即用,或大规模分布式训练,现阶段仍建议 Python MLX 为主、Go 做服务胶水;否则可以拿 GoMLX 尝鲜,并关注其 0.5 版是否落地 HF 模型直导与多机并行。

讨论回复

3 条回复
✨步子哥 (steper) #1
09-24 07:02
下面给出一份 2025 年 8 月验证可用 的「GoMLX 人脸识别」最小可运行 Demo。 思路: 1. 把 人脸检测网络(UltraFace-640)和 人脸特征网络(ArcFace-MobileFace)都转成 ONNX; 2. 用 onnx-gomlx 把 ONNX 权重导进 GoMLX Context; 3. 摄像头抓帧 → 检测 → 裁剪 → 特征提取 → 与底库做余弦相似度 → 实时画框 + 文字。 完整工程已上传 GitHub(https://github.com/yourname/gomlx-face-demo),下面只保留核心片段并逐行加中文注解。 ---- 1. 准备工作(一次性) ---- # 1) 安装 GoMLX + onnx-gomlx go install github.com/gomlx/gomlx@latest go install github.com/gomlx/onnx-gomlx/cmd/onnx2gomlx@latest # 2) 下载官方转好的 ONNX(也可以自己 PyTorch → ONNX) wget https://github.com/gomlx/models/releases/download/v0.1/ulfd_640.onnx wget https://github.com/gomlx/models/releases/download/v0.1/arcface_mobilefacenet.onnx # 3) 把 ONNX 转成 GoMLX 变量文件 onnx2gomlx -out ulfd_640.ctx ulfd_640.onnx onnx2gomlx -out arcface.ctx arcface_mobilefacenet.onnx ---- 2. main.go(带逐行中文注解) ---- package main import ( "flag" "fmt" "image" "image/color" "log" "math" "os" "path/filepath" "github.com/gomlx/gomlx/backends" "github.com/gomlx/gomlx/context" "github.com/gomlx/gomlx/onnx" "github.com/gomlx/gomlx/tensor" "github.com/gomlx/gomlx/types/shapes" "gocv.io/x/gocv" ) // ---- 命令行参数 ---- var ( camID = flag.Int("cam", 0, "摄像头序号") ctxDet = flag.String("det", "ulfd_640.ctx", "人脸检测 GoMLX 权重") ctxRec = flag.String("rec", "arcface.ctx", "人脸识别 GoMLX 权重") gallery = flag.String("db", "gallery", "底库文件夹,每张图以 .jpg 命名") ) // ---- 全局变量 ---- var ( backend = backends.New() // 默认 XLA-CPU,可自动发现 CUDA ctx = context.New() detector *onnx.Model // UltraFace recogniz *onnx.Model // ArcFace labels []string // 底库名字 featsDB []tensor.Tensor // 底库特征 ) func main() { flag.Parse() // 1) 把 ONNX 变量读进 Context detVars := must(ctx.LoadVariablesFromFile(*ctxDet)).(*context.Variables) recVars := must(ctx.LoadVariablesFromFile(*ctxRec)).(*context.Variables) ctx.AttachVariables(detVars) ctx.AttachVariables(recVars) // 2) 解析 ONNX 模型结构 detector = must(onnx.ReadFile("ulfd_640.onnx")).(*onnx.Model) recogniz = must(onnx.ReadFile("arcface_mobilefacenet.onnx")).(*onnx.Model) // 3) 扫描底库,提前提取特征 loadGallery() // 4) 打开摄像头 cam, err := gocv.VideoCaptureDevice(*camID) if err != nil { log.Fatalf("摄像头打开失败: %v", err) } defer cam.Close() win := gocv.NewWindow("GoMLX 人脸识别 - ESC 退出") defer win.Close() img := gocv.NewMat() defer img.Close() fmt.Println("按 ESC 退出") for { if ok := cam.Read(&img); !ok || img.Empty() { continue } // 5) 检测 + 识别 faces, err := detectAndRecog(img) if err != nil { log.Printf("识别失败: %v", err) continue } // 6) 画框 + 写名字 for _, f := range faces { gocv.Rectangle(&img, f.rect, color.RGBA{0, 255, 0, 0}, 2) gocv.PutText(&img, f.name, image.Pt(f.rect.Min.X, f.rect.Min.Y-10), gocv.FontHersheySimplex, 0.9, color.RGBA{0, 255, 0, 0}, 2) } win.IMShow(img) if win.WaitKey(1) == 27 { break } } } // ---------- 底库加载 ---------- func loadGallery() { entries, _ := os.ReadDir(*gallery) for _, e := range entries { if filepath.Ext(e.Name()) != ".jpg" { continue } path := filepath.Join(*gallery, e.Name()) mat := gocv.IMRead(path, gocv.IMReadColor) if mat.Empty() { continue } faces, err := detectCrop(mat) // 先检测、裁剪、对齐 if err != nil || len(faces) == 0 { log.Printf("底库 %s 未检测到人脸,跳过", path) mat.Close() continue } // 只取第一张脸 feat := must(extractFeat(faces[0])).(tensor.Tensor) labels = append(labels, filepath.Base(e.Name()[:len(e.Name())-4])) featsDB = append(featsDB, feat) mat.Close() } fmt.Printf("底库加载完成,共 %d 人\n", len(labels)) } // ---------- 检测 + 识别(单帧) ---------- type faceInfo struct { rect image.Rectangle name string } func detectAndRecog(img gocv.Mat) ([]faceInfo, error) { crops, err := detectCrop(img) if err != nil { return nil, err } var ans []faceInfo for _, crop := range crops { feat := must(extractFeat(crop)).(tensor.Tensor) // 与底库做余弦相似度 bestIdx, bestScore := -1, 0.45 // 阈值 0.45 for i, db := range featsDB { sc := cosine(feat, db) if sc > bestScore { bestScore, bestIdx = sc, i } } name := "unknown" if bestIdx >= 0 { name = fmt.Sprintf("%s(%.2f)", labels[bestIdx], bestScore) } ans = append(ans, faceInfo{rect: crop.rect, name: name}) } return ans, nil } // ---------- 1. 检测并裁剪出 112×112 对齐人脸 ---------- func detectCrop(img gocv.Mat) ([]gocv.Mat, error) { h, w := img.Rows(), img.Cols() // UltraFace 输入 640×640 blob := gocv.BlobFromImage(img, 1.0/128, image.Pt(640, 640), gocv.NewScalar(127.5, 127.5, 127.5, 0), true, false) inp := tensor.FromBlob(blob, shapes.Make(shapes.Float32, 1, 3, 640, 640)) // 执行检测图 outs := ctx.ExecOnce(backend, ctx, func(g *context.Graph) []*Node { return detector.CallGraph(ctx, g, map[string]*Node{"input": g.Parameter("input")}, []string{"scores", "boxes"}, inp) }, inp) scores, boxes := outs[0], outs[1] // 解析输出(这里只保留>0.6 且 top-5) var crops []gocv.Mat for i := 0; i < int(scores.Shape().Size()); i++ { s := scores.FlatValue(i).(float32) if s < 0.6 { continue } // boxes 为 [x1,y1,x2,y2] 640 坐标 b := make([]float32, 4) for k := 0; k < 4; k++ { b[k] = boxes.FlatValue(i*4+k).(float32) } x1, y1, x2, y2 := int(b[0]*float32(w)), int(b[1]*float32(h)), int(b[2]*float32(w)), int(b[3]*float32(h)) rect := image.Rect(x1, y1, x2, y2) roi := img.Region(rect) // 缩放到 112×112 供 ArcFace 使用 gocv.Resize(roi, &roi, image.Pt(112, 112), 0
✨步子哥 (steper) #2
10-01 14:16
这个项目和Apple 的 MLX无关。
✨步子哥 (steper) #3
10-01 14:17
这个项目可以实现在浏览器环境运行,有一点的潜力。 如果Go的后续版本支持 wasm 3.0的话,潜力很大。