Grokking:从死记硬背到突然理解——一万步里发生了什么?
什么是 Grokking?诡异之处何在?
Grokking 是机器学习中一个引人注目的现象:模型在训练初期死记硬背训练数据,然后在漫长的等待后突然理解并泛化。具体来说,一个 Transformer 在模加法等小数据集上训练时,往往会在很短的时间内将训练集的每个例子都记住,达到训练准确率100%,但对未见过的数据(同一模运算的新组合)却几乎随机猜测,泛化准确率只有30%左右。训练继续进行,几千步、上万步过去,模型的表现似乎停滞不前,毫无起色。然而,大约在某一点(例如一万两千步左右),模型的泛化准确率会骤然攀升,短短几十步内从30%跃升至95%以上,仿佛一下子“顿悟”了模加法的规律。这种从记忆到理解的延迟转变就是 grokking,源自罗伯特·海因莱因小说中的火星词“grok”,意为完全地、水乳交融地理解。
诡异之处并不在于模型最终学会了泛化——我们当然希望如此。而在于为什么它要等那么久。按理说,当模型已经完美记忆了所有训练数据,它应该“顺便”也学会了背后的规律,从而逐步提高泛化能力。但事实是,模型在记忆和理解之间徘徊了上万步训练,毫无进展,然后瞬间完成过渡。这种漫长的等待期让人困惑:在这漫长的“沉默期”里,模型内部究竟发生了什么?
之前的解释:规范最小化、特征涌现与彩票假说
研究者们此前提出了多种解释来阐明 grokking 等待期的成因。这些解释从不同角度揭示了延迟泛化的一些侧面,但往往忽略了注意力机制本身的特殊性:
已有理论视角
- 规范最小化(Norm Minimization):有观点认为,模型在记忆阶段学到了权重范数较大的参数,这些参数对记忆训练数据有效但对泛化无益。随着训练进行,权重衰减(weight decay)等正则化手段会逐步削减这些大权重,使模型朝向更小的范数解移动。当范数小到一定程度,模型被迫从记忆转向理解,从而出现泛化能力的突然提升。这个解释强调了权重范数在 grokking 中的作用,与 Goldilocks 约束(容量约束)的思想一致:模型容量太大时容易通过大权重死记硬背,容量适中时才倾向于学习可泛化的特征。
- 特征涌现(Feature Emergence):另一种解释关注内部特征的演化。在训练初期,模型可能主要依赖低层、简单的特征来拟合数据,这些特征足以记忆训练集但缺乏泛化能力。随着训练推进,更抽象、更具泛化性的特征逐渐涌现并被模型利用,使得模型从记忆转向理解。这种解释将 grokking 视为特征演变的结果,暗示模型需要时间来“发现”正确的特征表示。
- 彩票假说(Lottery Ticket Hypothesis):彩票假说认为,在随机初始化的神经网络中隐藏着稀疏的、可训练的子网络,称为“中奖彩票”。训练过程本质上是在寻找这些子网络。在 grokking 的情境下,可以理解为模型在训练初期尚未找到那组能够泛化的参数子集,因此只能靠记忆;当某次训练偶然发现了(或逐步逼近)那组“中奖彩票”参数时,模型突然具备了泛化能力。这种解释把 grokking 等待期归因于子网络的发现,强调稀疏性和参数子集的作用。
以上解释各有道理,也各自捕捉到了 grokking 转变的一部分机制。然而,它们都未充分考虑 Transformer 中注意力机制的独特约束。具体来说,这些理论要么将 Transformer 视作一般的参数模型(如规范最小化和特征涌现的解释),要么关注参数子集的发现(如彩票假说),但忽略了注意力作为信息瓶颈的角色:如果注意力机制在某一层丢弃了某个关键信息,那么后续任何计算都无法再恢复该信息。这一约束是注意力模型所特有的,对理解 grokking 至关重要。
注意力机制:一个赌徒的隐喻
Hidajat 等人的新论文从注意力机制出发,提出了一个全新的视角。他们指出,注意力在本质上是一个赌徒——它在对输入序列进行下注,决定哪些信息(token)重要、哪些可以忽略。这种“下注”可以用贝叶斯推断的语言来描述:注意力实际上是在隐式地推断任务的依赖结构(dependency graph),为每个输入 token 分配一个权重,相当于赋予其一个后验概率,表示“这个 token 对于当前任务有多重要”。
关键在于,如果注意力在某一步丢弃了一个实际上对任务信息量很大的 token(即没有给它足够的权重),那么后续的任何有限计算都无法弥补这一损失。因为信息一旦被丢弃,就像赌徒输掉的筹码,再也无法通过后续操作赢回。这与卷积网络或全连接网络不同——后者即使在中间层丢失了某些信号,后续层仍有参数可以重新组合出所需的信息。而注意力模型没有这样的第二次机会:丢弃的信息永远丢失。
因此,泛化要求注意力必须为每一个信息性 token 放置足够的质量。所谓“信息性 token”,是指对解决当前任务有用的输入元素;而“放置足够质量”则意味着注意力权重不能太小,以至于在 Softmax 归一化后该 token 被忽略。只有当注意力正确地识别出所有相关 token 并给予足够关注时,模型才能基于完整的任务结构进行推理,从而实现泛化。
结构推断与贝叶斯彩票:两个可分离的条件
基于上述洞察,作者将 Transformer 的泛化条件拆解为两个独立的部分:
1. MLP 容量的 Goldilocks 约束:这是指 Transformer 中前馈层(MLP)的容量必须适中,不能过大也不能过小。如果 MLP 容量不足,模型根本没有足够的参数去拟合训练数据;但如果容量过大,模型会倾向于通过记忆(死记硬背)来解决问题,因为这样更省事。只有当 MLP 的参数量恰到好处时,模型才被迫学习可泛化的表示。这个条件与之前基于权重范数的解释相吻合:容量过大对应于允许模型通过大范数权重记忆数据,而容量适中则迫使模型寻找更高效的解。简而言之,MLP 容量必须不大不小,这是泛化的第一个必要条件。
2. 注意力的贝叶斯结构条件:这是新提出的条件,强调注意力必须正确推断出任务的依赖结构。具体来说,注意力需要为每一个对任务有用的 token 分配足够的权重,不能遗漏任何关键信息。作者将这一要求形式化为贝叶斯结构条件:注意力权重应被视为对任务依赖图的后验分布,泛化要求这个后验在所有信息性节点上都有非零质量。如果某个关键 token 被注意力忽略(后验质量为零),相当于模型误以为该 token 与任务无关,那么模型就无法学到正确的任务结构,只能靠记忆训练集来弥补。因此,注意力的结构推断(structural inference)是泛化的第二个必要条件。
这两个条件是可分离的:一个条件满足并不意味着另一个也满足。MLP 容量的约束与注意力的推断是独立的两个方面。这种拆解揭示了 grokking 延迟的根源:延迟泛化本质上就是延迟的结构推断。模型之所以要等,是因为结构推断被推迟了。
解释消除:为什么理解要等记忆先消退?
那么,为什么注意力的结构推断会被推迟?答案在于“解释消除”(explaining away)效应。在训练初期,MLP 有足够的容量直接记忆训练数据——它可以通过学习一些与任务本质无关的捷径特征来拟合所有样本。一旦 MLP 将交叉熵损失压到接近零,模型就“认为”自己已经做得很好了,此时梯度信号变得非常微弱。对于注意力来说,这意味着它几乎得不到关于任务结构的有用反馈。
更糟糕的是,由于 MLP 已经解释了所有的输出(即通过记忆方式给出了正确答案),注意力就失去了学习任务结构的动力——它没有动力去弄清楚到底哪些 token 是关键的,因为即使它不提供有用信息,MLP 也能靠记忆把答案“猜”对。这种现象在因果推理中被称为解释消除:一个变量(这里是 MLP 的记忆)已经解释了结果,另一个变量(这里是注意力的结构推断)就被“消除”了影响,没有机会发挥作用。
因此,在训练早期,注意力被“屏蔽”了,无法学习到正确的任务依赖结构。只有当权重衰减等正则化机制逐步侵蚀掉 MLP 的记忆痕迹,使模型对训练数据的拟合不再完美时,损失才会上升,梯度信号重新出现。此时,注意力才能收到反馈,开始学习那些被遗漏的依赖关系。换句话说,grokking 之所以要等,是因为在等待记忆的消退。当 MLP 的记忆被削弱到一定程度,注意力终于有机会推断出正确的任务结构,模型也就完成了从记忆到理解的过渡。
这解释了为何 grokking 时间与权重衰减强度呈反比关系:权重衰减越强,记忆被啃掉得越快,注意力越早得到反馈,grokking 越早发生;反之亦然。这种延迟实际上是一种结构性等待时间,源于记忆对结构推断的抑制。
打破等待:结构干预的标度律
既然 grokking 延迟的根本原因是注意力缺乏结构梯度,那么一个自然的想法是:能否不给注意力“放假”,直接给它一些结构上的指导? 作者正是这样做的。他们引入了一种结构干预:在目标函数中加入一个KL 散度项,将注意力的分布约束向一个先验靠近。这个先验可以是任务依赖结构的某种先验知识,或者简单地鼓励注意力更均匀地关注各个 token,避免过早忽略任何信息。
实验结果非常漂亮:加入 KL 干预后,grokking 的等待时间大幅缩短。更令人惊喜的是,作者发现干预强度与 grokking 时间之间存在清晰的标度律——干预强度每增加一倍,等待时间近似减半。换言之,通过给注意力一点“提示”,我们几乎可以线性加速模型从记忆到理解的过渡。这种 KL 干预巧妙地绕过了“解释消除”的困境:即使 MLP 已经记忆了数据,注意力也不再完全依赖微弱的梯度信号,因为 KL 项直接为注意力提供了方向,让它更快地朝正确的结构推断迈进。
未知与局限:从模加法到更复杂的任务
尽管这一理论为 grokking 提供了一个新颖且有力的解释,仍有一些开放的问题和局限需要承认:
- 与彩票假说的关系:论文标题中出现了“Bayesian Lottery Tickets”,明显是在与彩票假说对话。彩票假说强调参数子网络的重要性,而本文强调注意力的结构推断。两者似乎解释了 grokking 的不同侧面:彩票假说关注哪组参数能泛化,本文关注哪些信息被模型利用。然而,这两种解释是互补的还是相互竞争的?或者说,它们是否可以在一个统一的框架下共存?目前尚不清楚。作者在论文中也没有明确回答这一点,留给读者一个思考的空间。
- 任务范围的扩展:该研究的实验主要在算法序列任务上进行,例如模加法、奇偶判断等。这些任务具有清晰、离散的依赖结构,非常适合验证注意力的结构推断条件。然而,对于更复杂的任务(例如自然语言的句法结构推断),依赖图远没有模加法那么“干净”,充满噪声和模糊性。注意力在这种环境下的贝叶斯推断行为可能截然不同,泛化延迟的机制或许也不同。目前尚不确定本文的结论能否直接推广到这类更模糊的任务上。未来的研究需要在更广泛的任务上检验这一理论的有效性。
- 其他模型架构:本文聚焦于 Transformer 的注意力机制,但 grokking 并非 Transformer 独有。在其他架构(如 RNN、MLP)上也观察到了类似的延迟泛化现象。这些模型没有显式的注意力机制,那么它们的 grokking 是否也能用类似的结构推断延迟来解释?还是需要其他机制?这也是一个值得探索的方向。
总结:等待中的结构与彩票的对话
Hidajat 等人的工作将 grokking 这一机器学习中的谜题拆解为两个可分离的条件:一个是关于模型容量的 Goldilocks 约束,另一个是关于注意力结构的贝叶斯推断条件。他们揭示了延迟泛化的本质——延迟的结构推断,并指出这种延迟源于记忆对结构推断的抑制(解释消除效应)。通过巧妙的结构干预,他们成功缩短了 grokking 的等待时间,并发现了干预强度与等待时间之间的标度律。
这篇论文不仅解释了“一万步里发生了什么”,还提供了一种加速模型理解的方法。它告诉我们:grokking 的等待并非不可避免,只要我们理解了其中的机制,就能主动打破记忆的束缚,让模型更早地“顿悟”。这无疑为理解和控制神经网络的学习过程提供了新的视角。然而,也正如作者所坦诚的,彩票假说与结构推断之间的关系、以及这一理论在更复杂任务上的普适性,仍是未解之谜,有待进一步研究。在 grokking 的研究中,结构推断与彩票假说的对话才刚刚开始,我们期待未来出现更多将二者融会贯通的理论,彻底揭开神经网络从记忆到理解转变的奥秘。