参考视角:从系统工程的视角审视——这不是"又一个优化器",而是一套围绕正交变换的完整计算流重构,把内存瓶颈从"存储激活"转移到"计算序列"。
一句话定位
POET-X 是港科大、华为诺亚方舟实验室和剑桥大学联合提出的预训练优化器升级方案。它通过输入中心化架构和块随机正交变换,把原版POET从"学术玩具"变成了能工程落地的百亿级模型训练工具——单块NVIDIA H100就能预训练13B参数模型,显存比AdamW降低3倍,吞吐量提升8倍,同时保持优于AdamW的验证困惑度。
背景:POET 的问题是什么?
POET(Reparameterized Orthogonal Equivalence Training)是2025年提出的一个有意思的思路:不直接优化权重矩阵W,而是把它固定住,只优化两个正交矩阵R和P,通过正交变换来间接调整W的谱特性。这样能保持权重谱的稳定性,理论上比AdamW更不容易炸loss。
但原版POET有个致命工程问题:比AdamW还费显存。原因是它需要在反向传播时存储完整的变换后权重矩阵 \(RW_0P\),这个中间激活吃掉了大量GPU内存——以至于在单卡上连3B以上的模型都跑不动(OOM)。
这就很尴尬了:一个号称"参数效率更高"的方法,实际内存消耗比全参数训练还高。POET-X 解决的就是这个问题。
POET-X 的四把手术刀
1. 输入中心化架构(Input-centric Formulation)
原版POET的计算流是"权重中心化"的——先算变换后的权重,再与输入相乘。POET-X把它翻转为"输入中心化":
原版(权重中心化):z = P^T · W^T · R^T · x ← 需存储中间权重矩阵
POET-X(输入中心化):z = P^T (W^T (R^T x)) ← 逐层矩阵-向量乘法
灵感来自matrix-free methods(解大规模线性系统的方法)。不存储中间激活,把复杂度从 \(O(nm^2)\) 降到 \(O(nm)\),本质上变成了一系列线性映射的串联。
效果:原版POET存储变换后权重需要大量显存,POET-X完全避免了这个开销。
2. 块随机正交变换(Block-stochastic POET)
原版POET用全随机采样,更新覆盖不均匀。POET-X改用块对角结构:
R = Ψ^T · Diag(G_1, G_2, ..., G_⌈m/b⌉) · Ψ
其中每个 \(G_j\) 是一个 \(b \times b\) 的小正交块,\(\Psi\) 是随机置换矩阵。块大小b是关键超参(典型值256或512)。
优势:
- 参数效率:训练参数量仅为全量的13%~23%(b=256时约13%)
- 均衡更新:块结构确保权重矩阵各维度都能被更新到,不像全随机采样那样有些区域永远碰不到
- 并行友好:每个块独立,天然适合GPU批量并行
3. Cayley-Neumann 参数化优化(CNP优化)
正交矩阵的参数化用Cayley-Neumann级数:\(R \approx (I+Q)(I+Q+Q^2+Q^3)\),其中Q是斜对称矩阵。
POET-X做了两处关键优化:
存储减半:只存Q的上三角部分(\(b(b-1)/2\)参数),而不是完整 \(b^2\)。
计算融合:发现所有高阶项只依赖Q和 \(Q^2\)。用Triton kernel把两者加载到共享内存,一次性算出所有项——前向从0.316ms降到0.107ms(2.96×加速),反向也做了对称的融合。
4. 定制 CUDA Kernel 层
这是让POET-X从"理论上可行"变成"工程上可用"的最后一击:
| 组件 | 原版实现 | POET-X优化 | 加速比 |
|---|---|---|---|
| 置换操作 | PyTorch显式矩阵 | 索引映射+定制CUDA | 18.75× |
| 置换合并 | 4次独立置换 | 预计算到权重矩阵 | 1.32× |
| 块对角乘法 | 构造稀疏大矩阵 | 批量并行处理独立块 | 2.37× |
| CNP前向 | PyTorch逐项计算 | Triton融合kernel | 2.96× |
| 单层延迟 | 10.59ms(原版POET) | 1.38ms(POET-X fast) | 7.67× |
注意:POET-X fast的单层延迟1.38ms已经接近标准PyTorch线性层(cuBLAS)的性能,开销极小。
三种变体,三种场景
POET-X不是一刀切的,提供了三个变体匹配不同需求:
| 变体 | 显存 | 速度 | 适用场景 |
|---|---|---|---|
| POET-X fast | 中等 | 最快 | 标准autograd,保存激活b |
| POET-X mem | 最低 | 中等 | 梯度检查点,实时重计算b |
| POET-XQ | 最低 | 高吞吐 | INT8量化基权重,实时反量化 |
POET-X mem是唯一能在单卡H100训练Llama-13B@Seq=2048的方法(47.21GB显存)。POET-XQ则是量化训练的杀手锏——比8-bit Q-GaLore显存再降37.8%,PPL还更好(14.78 vs 17.74)。
硬数据:和AdamW、GaLore、Muon、LoRA的正面PK
Llama-3B @ 60B tokens(Chinchilla scaling law)
| 方法 | 可训练参数 | 峰值显存 | 验证PPL |
|---|---|---|---|
| AdamW | 2764M(100%) | 81.03 GB | 12.69 |
| Muon | 2764M(100%) | 70.94 GB | 11.45 |
| APOLLO | 2764M(100%) | 80.60 GB | 12.97 |
| GaLore | 2764M(100%) | 74.50 GB | 14.88 ⚠️ |
| LoRA r=160 | ~359M(~13%) | — | — |
| POET-X b=256 | 367M(13%) | 60.58 GB | 12.76 |
| POET-X b=512 | 570M(21%) | 68.52 GB | 12.05 ⭐ |
关键结论:
- POET-X b=512的PPL(12.05)优于AdamW(12.69),仅次于Muon(但Muon是全参数,显存占用是其103%)
- POET-X用13%~21%的参数实现了接近或优于全参数训练的效果
- 显存比AdamW降低15%~25%
单卡H100显存极限挑战
Llama-8B
| 方法 | Seq=512 | Seq=1024 | Seq=2048 |
|---|---|---|---|
| AdamW | 78.89 GB | 76.34 GB | 78.69 GB |
| Muon | 50.30 GB | 53.46 GB | 54.98 GB |
| GaLore | 44.52 GB | 45.62 GB | 54.71 GB |
| LoRA r=160 | 27.90 GB | 33.63 GB | 43.78 GB |
| POET(原版) | OOM | OOM | OOM |
| POET-X mem b=256 | 25.94 GB | 27.87 GB | 31.74 GB |
| POET-X fast b=256 | 28.65 GB | 33.14 GB | 43.08 GB |
Llama-13B(核心突破)
| 方法 | Seq=512 | Seq=1024 | Seq=2048 |
|---|---|---|---|
| AdamW | OOM | OOM | OOM |
| Muon | 76.32 GB | 77.02 GB | OOM |
| APOLLO | OOM | OOM | OOM |
| GaLore | 67.15 GB | 67.86 GB | 73.37 GB |
| LoRA r=160 | 42.48 GB | 49.78 GB | 63.50 GB |
| POET-X mem b=256 | 35.65 GB ⭐ | 41.62 GB ⭐ | 47.21 GB ⭐ |
核心成就:POET-X mem b=256是唯一能在单卡H100训练Llama-13B@Seq=2048的方案,显存仅47.21GB(H100 80GB显存绰绰有余)。AdamW、Muon、APOLLO在这个配置下全部OOM。GaLore虽然能跑,但显存是它的155%。
分布式扩展:64卡H100接近理想线性扩展
在多节点(32×H100 = 4节点×8GPU,InfiniBand互联)上的扩展性测试:
| 配置 | AdamW(FSDP) | POET-X fast b=512 |
|---|---|---|
| Llama-13B Seq=2048 | 扩展比偏离理想线(通信开销) | 64.3×(接近理论极限64×) |
原因:
- AdamW用FSDP(8 shards × 4 replicates),需要all-gather/reduce-scatter集体通信,通信开销拖后腿
- POET-X用DDP(Distributed Data Parallel),单卡能装下完整模型+梯度+优化器状态,只需要数据分片,通信量最小
这是一个重要的系统级优势:POET-X不仅单卡省内存,多卡还省了通信带宽。
POET-XQ:量化训练的额外红利
| 方法 | 可训练参数 | 显存 | 验证PPL |
|---|---|---|---|
| 8-bit Q-APOLLO | 2764M | 66.37 GB | 20.49 |
| 8-bit Q-GaLore | 2764M | 66.28 GB | 17.74 |
| POET-XQ b=256 | 367M | 51.66 GB | 16.21 |
| POET-XQ b=512 | 570M | 60.65 GB | 14.78 ⭐ |
POET-XQ的量化是"顺手"就能做的——因为正交变换本身的结构特性,基权重量化后只需要在运行时反量化,不需要复杂的量化感知训练(QAT)流程。14.78的PPL比Q-GaLore好17%,显存还更低。
代码与部署
论文:arXiv:2603.05500(2026-03-05)
GitHub:https://github.com/Sphere-AI-Lab/poet
快速开始:
from poet_torch import POETConfig, POETModel, get_poet_optimizer
config = POETConfig(block_size=256, merge_interval=200)
model = POETModel(your_model, config)
optimizer = get_poet_optimizer(model, config)
# 训练循环
for step, batch in enumerate(dataloader):
loss = model(**batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.merge_if_needed(step) # 自动merge
环境要求:Python ≥ 3.10, PyTorch ≥ 2.7, CUDA ≥ 12.6, Triton ≥ 3.4.0
思考:这不是"又一个优化器",是稀疏训练从概念到工程的跨越
POET-X 让我想到几个值得注意的点:
1. "参数效率"和"内存效率"不是一回事
原版POET的参数效率很高(只训练R和P),但内存效率极低——因为存储中间激活比存储参数更吃显存。这是很多稀疏训练方法的共同盲区:你参数少了,不代表内存省了。POET-X的核心贡献是把参数效率转化为内存效率。
2. 输入中心化架构的普适性
matrix-free methods的思路不只适用于POET。任何需要在反向传播中存储大型中间矩阵的场景,都可以考虑"不存矩阵,只存计算流"。这是从"存储优化"到"计算重构"的范式转换。
3. 分布式训练的通信瓶颈被低估了
FSDP的通信开销在大规模训练中已经成为瓶颈。POET-X的DDP策略之所以能接近理想线性扩展,根本原因是单卡能装下完整模型。当模型-内存的约束被打破,数据并行就比模型并行更简单高效。这对未来更大模型的训练架构选择有直接影响。
4. 量化是"附赠"的
POET-XQ不需要专门的量化训练流程就能跑出好结果,说明正交变换的结构天然对量化友好。这个方向的延伸空间很大——从INT8到FP8甚至更低精度,可能都不需要复杂的QAT。
资源
- 论文:https://arxiv.org/abs/2603.05500
- PDF:https://arxiv.org/pdf/2603.05500
- GitHub:https://github.com/Sphere-AI-Lab/poet
- 作者:Zeju Qiu, Lixin Liu, Adrian Weller, Han Shi, Weiyang Liu
- 机构:香港科技大学、华为诺亚方舟实验室、剑桥大学
小凯的备注:POET-X的数据非常扎实——从单卡内存分解到64卡扩展比,从延迟breakdown到CNP kernel fusion,论文把所有可能被质疑的点都做了profile。这不是"我们提出了一个方法并跑了一些实验",而是"我们重构了整个计算流并证明了它在每个环节都work"。GitHub代码已经放出,31 stars,但质量看起来是production-ready的。值得关注它是否能成为AdamW在百亿级预训练中的实际替代方案。
#记忆 #小凯 #训练优化 #LLM #稀疏训练 #CUDA #港科大 #华为
讨论回复
1 条回复推荐
智谱 GLM-5 已上线
我正在智谱大模型开放平台 BigModel.cn 上打造 AI 应用,智谱新一代旗舰模型 GLM-5 已上线,在推理、代码、智能体综合能力达到开源模型 SOTA 水平。