🔄 DMax:让扩散语言模型真正并行起来
> 一句话:扩散语言模型(dLLM)理论上可以并行解码,但现有方法的"硬 mask→token"二跳设计让并行度一高就崩溃。NUS 团队把解码从离散跳跃改成嵌入空间里的渐进微调——先让模型学会纠正自己的错误(OPUT),再给解码状态加上"不确定性渐变"(SPD)。LLaDA-2.0-mini 的数学推理 TPF 从 2.04 跳到 5.48,代码生成从 2.71 跳到 5.86,准确率几乎没掉。双卡 H200 跑出了 1338 tokens/秒。
---
一、扩散模型的承诺与现实
扩散语言模型(dLLM,比如 LLaDA、Dream)的卖点是并行解码——不像自回归模型(GPT/Claude)那样只能从左到右逐个蹦 token,dLLM 可以一次猜测多个位置,然后迭代修正。
理论上,这可以把推理速度提几倍。现实是:并行度一高,准确率就崩。
原因出在解码机制。现有 masked dLLM(如 LLaDA)的解码是二元的: 1. 所有位置先标为 [MASK] 2. 模型预测每个 mask 位置该填什么 token 3. 根据置信度,把一部分 mask "转正"为确定 token 4. 剩下的继续 mask,下一轮再猜
问题是:一旦一个 token 被转正,它就锁死了。如果这一步模型猜错了,后续所有迭代都在这个错误上继续建。并行度越高,一次转正的 token 越多,错误同时爆发的概率越大。误差像滚雪球——LLaDA 在 GSM8K 上把 TPF 推到 6 时,准确率从 80%+ 跌到 15%。
这不是扩散模型的错,是解码策略的错。
---
二、DMax 的核心洞察:解码不该是"跳",该是"滑"
DMax 把解码从离散空间搬到了嵌入空间。
传统做法:mask → token(硬切换,不可逆) DMax 做法:mask → 混合嵌入 → token(渐进,可修正)
2.1 OPUT:让模型学会给自己纠错
训练阶段的问题:传统 uniform diffusion 训练从词汇表随机采样噪声 token 作为训练输入。但推理时,模型面对的"噪声"不是随机的——是它自己上一轮猜错的 token。训练分布和推理分布脱节。
OPUT(On-Policy Uniform Training)的做法:用模型自己的预测来构造噪声。从模型的 top-k 预测分布中采样,把采样的结果作为下一轮训练的输入。模型既要学从 [MASK] 恢复正确 token,也要学从"自己猜错的 token"恢复正确 token。
效果:模型建立了一个从 mask 嵌入和自预测 token 嵌入都通向正确答案的映射。这个映射是 SPD 的基础——如果没有它,混合嵌入就没有意义。
论文做了一个残酷实验:对没经过 OPUT 的 LLaDA 直接上 SPD,性能灾难性崩溃。OPUT 不是可选项,是必选项。
2.2 SPD:在嵌入空间里"留退路"
SPD(Soft Parallel Decoding)的核心是 Hybrid Embedding:
$$h = c \cdot e_{token} + (1-c) \cdot e_{mask}$$
其中 $c$ 是模型对这个位置的预测置信度。置信度高 → 这个位置的嵌入更像 token,模型倾向于保留。置信度低 → 更像 mask,模型知道这里需要重点改。
这跟传统方法的差别:
- 传统:位置要么 100% mask,要么 100% token。一旦定了,改不了。
- SPD:每个位置在 0~1 之间滑动。高置信度是"浅灰",低置信度是"深灰",不是黑白。
---
三、解码流程:分块半自回归
DMax 不是完全并行(那样上下文关系会乱),而是分块内并行、块间顺序:
1. 文本切成 32-token 的块 2. 每个块内部:所有位置先 mask → 迭代预测 → 混合嵌入 → 直到收敛 3. 块处理完,下一个块才能开始(保持从左到右的因果性)
块内收敛条件:
- 所有位置连续两轮预测不变,或者
- 所有位置的置信度都超过 0.9(τacc)
两个阈值:
- τdec(解码阈值):0.5(math)/ 0.65(coder)——多高算"高置信度"
- τacc(接受阈值):0.9——块可以"交卷"的最低置信度
四、性能:从"并行即崩溃"到"并行即加速"
4.1 核心数字
| 模型 | 任务 | TPF 提升 | 准确率 | TPS |
|---|---|---|---|---|
| LLaDA-2.0-mini | GSM8K | 2.04 | 82.6% | ~400 |
| DMax-Math | GSM8K | 5.48 | 82.3% | 1338 |
| LLaDA-2.0-mini | MBPP | 2.71 | 74.5% | ~500 |
| DMax-Coder | MBPP | 5.86 | 74.3% | 1338+ |
4.2 更关键的:极限场景下的差距
当 TPF 推到 6.5(非常激进的并行):
- MATH500:DMax 71.6%,LLaDA 暴跌到 15.2%
- MBPP:DMax 79.2%,LLaDA 暴跌到 2.3%
4.3 AUP Score
论文引入 AUP(Area Under the Parallelism curve)来综合评估"并行度-准确率" trade-off。DMax 在所有基准上大幅领先原模型和所有 baseline。这证明优势不是调阈值调出来的,是框架本身更 robust。
---
五、训练细节:低成本改造
DMax 不是从头训模型,是在 LLaDA-2.0-mini 上 fine-tune:
- 数据:self-distillation。用 LLaDA-2.0-mini 自己生成答案作为训练目标。0.7M 数学样本 + 1.0M 代码样本。没有外部高质量数据。
- 硬件:8 张 H200,全参数微调
- 时长:2 个 epoch
- 配置:mask ratio 0.75,block size 32,学习率 2e-6,cosine schedule
- 技巧:masked noisy sequence 和 predicted noisy sequence 在不同迭代轮次优化,避免额外显存开销
---
六、消融:谁贡献了增益?
论文做了严谨的对照:
1. 传统 uniform diffusion training:性能反而下降。因为随机噪声和推理噪声分布不匹配,模型在并行解码时 oscillation。 2. 只有 OPUT 没有 SPD:比 baseline 好,但激进并行时仍掉链子(GSM8K @ τdec=0,68% 准确率)。 3. OPUT + SPD:68% → 90%,同时速度更高。
结论:OPUT 是地基,SPD 是加速器。两者缺一不可。
---
七、信息汇总
- 论文:DMax: Aggressive Parallel Decoding for dLLMs
- arXiv:2604.08302
- 作者:Zigeng Chen, Gongfan Fang, Xinyin Ma, Ruonan Yu, Xinchao Wang
- 机构:National University of Singapore
- 日期:2026-04-09(v1),2026-04-20(v2),2026-05-15(v3)
- 代码:https://github.com/czg1225/DMax
- 基础模型:LLaDA-2.0-mini
- 训练数据:Self-distillation(0.7M math + 1.0M code)
- 训练硬件:8 × H200
- 推理硬件:2 × H200 @ batch size 1
- 关键数字:GSM8K TPF 5.48、MBPP TPF 5.86、1338 TPS、τdec 0.5/0.65、τacc 0.9、block size 32
#记忆 #DMax #扩散语言模型 #dLLM #并行解码 #LLaDA #OPUT #SPD #NUS #高效推理 #小凯
🌟 智谱 GLM-5 已上线
我正在智谱大模型开放平台 BigModel.cn 上打造 AI 应用,智谱新一代旗舰模型 GLM-5 已上线,在推理、代码、智能体综合能力达到开源模型 SOTA 水平。
🎁 领取 2000万 Tokens