MiniMax Sparse Attention 深度解析:把稀疏注意力的理论算力优势,变成 GPU 上的真金白银
MiniMax Sparse Attention 深度解析:把稀疏注意力的理论算力优势,变成 GPU 上的真金白银
> 长上下文不是未来,而是现在。Agent 工作流、代码仓库推理、持久化记忆——这些场景都要求模型能同时关注百万级 token。但 softmax 注意力的二次方复杂度让这一切在部署规模上变得不可承受。MiniMax 的答案是:不要重新发明注意力,而是把它稀疏化——并且用 GPU 内核把它提速到理论极限的 90% 以上。
---
一、问题定义:为什么长上下文是当前的瓶颈
大语言模型正在从短单轮交互转向长程 Agent 工作流:跨越数百个交错推理和动作步骤的代码编写与部署、开放网络导航、多样化工具编排、结构化文档生成。但超长上下文对训练和推理都施加了严重的计算和内存瓶颈,其中二次方成本的 softmax 注意力是主要罪魁祸首。
在 1M token 上下文长度下:
- 训练成本:注意力的 FLOPs 与 N² 成正比,1M 上下文意味着比 4K 上下文多 62,500 倍 的注意力计算
- 推理延迟:预填充(prefill)阶段需要处理整个上下文,解码(decoding)阶段 KV 缓存的内存占用线性增长
- 部署可行性:当前最大的 H100/H800 集群在 1M 上下文下的吞吐量让实时交互变得不可能
| 路径 | 代表工作 | 核心思想 | 局限 |
|---|---|---|---|
| 混合架构 | MiniMax-Text-01, Qwen3 | 部分层用线性注意力或滑动窗口替代 softmax | 需要重新设计模型架构,兼容性差 |
| 稀疏 softmax | DeepSeek-V3, NSA, MSA | 在 softmax 注意力内部做稀疏化,保留原有架构 | 需要算法-硬件协同设计才能转化为实际提速 |
---
二、MSA 架构:两阶段块稀疏注意力
2.1 核心设计哲学
MSA 的设计遵循一个简单原则:在 GQA(Grouped Query Attention)基础上做块级稀疏化,而不是从头设计新架构。
这意味着:
- ✅ 完全兼容现有 Transformer 生态(训练框架、推理引擎、量化工具)
- ✅ 支持从零训练,也支持从预训练 GQA checkpoint 近乎无损转换
- ✅ 不需要重新设计模型架构,落地门槛极低
2.2 两阶段架构
输入隐藏状态 X ∈ R^(N×d_model)
↓
┌─────────────────────────────────────┐
│ Index Branch(索引分支) │
│ - 每个 GQA 组一个轻量索引查询头 │
│ - 所有组共享一个索引键头 │
│ - 对可见键 token 打分,块级聚合 │
│ - TopK 选择 k 个关键块 │
│ - 强制保留本地块(包含当前位置) │
└─────────────────────────────────────┘
↓ 选中块索引 I^(r)_i
┌─────────────────────────────────────┐
│ Main Branch(主分支) │
│ - 仅在选中块范围内计算标准 Softmax │
│ - 每个查询头保留独立查询投影 │
│ - 输出注意力结果 │
└─────────────────────────────────────┘
↓
输出 O ∈ R^(N×d_model)
复杂度对比:
| 模块 | GQA | MSA |
|---|---|---|
| Index Branch | 无 | H_kv × d_idx × N²(轻量) |
| Main Branch | 2H_q × d_h × N² | 4H_q × d_h × N × k × B_k |
| 总 FLOPs | ~2H_q d_h N² | ~H_kv d_idx N² + 4H_q d_h N k B_k |
在 1M 上下文下:
- 每 token 注意力计算减少 28.4 倍
- 这意味着 1M 上下文的注意力成本 ≈ 35K 上下文的全注意力成本
三、训练稳定性的工程智慧
3.1 核心挑战:TopK 不可导
TopK 选择是离散的、不可微分的操作。这意味着语言建模损失无法直接训练索引 Q/K 投影 W^idx_q 和 W^idx_k。
MSA 的解决方案:KL 对齐损失 + 三重稳定机制
3.2 KL 对齐损失
P_idx^(r)_i,j = softmax(S_idx^(r)_i,j) over selected tokens
P^(r)_i,j = (1/G) Σ_l∈H_r softmax(S^(l)_i,j) over selected tokens
L_KL = (1/NH_kv) Σ_i Σ_r D_KL(stopgrad(P^(r)_i,·) || P_idx^(r)_i,·)
设计要点:
- 教师分布 P 在概率层面平均:不是取最大分数,而是平均每个查询头的 softmax 分布,确保索引学习的是"共识"注意力模式
- stopgrad 隔离:教师分布和 Index Branch 输入都 detach 梯度,确保 KL 损失只更新索引投影,不影响主干网络
- 仅在选中 token 上计算:避免对未选中位置做无效对齐
3.3 三重稳定机制
| 机制 | 目的 | 实现 |
|---|---|---|
| Gradient Detach | 隔离辅助目标与主干 | Index Branch 输入 X 经过 stopgrad,KL 损失只更新 W^idx_q 和 W^idx_k |
| Indexer Warmup | 避免早期随机选择 | 前 40B token 全注意力运行,仅训练索引投影;之后切换到稀疏注意力 |
| Forced Local Block | 防止退化选择 | 每个查询位置强制保留包含自身的本地块,确保至少关注最近上下文 |
- 3T token 训练过程中,MSA-PT(从零训练)的 LM loss 曲线与全注意力基线"几乎不可区分"
- 梯度范数始终在同一范围内,没有异常波动
- CPT(从全注意力 checkpoint 转换)的 KL 损失在 warmup 阶段迅速下降,切换后保持低位
- Block Recall 保持在很高水平,Score Recall 更高——说明索引器不仅找回了重要块,还找回了大部分注意力质量
四、GPU 内核协同设计:从理论到真金白银
这是 MSA 最硬核的部分。算法层面的稀疏化如果没有硬件层面的优化,就像设计了一辆法拉利但用自行车轮胎——理论快,实际慢。
4.1 Exp-free TopK 内核
洞察:softmax 是保序的(s_i ≤ s_j ⟺ softmax(s)_i ≤ softmax(s)_j),所以排名时不需要先算 softmax。
实现:
- 绕过 max/exp/sum 步骤,直接传递原始分数到 TopK
- 专用小 k 场景优化:Bk=128, k=16
- 每 warp 32 lane 各流式处理 1/32 步长的输入行
- 在共享内存维护 k 元素最小堆,根节点缓存于寄存器
- 延迟写入 + k 轮 shuffle merge 合并 32 个局部 TopK 结果
- 共享内存布局避免 bank conflict
- 相比 torch.topk 和 TileLang radix-select,在所有测试设置下最快
- 在部署设置(B=128, k=16)下优势最大
4.2 KV-outer 稀疏注意力
核心问题:稀疏注意力下,迭代顺序的选择直接影响 Tensor Core 利用率。
Q-outer vs KV-outer 的算术强度对比:
| 维度 | Q-outer | KV-outer |
|---|---|---|
| FLOPs | 4H_q N d_h k B_k | 4H_q N d_h k B_k |
| IO | 4H_q N d_h + 4H_kv N k B_k d_h | 4H_kv N d_h + 4H_q N k d_h + 2H_q N(k+1)d_h |
| FLOPs/IO | ≈ 2k B_k / (2k B_k/G + 1) | ≈ 2/3 B_k |
结论:KV-outer 的算术强度显著更高,因此选择 KV-outer 迭代 + Q gather。
实现细节:
1. 持久化网格:内核作为持久化网格运行在 (kv_block, kv_head) tile 上 2. 反向稀疏索引:从 TopK 选择反向索引出相关查询位置 3. TMA 加载:查询通过 TMA 拷贝加载到共享内存,32 lane 并行分发 4. 预调度 tile 分块:直接一对一 CTA-to-tile 映射会被 sink rows(几乎所有查询都选择的早期 KV 块)主导。调度器将每个 KV tile 沿查询维度切分为最多 ~2k B_k 查询的 chunk,把热点 tile 扇出到多个 CTA 5. 无原子更新:每个 (query, chunk) 对预分配 Obuf 中的 slot,避免原子操作 6. 两阶段前向:注意力内核写局部归一化 partial 到预分配 slot → combine 内核读取有效 slot,计算全局 softmax 归一化权重,输出最终结果 7. Query 拼接:每个 KV tile 通常只关联少量查询位置。内核将 ⌈128/G⌉ 个查询位置及其 G 个关联查询头打包成 128×128 score MMA,填满 Tensor Core
4.3 训练阶段优化
| 优化 | 原理 | 效果 |
|---|---|---|
| LSE 融合 | KL 损失只需要 LSE 值用于反向传播,前向传播时直接从主路径发射 LSE 到全局内存 | 跳过 KL 损失的前向传递,消除冗余前向计算 |
| 动态负载均衡 | 变长序列和数据依赖稀疏性导致每 tile 工作量差异巨大 | 持久化网格中 CTA 通过全局原子计数器认领工作,每 tile 按查询数量动态切分 sub-tile |
五、实验验证:109B MoE 模型的全面评估
5.1 实验设置
- 模型:41 层 MoE,109B 总参数,6B 激活参数/token
- 注意力配置:64 查询头,4 KV 头,头维度 128,RoPE 维度 64
- MSA 配置:块大小 B_k=128,每查询选 k=16 块
- 训练预算:3T token
- 两种训练方式:
- MSA-PT:从零训练(40B token indexer warmup + 稀疏训练)
- MSA-CPT:从 2.6T 全注意力 checkpoint 转换(40B warmup + 400B 稀疏持续训练)
5.2 核心结果
训练稳定性:
- MSA-PT 的 LM loss 曲线与全注意力基线"几乎不可区分"(3T token 全程)
- 梯度范数在同一范围内,无异常波动
- CPT 转换后 KL 损失保持低位,Block Recall 和 Score Recall 都很高
- MSA-PT 和 MSA-CPT 在几乎所有基准上与全注意力持平或略有提升
- 覆盖文本推理(MMLU, MMLU-Pro, BBH, GPQA)、数学(GSM8K, OlymMATH)、代码(HumanEval, BigCodeBench)、多模态图像(MMMU, OCRBench, CharXiv)、多模态视频(VideoMME, MLVU)、长上下文(RULER, HELMET)
- 预填充(Prefill):14.2 倍 wall-clock 提速
- 解码(Decoding):7.6 倍 wall-clock 提速
- 每 token 注意力计算减少 28.4 倍
5.3 关键洞察:为什么精度没掉?
MSA 的设计有几个关键决策保护了模型能力:
1. 块级选择 + 本地块强制保留:确保每个查询至少看到 128 个最近 token,不会丢失局部上下文 2. GQA 组级共享索引:在组内 16 个查询头之间共享选择结果,既减少计算又保持注意力多样性 3. KL 对齐到 Main Branch 共识:索引器学习的是"哪些块真正重要",而不是启发式规则 4. 精确 softmax 在选中块内:Main Branch 仍然是标准 softmax 注意力,没有近似误差
---
六、与同类工作的对比
| 维度 | MSA | DeepSeek NSA | Quest | Infini-attention |
|---|---|---|---|---|
| 稀疏粒度 | 块级 (128 token) | 块级 | 块级 | 段级 + 局部注意力 |
| 索引方式 | 学习式 (轻量 Q/K) | 学习式 | 学习式 | 固定规则 |
| TopK 优化 | 专用 exp-free 内核 | 通用内核 | 通用内核 | 无(固定规则) |
| 注意力迭代顺序 | KV-outer + query concat | Q-outer | Q-outer | — |
| 训练稳定性 | KL + 三重机制 | KL + 辅助损失 | 辅助损失 | 不适用(固定规则) |
| 支持 CPT | ✅ 400B token 验证 | ✅ | ❌ | ❌ |
| 多模态验证 | ✅ 109B 原生多模态 | ❌ 公开 | ❌ 公开 | ❌ 公开 |
| 开源内核 | ✅ GitHub | ❌ | ❌ | ❌ |
---
七、技术启示与行业影响
7.1 对长上下文 LLM 的启示
MSA 证明了一个重要原则:稀疏注意力不是"廉价替代品",而是"精准手术刀"。
- 不是粗略地"少看一些 token",而是学习地看对 token
- 不是牺牲精度换速度,而是通过算法-硬件协同设计同时实现两者
- 不是重新发明架构,而是在现有基础设施上做增量优化
7.2 对 Agent 时代的意义
当前 LLM 应用正在从"聊天"转向"Agent":
- 代码 Agent 需要处理整个代码仓库(10万+ token)
- 多轮工具调用需要维护长期记忆(百万级 token)
- 多模态 Agent 需要同时处理视频、图像、文本(上下文爆炸)
- 1M 上下文的首次响应时间从"分钟级"降到"秒级"
- 7.6 倍解码提速让实时流式输出成为可能
- 28.4 倍计算减少让长上下文训练的硬件成本进入可接受范围
7.3 对硬件-算法协同设计的启示
MSA 的 GPU 内核设计揭示了一个被低估的事实:稀疏化的理论 FLOP 节省 ≠ 实际 wall-clock 提速。
关键转化路径:
算法稀疏化
↓
内存访问模式优化(KV-outer, query concat)
↓
Tensor Core 利用率提升(128×128 MMA)
↓
负载均衡(预调度分块, 持久化网格)
↓
消除冗余计算(LSE 融合, exp-free TopK)
↓
实际 wall-clock 提速
MSA 在这个路径上的每一步都做了针对性优化,这才是 14.2 倍提速的来源——不是稀疏度本身,而是稀疏度如何被硬件执行。
---
八、局限与未来方向
8.1 当前局限
1. 块级粒度的边界效应:128 token 的块大小意味着某些跨边界的细粒度注意力模式可能被遗漏 2. k=16 的固定预算:在所有层和场景下使用相同的 k 值可能不是最优的,某些层可能需要更多上下文 3. 多模态数据的索引学习:图像/视频 token 的索引模式是否与文本 token 相同,仍需更多研究
8.2 未来方向
1. 自适应 k 值:根据层深度、任务类型、上下文长度动态调整 k 2. 层次化稀疏:结合块级稀疏和 token 级稀疏,粗筛+精筛 3. 跨模态索引:针对图像、视频、音频的不同特性设计专门的索引策略 4. 与线性注意力的混合:在部分层使用线性注意力进一步降低复杂度,MSA 在关键层保持精确注意力
---
九、结论:极简主义的力量
MSA 的最深远贡献不是某一个技术指标,而是它展示了一种工程哲学:
> 在长上下文 LLM 的竞争中,不是最复杂的方案 wins,而是最简单且能落地的方案 wins。
MSA 的设计决策可以总结为三个"不":
- 不重新发明注意力:在 GQA 基础上做增量优化
- 不牺牲精度换速度:通过算法-硬件协同设计同时实现两者
- 不增加落地门槛:支持从零训练和 CPT 转换,完全兼容现有生态
这不是"注意力机制的替代",而是"注意力机制的进化"。而进化的方向,是让模型在长上下文中既能看得远,又能看得准。
---
参考论文
Lai, X., Xu, W., Yang, Y., Chen, Q., Xu, Y., Zeng, L., Li, X., Sun, H., Zhu, H., Zhang, V., & Zhao, P. (2026). MiniMax Sparse Attention. *arXiv preprint arXiv:2606.13392*.
开源内核:https://github.com/MiniMax-AI/MSA
基于 MSA 的生产级多模态模型:已公开发布
#MiniMax #SparseAttention #长上下文LLM #GPU优化 #注意力机制 #MoE #大模型训练 #推理加速 #算法硬件协同设计 #AI基础设施
🌟 智谱 GLM-5 已上线
我正在智谱大模型开放平台 BigModel.cn 上打造 AI 应用,智谱新一代旗舰模型 GLM-5 已上线,在推理、代码、智能体综合能力达到开源模型 SOTA 水平。
🎁 领取 2000万 Tokens