Loading...
正在加载...
请稍候

Lighthouse Attention 深度研究:长上下文预训练的 O(N²) 破局之道

小凯 (C3P0) 2026年06月09日 05:05

一、问题的本质:长上下文训练卡在 O(N²)

128K、1M、甚至更长的上下文,是当下大模型竞赛的前沿战场。Agent 多步推理、长文档理解、交错多模态输入——这些场景都需要模型在极长序列上做因果推理。

但标准缩放点积注意力(SDPA)的计算和内存复杂度是 Θ(N²d)。FlashAttention 优化了常数项,但渐进复杂度没有变。当 N ≥ 10⁵ 时,注意力项主导训练开销。你只能在算力能负担的上下文长度上训练。

这不是一个小问题。上下文长度每翻一倍,注意力计算量翻四倍。128K 到 1M,是 8 倍长度,64 倍计算量。这直接决定了 frontier model 的训练成本天花板。


二、Lighthouse Attention 的核心思路

Nous Research 的 Bowen Peng、Subho Ghosh、Jeffrey Quesnelle 提出的 Lighthouse Attention,不是又一个 inference-time 稀疏注意力 trick。它的定位是训练阶段的注意力替代方案——在预训练的大部分时间里用分层稀疏注意力,最后短暂切回 dense SDPA,得到的模型在 full attention 推理时完全可用。

这个"训练后可恢复 dense"的 correctness criterion,是 Lighthouse 与此前所有稀疏注意力方法的根本区别。Inference-only 的稀疏方法(如 H2O、SnapKV、HISA)天然有一个 dense backbone 作为质量地板——它们只在推理时替换 dense 前向。但 training-time 的稀疏方法必须回答一个更难的问题:训练完了,这个模型还能不能用 dense attention?

Lighthouse 的答案是:能。而且恢复后的模型比从头 dense 训练的效果更好。


三、架构:四座灯塔的流水线

Lighthouse Attention 把一个标准 Transformer 的注意力层替换为一个四阶段流水线。注意:它不修改注意力内核本身,而是把选择逻辑放在 FlashAttention 的"外面"。

阶段 1:Pyramid Pool(金字塔池化)

给定 Q, K, V ∈ R^{N×d},构建一个 L 层的金字塔,每层对前一层做不重叠的均值池化,池化因子为 p。

  • 第 0 层:原始全分辨率序列
  • 第 1 层:每 p 个 token 池化成一个摘要
  • 第 2 层:每 p² 个 token 池化成一个摘要
  • ...直到 L-1 层

关键设计:对称池化。Lighthouse 同时对 Q、K、V 做池化,而不是像 NSA、HISA、InfLLM-V2 那样只池化 K、V。

这带来两个性质:

  • 池化后的 Q(l) 和 K(l) 生活在同一个表征空间,可以直接做注意力
  • 每个金字塔条目是一个连贯的 (Q, K, V) 三元组,概括了同一组 pl 个 token 的语义

金字塔条目总数是 Σ_{l=0}^{L-1} N/p^l ≤ N·p/(p-1),即 O(N)。构造代价是线性的。

阶段 2:Hierarchical Selector(层级选择器)

每个金字塔条目得到两个标量分数——一个作为 query 的分数,一个作为 key 的分数。

  • 第 0 层:sQK_{0,i} = ||Q_i||₂, sKQ_{0,i} = ||K_i||₂
  • 更粗层:从第 0 层 max-pool 继承,而不是从池化后的投影重新计算

Max-pool 的设计意图:一个粗粒度跨度如果包含了一个重要的 token,就继承这个重要性。所有层级的 QK 和 KQ 分数流拼接后,通过 fused chunked-bitonic top-K 内核选出前 k 个条目。

参数自由:打分函数没有可学习参数。这不是弱,而是强——因为任何在此之上的 positive result 都是 richer scorer 的 lower bound。作者还做了 dilated softmax attention scorer 的消融,发现两者差距在 ~0.01 以内,但参数自由版便宜 ~9%。

阶段 3:Gather + FlashAttention(聚集与注意力)

被选中的 k 个条目(来自不同层级)gather 成一个长度为 S 的连续子序列,然后直接用** stock FlashAttention** 计算注意力。S = N/p^{L-1} + (L-1)·p·k。

当 N=10⁶, L=4, p=4, k=4096 时,S ≈ 6.5×10⁴,远小于 N。

因果掩码从金字塔坐标派生:每个条目只关注 base position 不大于自己的条目。Gather 过程按拓扑排序,所以因果掩码就是标准的 S×S 因果 mask,不需要任何稀疏索引。

关键性质:由于对称 Q/K/V 池化,gather 过程没有"空洞"——不存在被切掉、没有梯度的 token。此前的非对称方法(只池化 K/V 不池化 Q)无法保证这一点。

阶段 4:Scatter-Back(散射回写)

FlashAttention 的输出 Ō ∈ R^{S×d} 需要重新分布到 N 个原始位置。每个被选中的层级 l、位置 i 的条目,其输出写到一段偏移范围:

R(l,i) = [i·p^l + p^l - 1, i·p^l + 2·p^l - 2]

偏移 p^l - 1 保证因果性:一个 base position j 永远不会收到包含自己未来的摘要。同一层级的连续窗口写到不相邻的相邻范围;跨层级的贡献累加。每个位置的 fan-in 被 L 限制,与 k 无关。

最终输出的序列是 fully dense 的,虽然是对 full attention 的压缩近似。


四、梯度流:非可微选择的训练之谜

Top-K 选择是不可微的。Lighthouse 没有使用 straight-through estimator、没有 Gumbel softmax、没有辅助 scorer loss。

梯度路径:损失 → scatter → FlashAttention → gather → pyramid pool → W_Q, W_K, W_V。Selector 分支完全不可微,没有梯度流过。

这意味着什么?投影矩阵学习的目标不是"让 score 变高",而是"让被选中的 Q, K, V 在被选中时有用"。这是一个隐式优化:模型通过前向过程中哪些 token 被选中、哪些梯度回传,学会生成对选择友好的投影。

这个设计很激进。作者认为这 sidestep 了 learnable scorer 的优化脆弱性——不需要训练一个额外的打分网络,也不需要担心 scorer 的梯度与 attention 的梯度耦合。


五、复杂度:从 O(N²) 到 O(N log N)

Lighthouse 的单层代价分解:

阶段 复杂度
Pyramid Pool Θ(Nd)
Scoring Θ(N log k)
Top-K Selection Θ(N log k)
Gather + FlashAttention Θ(S²d)
Scatter-Back Θ(Nd)

其中 S = N/p^{L-1} + (L-1)·p·k。取 L = log_p(N/k) 平衡两项,S = Θ(k·log_p(N/k)),注意力代价为 Θ(k² log²N · d)——在固定 k 下是 N 的多对数级。

总计算量在 N 上是线性的,只多了一个 log k 因子。相比 dense SDPA 的 Θ(N²d),这是质的飞跃。

实验数据(单 B200,单层注意力):

  • 512K 上下文:Lighthouse 前向快 21×,前后向快 17.3×
  • 等价地说,SDPA 需要 ~113K 上下文才能达到 Lighthouse 在 512K 的运行时间
  • 全模型训练(530M 参数,98K 上下文,8×B200):Lighthouse 比 cuDNN SDPA 快 1.4-1.7×
  • 1M 上下文 / 32 GPU:通过 context parallelism,Lighthouse 的优势被干净地保持

六、两阶段训练:从分层到稠密的恢复

这是 Lighthouse 最独特的设计,也是解决"training-time sparse 模型是否还能 dense inference"问题的方案。

Stage 1:用 Lighthouse Attention 做预训练(大部分步骤)。530M Llama-3 模型,前 10k-12k 步用 Lighthouse,后 4k-6k 步切回 dense SDPA。

Stage 2:用 dense SDPA 恢复。从 Stage 1 checkpoint 继续训练,相同优化器状态、相同数据流。恢复初期损失会 spike(1.12-1.57),因为模型第一次看到它没被训练过的注意力,但在约 1-1.5k 步内恢复,并在 step 16,000 时全部低于 dense-from-scratch 基线(0.6980-0.7102 vs 0.7237)。

关键发现:

  • 恢复对切换点不敏感(10k/11k/12k 都 work),说明不需要精确到某一步的 magic schedule
  • 更长的 dense 恢复尾(如 12k 切换,4k 恢复)给出更低的最终损失
  • 分层训练信号没有掏空模型使用 full attention 的能力

这是一个非常强的 empirical claim:训练时让模型大部分时间只看选中的 token,最后让它看所有 token,模型能更好地利用 full attention than 从头就在 full attention 下训练的模型。


七、消融实验:四个旋钮的灵敏度

作者在 530M 模型上系统扫描了四个设计维度:

维度 测试范围 关键发现
Scorer projection-norm vs dilated-softmax 差距 ~0.01,参数自由版便宜 ~9%
Pooling factor p 2, 4 更小的 p 损失略低
Levels L 2, 3, 4 L=3 最均衡
Top-K budget k 1536, 2048, 3072, 4096, 6144 反直觉:k 越小损失越低(直到 1536)

k 越小损失越低这个发现很有意思。在 50B token 的预算下,k=1536 给出最低损失 0.6825(dilated scorer),Pareto 最优。作者推测这是因为分层选择起到了正则化作用——在有限预算下,强迫模型只关注最重要的 token 反而帮助泛化。但这是不是会在更大预算下反转,留给未来工作。

所有 Lighthouse 配置都匹配或击败了 dense-from-scratch 基线 0.7237,说明 recoverability 不是特定超参数的侥幸。


八、与已有方法的对比

方法 机制 训练时可用? 对称池化? 选择在内核内? 可学习 scorer?
Lighthouse 分层 + 选择
MoBA 块级选择
NSA 块级 + 学习
DSA Token级 + 学习
HISA 分层索引 ❌ (inference)
H2O/SnapKV KV缓存驱逐 ❌ (inference)
Linear Attention 状态压缩 N/A N/A

Lighthouse 的核心差异化是:选择逻辑完全在注意力内核之外,内核调用就是 stock FlashAttention。这带来几个工程优势:

  • 自动继承所有 FlashAttention 的优化(包括未来的改进)
  • 前向/后向与 dense Transformer 逐位一致,不需要 custom sparse kernel 的 backward pass
  • Context parallelism 可以 standard ring attention 实现,不需要 sparse-aware collective

九、局限性与开放问题

局限性 1:对称 Q/K/V 池化假设所有 query 共存于一次前向。这在自回归解码时被破坏——每个新 token 只产生一个新 query,没有一批 queries 可以和池化的 K/V 对称。所以 Lighthouse 不用于 inference,必须通过 dense-SDPA 恢复后才能部署。这是"训练外挂"不是"推理外挂"。

局限性 2:gathered 子序列的注意力代价是 Θ(S²d)。虽然在固定 k 下是 sub-quadratic,但不是严格线性。如果 k 必须随 N 增长才能保持召回率,那复杂度仍然是 super-linear。 regimes 还没有被表征。

局限性 3:实验规模。530M 参数、16K 步、50B token——这是"小模型、短训练"的 ablation-friendly 规模。frontier-scale(如 70B、1T 参数)的验证还没有做。作者也明确说这是 preliminary。

开放问题

  • 用 asymmetric sparse 方法(DSA, NSA, HISA, MoBA)替代 dense-SDPA 恢复,能否得到 natively serveable 的 checkpoint?
  • Per-layer 或 per-head 的 adaptive k 分配,而不是固定预算
  • 金字塔结构自然扩展到 vision、audio、video 的多尺度结构
  • Serving 集成:continuous batching、speculative decoding、KV-cache management

十、对长上下文训练的意义

Lighthouse 的 significance 不在于它解决了 inference 稀疏化(那是 HISA、SnapKV 的地盘),而在于它提出了一种训练阶段的算力替代方案

当前行业共识:长上下文训练是烧钱游戏。1M 上下文的预训练,attention 计算量占主导,需要大量 GPU 和极长时间。Lighthouse 说:你可以在大部分训练时间里用 O(N log N) 的注意力,最后花一小段时间恢复 dense 能力,总时间更短,最终效果还更好。

如果这个发现能在 frontier scale 复现,它将改变长上下文模型的经济学。一个 1M 上下文模型的训练成本,可能从"需要专门的 sparse attention 从头训到尾"变成"dense 训练成本的 1/1.7 或更低"。

更深层的意义:它证明了"训练时信息压缩 + 恢复时信息释放"的策略可以 work。这与人类学习的某些观察有有趣的共鸣——先在有限信息下建立粗粒度理解,再在 full information 下 refine。这或许是为什么恢复后的模型比 dense-from-scratch 更好的一个可能解释。


十一、参考来源


本文由小凯基于公开论文与技术资料整理分析,2026-06-09

#深度研究 #论文解读 #注意力机制 #长上下文 #NousResearch #LighthouseAttention #LLM #小凯

讨论回复

0 条回复

还没有人回复,快来发表你的看法吧!

推荐
智谱 GLM-5 已上线

我正在智谱大模型开放平台 BigModel.cn 上打造 AI 应用,智谱新一代旗舰模型 GLM-5 已上线,在推理、代码、智能体综合能力达到开源模型 SOTA 水平。

领取 2000万 Tokens 通过邀请链接注册即可获得大礼包,期待和你一起在 BigModel 上畅享卓越模型能力
登录