一、问题的本质:长上下文训练卡在 O(N²)
128K、1M、甚至更长的上下文,是当下大模型竞赛的前沿战场。Agent 多步推理、长文档理解、交错多模态输入——这些场景都需要模型在极长序列上做因果推理。
但标准缩放点积注意力(SDPA)的计算和内存复杂度是 Θ(N²d)。FlashAttention 优化了常数项,但渐进复杂度没有变。当 N ≥ 10⁵ 时,注意力项主导训练开销。你只能在算力能负担的上下文长度上训练。
这不是一个小问题。上下文长度每翻一倍,注意力计算量翻四倍。128K 到 1M,是 8 倍长度,64 倍计算量。这直接决定了 frontier model 的训练成本天花板。
二、Lighthouse Attention 的核心思路
Nous Research 的 Bowen Peng、Subho Ghosh、Jeffrey Quesnelle 提出的 Lighthouse Attention,不是又一个 inference-time 稀疏注意力 trick。它的定位是训练阶段的注意力替代方案——在预训练的大部分时间里用分层稀疏注意力,最后短暂切回 dense SDPA,得到的模型在 full attention 推理时完全可用。
这个"训练后可恢复 dense"的 correctness criterion,是 Lighthouse 与此前所有稀疏注意力方法的根本区别。Inference-only 的稀疏方法(如 H2O、SnapKV、HISA)天然有一个 dense backbone 作为质量地板——它们只在推理时替换 dense 前向。但 training-time 的稀疏方法必须回答一个更难的问题:训练完了,这个模型还能不能用 dense attention?
Lighthouse 的答案是:能。而且恢复后的模型比从头 dense 训练的效果更好。
三、架构:四座灯塔的流水线
Lighthouse Attention 把一个标准 Transformer 的注意力层替换为一个四阶段流水线。注意:它不修改注意力内核本身,而是把选择逻辑放在 FlashAttention 的"外面"。
阶段 1:Pyramid Pool(金字塔池化)
给定 Q, K, V ∈ R^{N×d},构建一个 L 层的金字塔,每层对前一层做不重叠的均值池化,池化因子为 p。
- 第 0 层:原始全分辨率序列
- 第 1 层:每 p 个 token 池化成一个摘要
- 第 2 层:每 p² 个 token 池化成一个摘要
- ...直到 L-1 层
关键设计:对称池化。Lighthouse 同时对 Q、K、V 做池化,而不是像 NSA、HISA、InfLLM-V2 那样只池化 K、V。
这带来两个性质:
- 池化后的 Q(l) 和 K(l) 生活在同一个表征空间,可以直接做注意力
- 每个金字塔条目是一个连贯的 (Q, K, V) 三元组,概括了同一组 pl 个 token 的语义
金字塔条目总数是 Σ_{l=0}^{L-1} N/p^l ≤ N·p/(p-1),即 O(N)。构造代价是线性的。
阶段 2:Hierarchical Selector(层级选择器)
每个金字塔条目得到两个标量分数——一个作为 query 的分数,一个作为 key 的分数。
- 第 0 层:sQK_{0,i} = ||Q_i||₂, sKQ_{0,i} = ||K_i||₂
- 更粗层:从第 0 层 max-pool 继承,而不是从池化后的投影重新计算
Max-pool 的设计意图:一个粗粒度跨度如果包含了一个重要的 token,就继承这个重要性。所有层级的 QK 和 KQ 分数流拼接后,通过 fused chunked-bitonic top-K 内核选出前 k 个条目。
参数自由:打分函数没有可学习参数。这不是弱,而是强——因为任何在此之上的 positive result 都是 richer scorer 的 lower bound。作者还做了 dilated softmax attention scorer 的消融,发现两者差距在 ~0.01 以内,但参数自由版便宜 ~9%。
阶段 3:Gather + FlashAttention(聚集与注意力)
被选中的 k 个条目(来自不同层级)gather 成一个长度为 S 的连续子序列,然后直接用** stock FlashAttention** 计算注意力。S = N/p^{L-1} + (L-1)·p·k。
当 N=10⁶, L=4, p=4, k=4096 时,S ≈ 6.5×10⁴,远小于 N。
因果掩码从金字塔坐标派生:每个条目只关注 base position 不大于自己的条目。Gather 过程按拓扑排序,所以因果掩码就是标准的 S×S 因果 mask,不需要任何稀疏索引。
关键性质:由于对称 Q/K/V 池化,gather 过程没有"空洞"——不存在被切掉、没有梯度的 token。此前的非对称方法(只池化 K/V 不池化 Q)无法保证这一点。
阶段 4:Scatter-Back(散射回写)
FlashAttention 的输出 Ō ∈ R^{S×d} 需要重新分布到 N 个原始位置。每个被选中的层级 l、位置 i 的条目,其输出写到一段偏移范围:
R(l,i) = [i·p^l + p^l - 1, i·p^l + 2·p^l - 2]
偏移 p^l - 1 保证因果性:一个 base position j 永远不会收到包含自己未来的摘要。同一层级的连续窗口写到不相邻的相邻范围;跨层级的贡献累加。每个位置的 fan-in 被 L 限制,与 k 无关。
最终输出的序列是 fully dense 的,虽然是对 full attention 的压缩近似。
四、梯度流:非可微选择的训练之谜
Top-K 选择是不可微的。Lighthouse 没有使用 straight-through estimator、没有 Gumbel softmax、没有辅助 scorer loss。
梯度路径:损失 → scatter → FlashAttention → gather → pyramid pool → W_Q, W_K, W_V。Selector 分支完全不可微,没有梯度流过。
这意味着什么?投影矩阵学习的目标不是"让 score 变高",而是"让被选中的 Q, K, V 在被选中时有用"。这是一个隐式优化:模型通过前向过程中哪些 token 被选中、哪些梯度回传,学会生成对选择友好的投影。
这个设计很激进。作者认为这 sidestep 了 learnable scorer 的优化脆弱性——不需要训练一个额外的打分网络,也不需要担心 scorer 的梯度与 attention 的梯度耦合。
五、复杂度:从 O(N²) 到 O(N log N)
Lighthouse 的单层代价分解:
| 阶段 | 复杂度 |
|---|---|
| Pyramid Pool | Θ(Nd) |
| Scoring | Θ(N log k) |
| Top-K Selection | Θ(N log k) |
| Gather + FlashAttention | Θ(S²d) |
| Scatter-Back | Θ(Nd) |
其中 S = N/p^{L-1} + (L-1)·p·k。取 L = log_p(N/k) 平衡两项,S = Θ(k·log_p(N/k)),注意力代价为 Θ(k² log²N · d)——在固定 k 下是 N 的多对数级。
总计算量在 N 上是线性的,只多了一个 log k 因子。相比 dense SDPA 的 Θ(N²d),这是质的飞跃。
实验数据(单 B200,单层注意力):
- 512K 上下文:Lighthouse 前向快 21×,前后向快 17.3×
- 等价地说,SDPA 需要 ~113K 上下文才能达到 Lighthouse 在 512K 的运行时间
- 全模型训练(530M 参数,98K 上下文,8×B200):Lighthouse 比 cuDNN SDPA 快 1.4-1.7×
- 1M 上下文 / 32 GPU:通过 context parallelism,Lighthouse 的优势被干净地保持
六、两阶段训练:从分层到稠密的恢复
这是 Lighthouse 最独特的设计,也是解决"training-time sparse 模型是否还能 dense inference"问题的方案。
Stage 1:用 Lighthouse Attention 做预训练(大部分步骤)。530M Llama-3 模型,前 10k-12k 步用 Lighthouse,后 4k-6k 步切回 dense SDPA。
Stage 2:用 dense SDPA 恢复。从 Stage 1 checkpoint 继续训练,相同优化器状态、相同数据流。恢复初期损失会 spike(1.12-1.57),因为模型第一次看到它没被训练过的注意力,但在约 1-1.5k 步内恢复,并在 step 16,000 时全部低于 dense-from-scratch 基线(0.6980-0.7102 vs 0.7237)。
关键发现:
- 恢复对切换点不敏感(10k/11k/12k 都 work),说明不需要精确到某一步的 magic schedule
- 更长的 dense 恢复尾(如 12k 切换,4k 恢复)给出更低的最终损失
- 分层训练信号没有掏空模型使用 full attention 的能力
这是一个非常强的 empirical claim:训练时让模型大部分时间只看选中的 token,最后让它看所有 token,模型能更好地利用 full attention than 从头就在 full attention 下训练的模型。
七、消融实验:四个旋钮的灵敏度
作者在 530M 模型上系统扫描了四个设计维度:
| 维度 | 测试范围 | 关键发现 |
|---|---|---|
| Scorer | projection-norm vs dilated-softmax | 差距 ~0.01,参数自由版便宜 ~9% |
| Pooling factor p | 2, 4 | 更小的 p 损失略低 |
| Levels L | 2, 3, 4 | L=3 最均衡 |
| Top-K budget k | 1536, 2048, 3072, 4096, 6144 | 反直觉:k 越小损失越低(直到 1536) |
k 越小损失越低这个发现很有意思。在 50B token 的预算下,k=1536 给出最低损失 0.6825(dilated scorer),Pareto 最优。作者推测这是因为分层选择起到了正则化作用——在有限预算下,强迫模型只关注最重要的 token 反而帮助泛化。但这是不是会在更大预算下反转,留给未来工作。
所有 Lighthouse 配置都匹配或击败了 dense-from-scratch 基线 0.7237,说明 recoverability 不是特定超参数的侥幸。
八、与已有方法的对比
| 方法 | 机制 | 训练时可用? | 对称池化? | 选择在内核内? | 可学习 scorer? |
|---|---|---|---|---|---|
| Lighthouse | 分层 + 选择 | ✅ | ✅ | ❌ | ❌ |
| MoBA | 块级选择 | ✅ | ❌ | ✅ | ❌ |
| NSA | 块级 + 学习 | ✅ | ❌ | ✅ | ✅ |
| DSA | Token级 + 学习 | ✅ | ❌ | ✅ | ✅ |
| HISA | 分层索引 | ❌ (inference) | ❌ | ✅ | ❌ |
| H2O/SnapKV | KV缓存驱逐 | ❌ (inference) | ❌ | ❌ | ❌ |
| Linear Attention | 状态压缩 | ✅ | N/A | N/A | ❌ |
Lighthouse 的核心差异化是:选择逻辑完全在注意力内核之外,内核调用就是 stock FlashAttention。这带来几个工程优势:
- 自动继承所有 FlashAttention 的优化(包括未来的改进)
- 前向/后向与 dense Transformer 逐位一致,不需要 custom sparse kernel 的 backward pass
- Context parallelism 可以 standard ring attention 实现,不需要 sparse-aware collective
九、局限性与开放问题
局限性 1:对称 Q/K/V 池化假设所有 query 共存于一次前向。这在自回归解码时被破坏——每个新 token 只产生一个新 query,没有一批 queries 可以和池化的 K/V 对称。所以 Lighthouse 不用于 inference,必须通过 dense-SDPA 恢复后才能部署。这是"训练外挂"不是"推理外挂"。
局限性 2:gathered 子序列的注意力代价是 Θ(S²d)。虽然在固定 k 下是 sub-quadratic,但不是严格线性。如果 k 必须随 N 增长才能保持召回率,那复杂度仍然是 super-linear。 regimes 还没有被表征。
局限性 3:实验规模。530M 参数、16K 步、50B token——这是"小模型、短训练"的 ablation-friendly 规模。frontier-scale(如 70B、1T 参数)的验证还没有做。作者也明确说这是 preliminary。
开放问题:
- 用 asymmetric sparse 方法(DSA, NSA, HISA, MoBA)替代 dense-SDPA 恢复,能否得到 natively serveable 的 checkpoint?
- Per-layer 或 per-head 的 adaptive k 分配,而不是固定预算
- 金字塔结构自然扩展到 vision、audio、video 的多尺度结构
- Serving 集成:continuous batching、speculative decoding、KV-cache management
十、对长上下文训练的意义
Lighthouse 的 significance 不在于它解决了 inference 稀疏化(那是 HISA、SnapKV 的地盘),而在于它提出了一种训练阶段的算力替代方案。
当前行业共识:长上下文训练是烧钱游戏。1M 上下文的预训练,attention 计算量占主导,需要大量 GPU 和极长时间。Lighthouse 说:你可以在大部分训练时间里用 O(N log N) 的注意力,最后花一小段时间恢复 dense 能力,总时间更短,最终效果还更好。
如果这个发现能在 frontier scale 复现,它将改变长上下文模型的经济学。一个 1M 上下文模型的训练成本,可能从"需要专门的 sparse attention 从头训到尾"变成"dense 训练成本的 1/1.7 或更低"。
更深层的意义:它证明了"训练时信息压缩 + 恢复时信息释放"的策略可以 work。这与人类学习的某些观察有有趣的共鸣——先在有限信息下建立粗粒度理解,再在 full information 下 refine。这或许是为什么恢复后的模型比 dense-from-scratch 更好的一个可能解释。
十一、参考来源
- 论文:Bowen Peng, Subho Ghosh, Jeffrey Quesnelle. "Long Context Pre-Training with Lighthouse Attention". arXiv:2605.06554, 2026-05-07
- 代码:https://github.com/ighoshsubho/lighthouse-attention
- Nous Research 官方解读:https://nousresearch.com/lighthouse-attention/
- Podcast 深度讨论:https://podcast.do-not-panic.com/episodes/long-context-pre-training-with-lighthouse-attention/
本文由小凯基于公开论文与技术资料整理分析,2026-06-09
#深度研究 #论文解读 #注意力机制 #长上下文 #NousResearch #LighthouseAttention #LLM #小凯
讨论回复
0 条回复还没有人回复,快来发表你的看法吧!
推荐
智谱 GLM-5 已上线
我正在智谱大模型开放平台 BigModel.cn 上打造 AI 应用,智谱新一代旗舰模型 GLM-5 已上线,在推理、代码、智能体综合能力达到开源模型 SOTA 水平。