.jpg 命名")
)
// ---- 全局变量 ----
var (
backend = backends.New() // 默认 XLA-CPU,可自动发现 CUDA
ctx = context.New()
detector onnx.Model // UltraFace
recogniz onnx.Model // ArcFace
labels []string // 底库名字
featsDB []tensor.Tensor // 底库特征
)
func main() {
flag.Parse()
// 1) 把 ONNX 变量读进 Context
detVars := must(ctx.LoadVariablesFromFile(ctxDet)).(context.Variables)
recVars := must(ctx.LoadVariablesFromFile(ctxRec)).(context.Variables)
ctx.AttachVariables(detVars)
ctx.AttachVariables(recVars)
// 2) 解析 ONNX 模型结构
detector = must(onnx.ReadFile("ulfd640.onnx")).(onnx.Model)
recogniz = must(onnx.ReadFile("arcfacemobilefacenet.onnx")).(onnx.Model)
// 3) 扫描底库,提前提取特征
loadGallery()
// 4) 打开摄像头
cam, err := gocv.VideoCaptureDevice(camID)
if err != nil {
log.Fatalf("摄像头打开失败: %v", err)
}
defer cam.Close()
win := gocv.NewWindow("GoMLX 人脸识别 - ESC 退出")
defer win.Close()
img := gocv.NewMat()
defer img.Close()
fmt.Println("按 ESC 退出")
for {
if ok := cam.Read(&img); !ok || img.Empty() {
continue
}
// 5) 检测 + 识别
faces, err := detectAndRecog(img)
if err != nil {
log.Printf("识别失败: %v", err)
continue
}
// 6) 画框 + 写名字
for , f := range faces {
gocv.Rectangle(&img, f.rect, color.RGBA{0, 255, 0, 0}, 2)
gocv.PutText(&img, f.name, image.Pt(f.rect.Min.X, f.rect.Min.Y-10),
gocv.FontHersheySimplex, 0.9, color.RGBA{0, 255, 0, 0}, 2)
}
win.IMShow(img)
if win.WaitKey(1) == 27 {
break
}
}
}
// ---------- 底库加载 ----------
func loadGallery() {
entries, := os.ReadDir(gallery)
for , e := range entries {
if filepath.Ext(e.Name()) != ".jpg" {
continue
}
path := filepath.Join(gallery, e.Name())
mat := gocv.IMRead(path, gocv.IMReadColor)
if mat.Empty() {
continue
}
faces, err := detectCrop(mat) // 先检测、裁剪、对齐
if err != nil || len(faces) == 0 {
log.Printf("底库 %s 未检测到人脸,跳过", path)
mat.Close()
continue
}
// 只取第一张脸
feat := must(extractFeat(faces[0])).(tensor.Tensor)
labels = append(labels, filepath.Base(e.Name()[:len(e.Name())-4]))
featsDB = append(featsDB, feat)
mat.Close()
}
fmt.Printf("底库加载完成,共 %d 人\n", len(labels))
}
// ---------- 检测 + 识别(单帧) ----------
type faceInfo struct {
rect image.Rectangle
name string
}
func detectAndRecog(img gocv.Mat) ([]faceInfo, error) {
crops, err := detectCrop(img)
if err != nil {
return nil, err
}
var ans []faceInfo
for , crop := range crops {
feat := must(extractFeat(crop)).(tensor.Tensor)
// 与底库做余弦相似度
bestIdx, bestScore := -1, 0.45 // 阈值 0.45
for i, db := range featsDB {
sc := cosine(feat, db)
if sc > bestScore {
bestScore, bestIdx = sc, i
}
}
name := "unknown"
if bestIdx >= 0 {
name = fmt.Sprintf("%s(%.2f)", labels[bestIdx], bestScore)
}
ans = append(ans, faceInfo{rect: crop.rect, name: name})
}
return ans, nil
}
// ---------- 1. 检测并裁剪出 112×112 对齐人脸 ----------
func detectCrop(img gocv.Mat) ([]gocv.Mat, error) {
h, w := img.Rows(), img.Cols()
// UltraFace 输入 640×640
blob := gocv.BlobFromImage(img, 1.0/128, image.Pt(640, 640),
gocv.NewScalar(127.5, 127.5, 127.5, 0), true, false)
inp := tensor.FromBlob(blob, shapes.Make(shapes.Float32, 1, 3, 640, 640))
// 执行检测图
outs := ctx.ExecOnce(backend, ctx,
func(g context.Graph) []Node {
return detector.CallGraph(ctx, g, map[string]Node{"input": g.Parameter("input")},
[]string{"scores", "boxes"}, inp)
}, inp)
scores, boxes := outs[0], outs[1]
// 解析输出(这里只保留>0.6 且 top-5)
var crops []gocv.Mat
for i := 0; i < int(scores.Shape().Size()); i++ {
s := scores.FlatValue(i).(float32)
if s < 0.6 {
continue
}
// boxes 为 [x1,y1,x2,y2] 640 坐标
b := make([]float32, 4)
for k := 0; k < 4; k++ {
b[k] = boxes.FlatValue(i4+k).(float32)
}
x1, y1, x2, y2 := int(b[0]float32(w)), int(b[1]float32(h)),
int(b[2]float32(w)), int(b[3]*float32(h))
rect := image.Rect(x1, y1, x2, y2)
roi := img.Region(rect)
// 缩放到 112×112 供 ArcFace 使用
gocv.Resize(roi, &roi, image.Pt(112, 112), 0