> 论文:Accelerating Speculative Decoding with Block Diffusion Draft Trees > 作者:Liran Ringel, Yaniv Romano (Technion) > arXiv: 2604.12989 (2026.4.14) > 代码:https://github.com/liranringel/ddtree
---
一、问题意识:DFlash 的"浪费"
DFlash 用 block diffusion 一次前向传播生成整个 token block 的 marginal distributions,但只验证单条轨迹。这意味着:
- 每个位置 i 的分布 qi 包含大量信息
- 但 DFlash 只采样一条路径,其余概率质量被丢弃
- 如果能利用这些 per-position distributions 探索多条候选路径,可能大幅提升接受率
---
二、核心贡献:DDTree 三步走
2.1 数学框架
设定:
- 当前上下文 c,bonus token b(已由 target model 生成但尚未前向传播)
- Drafter 一次前向传播产生 L 个 per-position distributions: qi(·|c,b), i=1,...,L
- 这些 marginals 定义 factorized distribution: Q(y₁:L|c,b) = ∏ᵢ qi(yᵢ|c,b)
理想目标(不可行):
max_T E_{Y~p(·|c,b)}[α_T(Y)] ← 需要 target model 的 path-conditioned probs
替代目标(可行):
max_T E_{Y~Q(·|c,b)}[α_T(Y)] ← 只用 drafter 的 factorized distribution
2.2 关键定理
Proposition 1(目标分解):
E_{Y~Q}[α_T(Y)] = Σ_{u∈T} q(u|c,b)
其中 q(u|c,b) = ∏ᵢ qi(uᵢ|c,b) 是 prefix u 的概率。这意味着:目标函数是 prefix 概率的可加和,最优策略就是选概率最高的 B 个 prefix。
Proposition 2(最优性): 取概率最高的 B 个 prefix,它们自动构成有效的 tree(prefix-closed)。这就是最优解。
2.3 Best-First Heap 算法(Algorithm 1)
不枚举所有 O(|V|^L) 个 prefix,而是用 max-heap 高效找 top-B:
# 核心思想:按 rank tuple ρ = (ρ₁,...,ρ_d) 索引 prefix
# ρ_i = k 表示位置 i 取第 k 大概率的 token
# prefix 概率 = ∏_i q_i^{(ρ_i)}
# Heap 初始化:ρ = (1),即每个位置都取 top-1 token
# 每次 pop 最大概率的 prefix,push:
# - sibling: (ρ₁,...,ρ_{d-1}, ρ_d+1) ← 当前位置换次优 token
# - child: (ρ₁,...,ρ_d, 1) ← 扩展到下一位置,取最优 token
复杂度:O(B log B),heap size O(B)。
Lemma 1:只需在每个位置考虑 top-K tokens(K=min(B,|V|)),即可保证最优性。
---
三、代码架构解析
3.1 核心文件结构
ddtree/
├── ddtree.py # DDTree 核心:build_tree + ddtree_generate
├── dflash.py # DFlash 基础实现
├── model/
│ ├── __init__.py # DFlashDraftModel + load_and_process_dataset
│ ├── dflash.py # Draft model 架构(Transformer-based)
│ ├── utils.py # RMSNorm, repeat_kv, apply_rotary_emb
│ └── distributed.py # 多卡分布式支持
├── benchmark.py # 实验框架
├── plot_results.py # 绘图
└── make_latex_table.py # 表格生成
3.2 ddtree.py 核心实现
def build_tree(draft_logits: torch.Tensor, tree_budget: int):
"""
draft_logits: [batch_size, block_size, vocab_size]
返回: tree_tokens [tree_budget], tree_indices [tree_budget, 2]
"""
# 1. 获取 top-K tokens 和概率
draft_probs, draft_topk_ids = draft_logits.softmax(-1).topk(k=min(tree_budget, vocab_size))
# 2. 初始化 heap,从 (1,) 开始
# 3. Best-first 搜索 B 次
# 4. 返回 tree_tokens 和 tree_indices
tree_indices 编码:[index_in_tree, parent_index_in_tree],用于 tree attention mask。
3.3 Tree Attention Mask
def compute_tree_attention_mask(tree_indices: torch.Tensor, bonus_token_offset: int):
"""
构建 ancestor-only attention mask:
- 每个 token 只能 attend to: bonus token + 所有祖先 + 自己
- 不能 attend to siblings 或其他分支
"""
attention_mask = torch.zeros(total_tree_size, total_tree_size, dtype=torch.bool)
# ... 基于 tree_indices 的 parent 关系构建 mask
return attention_mask
关键:ancestor-only mask 保证验证时每个位置的计算不依赖兄弟节点,避免相互干扰。
3.4 KV Cache 管理
def compact_cache(kv_cache, accepted_indices):
"""
每轮结束后,只保留 accepted path 的 KV cache
丢弃未接受的分支,释放内存
"""
# 使用 C++ 扩展加速(compact_attention.cpp)
3.5 验证流程(Verifier Walk)
def verify_draft_tree(target_logits, tree_tokens, tree_indices, bonus_token):
"""
1. 从 bonus token 开始
2. 用 target model 的解码规则(greedy/sampling)选择下一个 token
3. 检查是否匹配 tree 中的某个 child
4. 匹配则继续,不匹配则停止
5. 返回 accepted path 和新的 bonus token
"""
---
四、与相关工作的对比
| 方法 | Drafter 类型 | Tree 构建方式 | 关键区别 |
|---|---|---|---|
| DFlash | Block diffusion | 单路径(贪心) | 只验证一条轨迹 |
| DDTree | Block diffusion | Best-first heap from per-position marginals | 单次 diffusion pass → 最优 tree |
| OPT-Tree | Autoregressive | 逐层前向传播 + 动态选择 | 每层需要一次 drafter forward |
| DART | Parallel logits | N-gram continuity pruning + trie | 需要外部 N-gram 评分 |
| EAGLE-3 | Autoregressive | Feature-based drafting | 多层 feature fusion |
---
五、实验结果分析
5.1 设置
- 模型:Qwen3-4B, Qwen3-8B, Qwen3-Coder-30B-A3B-Instruct
- Drafter:对应 DFlash checkpoints (z-lab/dflash)
- 数据集:MATH-500, GSM8K, AIME 2024/2025, HumanEval, MBPP, LiveCodeBench, SWE-bench Lite, MT-Bench, Alpaca
- 硬件:8× H200
- 温度:0.0 (greedy) 和 1.0 (sampling)
5.2 核心结果
DDTree 在所有 60 个 setting(10 数据集 × 3 模型 × 2 温度)上均优于 vanilla DFlash。
Speedup 范围(相对于 autoregressive decoding):
- Qwen3-4B: ~5-7x
- Qwen3-8B: ~6-8x
- Qwen3-Coder-30B: ~4-6x
- Vanilla DFlash: ~3-5 tokens
- DDTree (best budget): ~4-7 tokens
- 提升幅度:约 10-60%(取决于 budget 和数据集)
5.3 Budget-Quality Tradeoff
Node budget B 的选择:
- B=16:轻量级,适合 latency-sensitive 场景
- B=64-128:sweet spot,性价比最高
- B=256-512:边际收益递减,但某些数据集仍有提升
- B=1024:论文未报告,可能收益有限
---
六、局限与讨论
6.1 论文中提到的局限
1. 固定 block size:当前使用固定 L(如 16),未来可探索自适应 block size 2. 依赖 target hidden states:DFlash 需要 target model 的 5 层 feature,内存开销随 block 增大 3. 长上下文:超长上下文场景仍需优化(虽然 sliding window 已部分缓解) 4. Surrogate objective:基于 drafter 的 factorized distribution,非 target model 的真实分布
6.2 代码层面的观察
1. Flash Attention 依赖:drafter 必须使用 FlashAttention,target model 可用 sdpa 2. C++ 扩展:KV cache compact 有 C++ 加速(compact_attention.cpp),对性能关键 3. Tree attention 的兼容性:需要确保 target model 支持 custom attention mask(如 Qwen 的 GQA)
6.3 与 DFlash 的兼容性
DDTree 完全兼容现有 DFlash checkpoints,无需重新训练 draft model。这是巨大的工程优势——用户可以直接复用 z-lab 发布的所有 DFlash 模型。
---
七、未来方向
1. 自适应 Block Size:根据上下文动态调整 L,而非固定值 2. 多块级联:DDTree → 更大的 block size(32/64),进一步提升接受率 3. 与 Target Model 联合微调:当前 draft model 独立训练,未来可端到端联合优化 4. 多模态扩展:将 block diffusion 思想扩展到 vision-language 模型 5. 硬件协同优化:针对特定硬件(如 Apple Silicon MLX, AMD)优化 tree attention kernel
---
八、关键洞察总结
8.1 理论层面
DDTree 的核心洞察是:block diffusion 的 per-position marginals 虽然丢失了路径条件信息,但仍然足以构建一个高效的 draft tree。通过最大化 factorized distribution 下的 expected acceptance length,DDTree 将 tree construction 转化为一个可解的优化问题,并给出了最优的 greedy 算法。
8.2 工程层面
DDTree 的工程优雅之处在于: 1. 零额外训练:完全复用现有 DFlash 模型 2. 单次 drafter pass:不增加 drafting latency 3. O(B log B) 的 tree 构建:开销极小 4. Ancestor-only attention:利用现有 tree attention 机制,无需新 kernel
8.3 生态系统意义
DDTree 代表了 speculative decoding 的一个关键转折点:
- 从单路径到多路径:充分利用并行 draft 的信息
- 从启发式到最优化:理论保证替代经验调参
- 从专用到通用:兼容所有 DFlash 模型,即插即用
九、推荐阅读顺序
1. 先读:DFlash 原始论文(arXiv 2602.06036)——理解 block diffusion 基础
2. 再读:DDTree 论文 Section 3-4 —— 理解核心算法
3. 代码:从 ddtree.py 的 build_tree() 开始 —— 理解实现细节
4. 实验:跑 benchmark.py 复现 Table 1 —— 验证实际效果
5. 扩展:读 OPT-Tree 和 DART 论文 —— 理解 tree-based speculative decoding 的全貌
---
> 一句话总结:DDTree 用一次 block diffusion pass 的 per-position distributions,通过 best-first heap 构建最优 draft tree,在零额外训练成本下将 speculative decoding 的加速比从 6x 推向 8x+。