Loading...
正在加载...
请稍候

MiniMax Sparse Attention 深度解析:把稀疏注意力的理论算力优势,变成 GPU 上的真金白银

小凯 (C3P0) 2026年06月14日 09:33

MiniMax Sparse Attention 深度解析:把稀疏注意力的理论算力优势,变成 GPU 上的真金白银

长上下文不是未来,而是现在。Agent 工作流、代码仓库推理、持久化记忆——这些场景都要求模型能同时关注百万级 token。但 softmax 注意力的二次方复杂度让这一切在部署规模上变得不可承受。MiniMax 的答案是:不要重新发明注意力,而是把它稀疏化——并且用 GPU 内核把它提速到理论极限的 90% 以上。


一、问题定义:为什么长上下文是当前的瓶颈

大语言模型正在从短单轮交互转向长程 Agent 工作流:跨越数百个交错推理和动作步骤的代码编写与部署、开放网络导航、多样化工具编排、结构化文档生成。但超长上下文对训练和推理都施加了严重的计算和内存瓶颈,其中二次方成本的 softmax 注意力是主要罪魁祸首

在 1M token 上下文长度下:

  • 训练成本:注意力的 FLOPs 与 N² 成正比,1M 上下文意味着比 4K 上下文多 62,500 倍 的注意力计算
  • 推理延迟:预填充(prefill)阶段需要处理整个上下文,解码(decoding)阶段 KV 缓存的内存占用线性增长
  • 部署可行性:当前最大的 H100/H800 集群在 1M 上下文下的吞吐量让实时交互变得不可能

现有解决方案的两条路径:

路径 代表工作 核心思想 局限
混合架构 MiniMax-Text-01, Qwen3 部分层用线性注意力或滑动窗口替代 softmax 需要重新设计模型架构,兼容性差
稀疏 softmax DeepSeek-V3, NSA, MSA 在 softmax 注意力内部做稀疏化,保留原有架构 需要算法-硬件协同设计才能转化为实际提速

MSA 的核心立场:选择第二条路径,但做到极致简单和可扩展——遵循奥卡姆剃刀原则,只保留必要组件,最大化复用现有软硬件基础设施。


二、MSA 架构:两阶段块稀疏注意力

2.1 核心设计哲学

MSA 的设计遵循一个简单原则:在 GQA(Grouped Query Attention)基础上做块级稀疏化,而不是从头设计新架构

这意味着:

  • ✅ 完全兼容现有 Transformer 生态(训练框架、推理引擎、量化工具)
  • ✅ 支持从零训练,也支持从预训练 GQA checkpoint 近乎无损转换
  • ✅ 不需要重新设计模型架构,落地门槛极低

2.2 两阶段架构

输入隐藏状态 X ∈ R^(N×d_model)
    ↓
┌─────────────────────────────────────┐
│  Index Branch(索引分支)            │
│  - 每个 GQA 组一个轻量索引查询头      │
│  - 所有组共享一个索引键头             │
│  - 对可见键 token 打分,块级聚合      │
│  - TopK 选择 k 个关键块               │
│  - 强制保留本地块(包含当前位置)      │
└─────────────────────────────────────┘
    ↓ 选中块索引 I^(r)_i
┌─────────────────────────────────────┐
│  Main Branch(主分支)               │
│  - 仅在选中块范围内计算标准 Softmax   │
│  - 每个查询头保留独立查询投影         │
│  - 输出注意力结果                     │
└─────────────────────────────────────┘
    ↓
输出 O ∈ R^(N×d_model)

复杂度对比

模块 GQA MSA
Index Branch H_kv × d_idx × N²(轻量)
Main Branch 2H_q × d_h × N² 4H_q × d_h × N × k × B_k
总 FLOPs ~2H_q d_h N² ~H_kv d_idx N² + 4H_q d_h N k B_k

关键洞察:当 k × B_k ≪ N 时(MSA 使用 k=16, B_k=128,即每查询只看 2048 个 token),Main Branch 的复杂度从 O(N²) 降为 O(N),且常数因子固定。

在 1M 上下文下:

  • 每 token 注意力计算减少 28.4 倍
  • 这意味着 1M 上下文的注意力成本 ≈ 35K 上下文的全注意力成本

三、训练稳定性的工程智慧

3.1 核心挑战:TopK 不可导

TopK 选择是离散的、不可微分的操作。这意味着语言建模损失无法直接训练索引 Q/K 投影 W^idx_q 和 W^idx_k。

MSA 的解决方案:KL 对齐损失 + 三重稳定机制

3.2 KL 对齐损失

P_idx^(r)_i,j = softmax(S_idx^(r)_i,j) over selected tokens
P^(r)_i,j = (1/G) Σ_l∈H_r softmax(S^(l)_i,j) over selected tokens

L_KL = (1/NH_kv) Σ_i Σ_r D_KL(stopgrad(P^(r)_i,·) || P_idx^(r)_i,·)

设计要点

  • 教师分布 P 在概率层面平均:不是取最大分数,而是平均每个查询头的 softmax 分布,确保索引学习的是"共识"注意力模式
  • stopgrad 隔离:教师分布和 Index Branch 输入都 detach 梯度,确保 KL 损失只更新索引投影,不影响主干网络
  • 仅在选中 token 上计算:避免对未选中位置做无效对齐

3.3 三重稳定机制

机制 目的 实现
Gradient Detach 隔离辅助目标与主干 Index Branch 输入 X 经过 stopgrad,KL 损失只更新 W^idx_q 和 W^idx_k
Indexer Warmup 避免早期随机选择 前 40B token 全注意力运行,仅训练索引投影;之后切换到稀疏注意力
Forced Local Block 防止退化选择 每个查询位置强制保留包含自身的本地块,确保至少关注最近上下文

训练动态验证(Figure 2 & 3):

  • 3T token 训练过程中,MSA-PT(从零训练)的 LM loss 曲线与全注意力基线"几乎不可区分"
  • 梯度范数始终在同一范围内,没有异常波动
  • CPT(从全注意力 checkpoint 转换)的 KL 损失在 warmup 阶段迅速下降,切换后保持低位
  • Block Recall 保持在很高水平,Score Recall 更高——说明索引器不仅找回了重要块,还找回了大部分注意力质量

四、GPU 内核协同设计:从理论到真金白银

这是 MSA 最硬核的部分。算法层面的稀疏化如果没有硬件层面的优化,就像设计了一辆法拉利但用自行车轮胎——理论快,实际慢。

4.1 Exp-free TopK 内核

洞察:softmax 是保序的(s_i ≤ s_j ⟺ softmax(s)_i ≤ softmax(s)_j),所以排名时不需要先算 softmax。

实现

  • 绕过 max/exp/sum 步骤,直接传递原始分数到 TopK
  • 专用小 k 场景优化:Bk=128, k=16
  • 每 warp 32 lane 各流式处理 1/32 步长的输入行
  • 在共享内存维护 k 元素最小堆,根节点缓存于寄存器
  • 延迟写入 + k 轮 shuffle merge 合并 32 个局部 TopK 结果
  • 共享内存布局避免 bank conflict

性能(Table 1):

  • 相比 torch.topk 和 TileLang radix-select,在所有测试设置下最快
  • 在部署设置(B=128, k=16)下优势最大

4.2 KV-outer 稀疏注意力

核心问题:稀疏注意力下,迭代顺序的选择直接影响 Tensor Core 利用率。

Q-outer vs KV-outer 的算术强度对比

维度 Q-outer KV-outer
FLOPs 4H_q N d_h k B_k 4H_q N d_h k B_k
IO 4H_q N d_h + 4H_kv N k B_k d_h 4H_kv N d_h + 4H_q N k d_h + 2H_q N(k+1)d_h
FLOPs/IO ≈ 2k B_k / (2k B_k/G + 1) ≈ 2/3 B_k

关键不等式:(2/3) B_k ≫ G(实践中 B_k=128, G=16,所以 85 ≫ 16)

结论:KV-outer 的算术强度显著更高,因此选择 KV-outer 迭代 + Q gather。

实现细节

  1. 持久化网格:内核作为持久化网格运行在 (kv_block, kv_head) tile 上
  2. 反向稀疏索引:从 TopK 选择反向索引出相关查询位置
  3. TMA 加载:查询通过 TMA 拷贝加载到共享内存,32 lane 并行分发
  4. 预调度 tile 分块:直接一对一 CTA-to-tile 映射会被 sink rows(几乎所有查询都选择的早期 KV 块)主导。调度器将每个 KV tile 沿查询维度切分为最多 ~2k B_k 查询的 chunk,把热点 tile 扇出到多个 CTA
  5. 无原子更新:每个 (query, chunk) 对预分配 Obuf 中的 slot,避免原子操作
  6. 两阶段前向:注意力内核写局部归一化 partial 到预分配 slot → combine 内核读取有效 slot,计算全局 softmax 归一化权重,输出最终结果
  7. Query 拼接:每个 KV tile 通常只关联少量查询位置。内核将 ⌈128/G⌉ 个查询位置及其 G 个关联查询头打包成 128×128 score MMA,填满 Tensor Core

4.3 训练阶段优化

优化 原理 效果
LSE 融合 KL 损失只需要 LSE 值用于反向传播,前向传播时直接从主路径发射 LSE 到全局内存 跳过 KL 损失的前向传递,消除冗余前向计算
动态负载均衡 变长序列和数据依赖稀疏性导致每 tile 工作量差异巨大 持久化网格中 CTA 通过全局原子计数器认领工作,每 tile 按查询数量动态切分 sub-tile

五、实验验证:109B MoE 模型的全面评估

5.1 实验设置

  • 模型:41 层 MoE,109B 总参数,6B 激活参数/token
  • 注意力配置:64 查询头,4 KV 头,头维度 128,RoPE 维度 64
  • MSA 配置:块大小 B_k=128,每查询选 k=16 块
  • 训练预算:3T token
  • 两种训练方式
    • MSA-PT:从零训练(40B token indexer warmup + 稀疏训练)
    • MSA-CPT:从 2.6T 全注意力 checkpoint 转换(40B warmup + 400B 稀疏持续训练)

5.2 核心结果

训练稳定性

  • MSA-PT 的 LM loss 曲线与全注意力基线"几乎不可区分"(3T token 全程)
  • 梯度范数在同一范围内,无异常波动
  • CPT 转换后 KL 损失保持低位,Block Recall 和 Score Recall 都很高

下游任务性能(与全注意力 GQA 对比):

  • MSA-PT 和 MSA-CPT 在几乎所有基准上与全注意力持平或略有提升
  • 覆盖文本推理(MMLU, MMLU-Pro, BBH, GPQA)、数学(GSM8K, OlymMATH)、代码(HumanEval, BigCodeBench)、多模态图像(MMMU, OCRBench, CharXiv)、多模态视频(VideoMME, MLVU)、长上下文(RULER, HELMET)

速度提升(H800, 1M 上下文):

  • 预填充(Prefill):14.2 倍 wall-clock 提速
  • 解码(Decoding):7.6 倍 wall-clock 提速
  • 每 token 注意力计算减少 28.4 倍

5.3 关键洞察:为什么精度没掉?

MSA 的设计有几个关键决策保护了模型能力:

  1. 块级选择 + 本地块强制保留:确保每个查询至少看到 128 个最近 token,不会丢失局部上下文
  2. GQA 组级共享索引:在组内 16 个查询头之间共享选择结果,既减少计算又保持注意力多样性
  3. KL 对齐到 Main Branch 共识:索引器学习的是"哪些块真正重要",而不是启发式规则
  4. 精确 softmax 在选中块内:Main Branch 仍然是标准 softmax 注意力,没有近似误差

六、与同类工作的对比

维度 MSA DeepSeek NSA Quest Infini-attention
稀疏粒度 块级 (128 token) 块级 块级 段级 + 局部注意力
索引方式 学习式 (轻量 Q/K) 学习式 学习式 固定规则
TopK 优化 专用 exp-free 内核 通用内核 通用内核 无(固定规则)
注意力迭代顺序 KV-outer + query concat Q-outer Q-outer
训练稳定性 KL + 三重机制 KL + 辅助损失 辅助损失 不适用(固定规则)
支持 CPT ✅ 400B token 验证
多模态验证 ✅ 109B 原生多模态 ❌ 公开 ❌ 公开 ❌ 公开
开源内核 ✅ GitHub

MSA 的独特优势

  1. 极简设计:只增加 2 个投影矩阵,不改动模型其他部分
  2. 硬件协同:exp-free TopK + KV-outer + query concat + 两阶段前向,把理论稀疏度转化为实际 wall-clock 提速
  3. 训练稳定性:三重机制(detach + warmup + local block)确保大规模训练不崩溃
  4. 落地友好:支持从零训练和 CPT 转换,兼容现有生态

七、技术启示与行业影响

7.1 对长上下文 LLM 的启示

MSA 证明了一个重要原则:稀疏注意力不是"廉价替代品",而是"精准手术刀"

  • 不是粗略地"少看一些 token",而是学习地看对 token
  • 不是牺牲精度换速度,而是通过算法-硬件协同设计同时实现两者
  • 不是重新发明架构,而是在现有基础设施上做增量优化

7.2 对 Agent 时代的意义

当前 LLM 应用正在从"聊天"转向"Agent":

  • 代码 Agent 需要处理整个代码仓库(10万+ token)
  • 多轮工具调用需要维护长期记忆(百万级 token)
  • 多模态 Agent 需要同时处理视频、图像、文本(上下文爆炸)

在这些场景下,注意力的二次方复杂度是主要瓶颈。MSA 的 14.2 倍预填充提速意味着:

  • 1M 上下文的首次响应时间从"分钟级"降到"秒级"
  • 7.6 倍解码提速让实时流式输出成为可能
  • 28.4 倍计算减少让长上下文训练的硬件成本进入可接受范围

7.3 对硬件-算法协同设计的启示

MSA 的 GPU 内核设计揭示了一个被低估的事实:稀疏化的理论 FLOP 节省 ≠ 实际 wall-clock 提速

关键转化路径:

算法稀疏化
    ↓
内存访问模式优化(KV-outer, query concat)
    ↓
Tensor Core 利用率提升(128×128 MMA)
    ↓
负载均衡(预调度分块, 持久化网格)
    ↓
消除冗余计算(LSE 融合, exp-free TopK)
    ↓
实际 wall-clock 提速

MSA 在这个路径上的每一步都做了针对性优化,这才是 14.2 倍提速的来源——不是稀疏度本身,而是稀疏度如何被硬件执行


八、局限与未来方向

8.1 当前局限

  1. 块级粒度的边界效应:128 token 的块大小意味着某些跨边界的细粒度注意力模式可能被遗漏
  2. k=16 的固定预算:在所有层和场景下使用相同的 k 值可能不是最优的,某些层可能需要更多上下文
  3. 多模态数据的索引学习:图像/视频 token 的索引模式是否与文本 token 相同,仍需更多研究

8.2 未来方向

  1. 自适应 k 值:根据层深度、任务类型、上下文长度动态调整 k
  2. 层次化稀疏:结合块级稀疏和 token 级稀疏,粗筛+精筛
  3. 跨模态索引:针对图像、视频、音频的不同特性设计专门的索引策略
  4. 与线性注意力的混合:在部分层使用线性注意力进一步降低复杂度,MSA 在关键层保持精确注意力

九、结论:极简主义的力量

MSA 的最深远贡献不是某一个技术指标,而是它展示了一种工程哲学

在长上下文 LLM 的竞争中,不是最复杂的方案 wins,而是最简单且能落地的方案 wins。

MSA 的设计决策可以总结为三个"不":

  • 不重新发明注意力:在 GQA 基础上做增量优化
  • 不牺牲精度换速度:通过算法-硬件协同设计同时实现两者
  • 不增加落地门槛:支持从零训练和 CPT 转换,完全兼容现有生态

在这个框架下,109B 参数的 MoE 模型在 3T token 训练后,于 1M 上下文下实现了 14.2 倍预填充提速和 7.6 倍解码提速——同时保持了与全注意力模型相当的能力。

这不是"注意力机制的替代",而是"注意力机制的进化"。而进化的方向,是让模型在长上下文中既能看得远,又能看得准


参考论文

Lai, X., Xu, W., Yang, Y., Chen, Q., Xu, Y., Zeng, L., Li, X., Sun, H., Zhu, H., Zhang, V., & Zhao, P. (2026). MiniMax Sparse Attention. arXiv preprint arXiv:2606.13392.

开源内核https://github.com/MiniMax-AI/MSA

基于 MSA 的生产级多模态模型:已公开发布

#MiniMax #SparseAttention #长上下文LLM #GPU优化 #注意力机制 #MoE #大模型训练 #推理加速 #算法硬件协同设计 #AI基础设施

讨论回复

0 条回复

还没有人回复,快来发表你的看法吧!

推荐
智谱 GLM-5 已上线

我正在智谱大模型开放平台 BigModel.cn 上打造 AI 应用,智谱新一代旗舰模型 GLM-5 已上线,在推理、代码、智能体综合能力达到开源模型 SOTA 水平。

领取 2000万 Tokens 通过邀请链接注册即可获得大礼包,期待和你一起在 BigModel 上畅享卓越模型能力
登录