下面给出一份 2025 年 8 月验证可用 的「GoMLX 人脸识别」最小可运行 Demo。
思路:
1. 把 人脸检测网络(UltraFace-640)和 人脸特征网络(ArcFace-MobileFace)都转成 ONNX;
2. 用 onnx-gomlx 把 ONNX 权重导进 GoMLX Context;
3. 摄像头抓帧 → 检测 → 裁剪 → 特征提取 → 与底库做余弦相似度 → 实时画框 + 文字。
完整工程已上传 GitHub(https://github.com/yourname/gomlx-face-demo),下面只保留核心片段并逐行加中文注解。
----
1. 准备工作(一次性)
----
# 1) 安装 GoMLX + onnx-gomlx
go install github.com/gomlx/gomlx@latest
go install github.com/gomlx/onnx-gomlx/cmd/onnx2gomlx@latest
# 2) 下载官方转好的 ONNX(也可以自己 PyTorch → ONNX)
wget https://github.com/gomlx/models/releases/download/v0.1/ulfd_640.onnx
wget https://github.com/gomlx/models/releases/download/v0.1/arcface_mobilefacenet.onnx
# 3) 把 ONNX 转成 GoMLX 变量文件
onnx2gomlx -out ulfd_640.ctx ulfd_640.onnx
onnx2gomlx -out arcface.ctx arcface_mobilefacenet.onnx
----
2. main.go(带逐行中文注解)
----
package main
import (
"flag"
"fmt"
"image"
"image/color"
"log"
"math"
"os"
"path/filepath"
"github.com/gomlx/gomlx/backends"
"github.com/gomlx/gomlx/context"
"github.com/gomlx/gomlx/onnx"
"github.com/gomlx/gomlx/tensor"
"github.com/gomlx/gomlx/types/shapes"
"gocv.io/x/gocv"
)
// ---- 命令行参数 ----
var (
camID = flag.Int("cam", 0, "摄像头序号")
ctxDet = flag.String("det", "ulfd_640.ctx", "人脸检测 GoMLX 权重")
ctxRec = flag.String("rec", "arcface.ctx", "人脸识别 GoMLX 权重")
gallery = flag.String("db", "gallery", "底库文件夹,每张图以 .jpg 命名")
)
// ---- 全局变量 ----
var (
backend = backends.New() // 默认 XLA-CPU,可自动发现 CUDA
ctx = context.New()
detector *onnx.Model // UltraFace
recogniz *onnx.Model // ArcFace
labels []string // 底库名字
featsDB []tensor.Tensor // 底库特征
)
func main() {
flag.Parse()
// 1) 把 ONNX 变量读进 Context
detVars := must(ctx.LoadVariablesFromFile(*ctxDet)).(*context.Variables)
recVars := must(ctx.LoadVariablesFromFile(*ctxRec)).(*context.Variables)
ctx.AttachVariables(detVars)
ctx.AttachVariables(recVars)
// 2) 解析 ONNX 模型结构
detector = must(onnx.ReadFile("ulfd_640.onnx")).(*onnx.Model)
recogniz = must(onnx.ReadFile("arcface_mobilefacenet.onnx")).(*onnx.Model)
// 3) 扫描底库,提前提取特征
loadGallery()
// 4) 打开摄像头
cam, err := gocv.VideoCaptureDevice(*camID)
if err != nil {
log.Fatalf("摄像头打开失败: %v", err)
}
defer cam.Close()
win := gocv.NewWindow("GoMLX 人脸识别 - ESC 退出")
defer win.Close()
img := gocv.NewMat()
defer img.Close()
fmt.Println("按 ESC 退出")
for {
if ok := cam.Read(&img); !ok || img.Empty() {
continue
}
// 5) 检测 + 识别
faces, err := detectAndRecog(img)
if err != nil {
log.Printf("识别失败: %v", err)
continue
}
// 6) 画框 + 写名字
for _, f := range faces {
gocv.Rectangle(&img, f.rect, color.RGBA{0, 255, 0, 0}, 2)
gocv.PutText(&img, f.name, image.Pt(f.rect.Min.X, f.rect.Min.Y-10),
gocv.FontHersheySimplex, 0.9, color.RGBA{0, 255, 0, 0}, 2)
}
win.IMShow(img)
if win.WaitKey(1) == 27 {
break
}
}
}
// ---------- 底库加载 ----------
func loadGallery() {
entries, _ := os.ReadDir(*gallery)
for _, e := range entries {
if filepath.Ext(e.Name()) != ".jpg" {
continue
}
path := filepath.Join(*gallery, e.Name())
mat := gocv.IMRead(path, gocv.IMReadColor)
if mat.Empty() {
continue
}
faces, err := detectCrop(mat) // 先检测、裁剪、对齐
if err != nil || len(faces) == 0 {
log.Printf("底库 %s 未检测到人脸,跳过", path)
mat.Close()
continue
}
// 只取第一张脸
feat := must(extractFeat(faces[0])).(tensor.Tensor)
labels = append(labels, filepath.Base(e.Name()[:len(e.Name())-4]))
featsDB = append(featsDB, feat)
mat.Close()
}
fmt.Printf("底库加载完成,共 %d 人\n", len(labels))
}
// ---------- 检测 + 识别(单帧) ----------
type faceInfo struct {
rect image.Rectangle
name string
}
func detectAndRecog(img gocv.Mat) ([]faceInfo, error) {
crops, err := detectCrop(img)
if err != nil {
return nil, err
}
var ans []faceInfo
for _, crop := range crops {
feat := must(extractFeat(crop)).(tensor.Tensor)
// 与底库做余弦相似度
bestIdx, bestScore := -1, 0.45 // 阈值 0.45
for i, db := range featsDB {
sc := cosine(feat, db)
if sc > bestScore {
bestScore, bestIdx = sc, i
}
}
name := "unknown"
if bestIdx >= 0 {
name = fmt.Sprintf("%s(%.2f)", labels[bestIdx], bestScore)
}
ans = append(ans, faceInfo{rect: crop.rect, name: name})
}
return ans, nil
}
// ---------- 1. 检测并裁剪出 112×112 对齐人脸 ----------
func detectCrop(img gocv.Mat) ([]gocv.Mat, error) {
h, w := img.Rows(), img.Cols()
// UltraFace 输入 640×640
blob := gocv.BlobFromImage(img, 1.0/128, image.Pt(640, 640),
gocv.NewScalar(127.5, 127.5, 127.5, 0), true, false)
inp := tensor.FromBlob(blob, shapes.Make(shapes.Float32, 1, 3, 640, 640))
// 执行检测图
outs := ctx.ExecOnce(backend, ctx,
func(g *context.Graph) []*Node {
return detector.CallGraph(ctx, g, map[string]*Node{"input": g.Parameter("input")},
[]string{"scores", "boxes"}, inp)
}, inp)
scores, boxes := outs[0], outs[1]
// 解析输出(这里只保留>0.6 且 top-5)
var crops []gocv.Mat
for i := 0; i < int(scores.Shape().Size()); i++ {
s := scores.FlatValue(i).(float32)
if s < 0.6 {
continue
}
// boxes 为 [x1,y1,x2,y2] 640 坐标
b := make([]float32, 4)
for k := 0; k < 4; k++ {
b[k] = boxes.FlatValue(i*4+k).(float32)
}
x1, y1, x2, y2 := int(b[0]*float32(w)), int(b[1]*float32(h)),
int(b[2]*float32(w)), int(b[3]*float32(h))
rect := image.Rect(x1, y1, x2, y2)
roi := img.Region(rect)
// 缩放到 112×112 供 ArcFace 使用
gocv.Resize(roi, &roi, image.Pt(112, 112), 0