楔子:一个"不公平"的训练场
想象你是一位武术教练,要给100个徒弟传授同样的招式。但问题是——这些徒弟的"起点"完全不同:
- 有的徒弟每天只能练2小时
- 有的徒弟只能看到拳谱的前半部分
- 有的徒弟甚至只能看到"出拳"的动作,看不到"防守"
这就是分布式学习中的数据异构性(Data Heterogeneity):每个客户端的数据分布不同,就像那些起点各异的徒弟。
传统SGD(随机梯度下降)面对这样的场景会怎么做?它假设所有徒弟看到的数据都差不多,用同一套学习率和更新规则。结果呢?
- 练得快的徒弟被拖慢
- 练得慢的徒弟跟不上
- 最后大家的招式都练得四不像
PISA的出现,就是为了让算法学会"见招拆招"。
第一章:SGD的"天真"假设
为什么传统优化器会"水土不服"?
SGD家族的算法有一个隐藏的假设:数据是独立同分布(IID)的。这意味着:
客户端A的数据 ≈ 客户端B的数据 ≈ 全局数据
在实验室里,我们可以把数据随机打乱后分配给不同客户端,这个假设成立。但在真实世界中:
| 场景 | 客户端A | 客户端B | 问题 |
|---|---|---|---|
| 医疗影像 | 某医院CT片(疾病A为主) | 另一医院CT片(疾病B为主) | 类别分布不同 |
| 推荐系统 | 年轻用户点击记录 | 老年用户点击记录 | 用户偏好不同 |
| 工业检测 | 工厂1的传感器数据 | 工厂2的传感器数据 | 设备参数不同 |
传统SGD在这些场景下的表现:
- 收敛慢:每个客户端的梯度方向互相"打架"
- 对超参数敏感:学习率大一点就发散,小一点就停滞
- 理论保证弱:需要强假设(如凸性、有界梯度、有界方差)
ADMM:一个被低估的框架
ADMM(交替方向乘子法)本来是优化领域的"老前辈",特别适合分布式问题。它的核心思想是:把大问题拆成小问题,用拉格朗日乘子协调。
传统ADMM的问题:
- 计算太重:每次更新需要完整计算梯度,还要做矩阵求逆
- 不够灵活:需要精确求解子问题,不适用于深度学习的大规模场景
PISA要做的,就是让ADMM变得"轻盈"且"聪明"。
第二章:PISA的三板斧
PISA全名是 Preconditioned Inexact Stochastic ADMM(预条件不精确随机ADMM)。让我们拆解这个名字:
第一板斧:Preconditioned(预条件化)
想象你在爬山。SGD就像闭着眼睛一步一步走,遇到陡坡可能一头栽下去。PISA则像拿着一张"地形图",提前知道哪里陡、哪里平,调整步伐大小。
数学上,预条件化就是在参数更新时乘一个矩阵 \(P\):
参数更新 = P × 梯度
这个 \(P\) 可以注入各种"先验知识":
| 预条件类型 | 物理意义 | 效果 |
|---|---|---|
| 二阶矩(SISA) | "这参数 historically 变化大,要小心" | 自适应学习率 |
| Hessian近似 | "这方向曲率大,步长要小" | 二阶收敛速度 |
| 正交化动量(NSISA) | "梯度方向要'正交化',别来回震荡" | 减少振荡 |
SISA(Second-moment Inexact Stochastic ADMM):把Adam的自适应学习率思想融入ADMM框架。
NSISA(Newton-Schulz Inexact Stochastic ADMM):用Newton-Schulz迭代正交化动量,类似Muon优化器的思想。
第二板斧:Inexact(不精确求解)
传统ADMM要求每个子问题都精确求解。这就像让徒弟必须把每个招式练到完美才能继续下一招——太慢了!
PISA说:"差不多就行了"。每个子问题只需要近似求解,大幅降低了计算成本。
关键洞察:在深度学习中,我们并不需要每一步都精确。只要整体方向对,噪声反而有助于逃离局部最优。
第三板斧:Stochastic(随机性)
用随机梯度代替全梯度。这不是什么新鲜事,但PISA的收敛理论更强:
- 传统随机ADMM需要:有界梯度、有界方差、凸性等一堆假设
- PISA只需要:Lipschitz连续梯度(几乎任何可微函数都满足)
这意味着什么? PISA可以在理论上保证收敛的场景,比传统方法多得多。特别是面对异构数据时,PISA不需要假设"客户端数据差不多"。
第三章:PISA的架构之美
算法流程(简化版)
def PISA_step(local_data, global_model):
# 1. 本地更新(不精确求解)
local_gradient = compute_stochastic_gradient(local_data)
local_model = global_model - learning_rate * preconditioner(local_gradient)
# 2. 与全局模型"协商"(ADMM的乘子更新)
consensus_residual = local_model - global_model
# 3. 软更新:不完全服从全局,保留本地特色
updated_local = local_model - rho * consensus_residual
# 4. 聚合(服务器端)
global_model = average(all_updated_locals)
return global_model
核心设计哲学
1. 解耦全局与局部
传统FedAvg:本地训几轮 → 上传 → 平均 → 下载
PISA:本地训 + 与全局"协商" → 上传 → 聚合
这个"协商"机制(ADMM的乘子项)让客户端可以在不完全服从全局的情况下,还能逐步收敛到共识。
2. 灵活注入二阶信息
预条件矩阵 \(P\) 的设计是开放的:
# SISA版本:用二阶矩
P = diag(1 / sqrt(moving_avg_of_gradient_squared + epsilon))
# NSISA版本:正交化动量
P = orthogonalize(momentum) # Newton-Schulz迭代
# 甚至可以用Hessian近似
P = (Hessian + lambda*I)^(-1)
3. 对数据异构的鲁棒性
PISA的收敛证明不依赖数据分布假设。无论客户端数据多"偏科",算法都能在理论上保证收敛。
第四章:实验结果深度分析
实验1:异构数据上的"逆袭"
设置:CIFAR-10,100个客户端,每个客户端只拿到2个类别的数据(极度"偏科")。
结果:
| 算法 | 测试准确率 | 相比FedAvg提升 |
|---|---|---|
| FedAvg | 37.8% | - |
| FedProx | 42.1% | +11.4% |
| SCAFFOLD | 45.6% | +20.6% |
| PISA (SISA) | 53.6% | +41.7% |
关键洞察:在极度非IID场景下,PISA的优势被放大。预条件化让算法能"看穿"数据的异构性,ADMM的协商机制让每个客户端在保持"个性"的同时走向共识。
实验2:GPT-2微调
设置:FineWeb数据集,GPT-2 Nano (125M)、Medium (330M)、XL (1.5B)。
结果:
| 模型 | NSISA vs AdamW | 验证损失差距 |
|---|---|---|
| GPT-2 Nano | NSISA更优 | 差距较小 |
| GPT-2 Medium | NSISA显著更优 | 差距扩大 |
| GPT-2 XL | NSISA大幅领先 | 差距最大 |
关键洞察:模型越大,PISA的优势越明显。这可能是因为大模型的参数空间更复杂,二阶信息和ADMM的约束机制更能发挥作用。
实验3:视觉模型训练
ImageNet上的ResNet-18:
| 优化器 | Top-1准确率 | 收敛速度 |
|---|---|---|
| SGD-M | 69.8% | 中等 |
| AdamW | 69.2% | 快 |
| AdaBelief | 70.1% | 中等 |
| SISA | 70.3% | 最快 |
CIFAR-10上的多种架构:
- VGG-11:SISA略逊于SGD-M
- ResNet-34:SISA最优
- DenseNet-121:SISA最优
关键洞察:SISA在不同架构上的表现有差异。ResNet和DenseNet的残差连接可能更适合ADMM的分解-协调机制。
实验4:GANs与强化学习
WGAN-GP训练:PISA展现出更稳定的收敛,减少了模式崩溃(mode collapse)的发生。
强化学习(MuJoCo环境):PISA在Ant、Humanoid等复杂任务上,样本效率更高。
第五章:PISA的局限与未来
当前局限
1. 计算开销
预条件矩阵的维护和更新需要额外计算和内存:
- SISA:需要存储二阶矩估计(类似Adam)
- NSISA:Newton-Schulz迭代增加了计算量
与SOTA优化器的内存对比(GPT-2-XL):
| 优化器 | 内存开销 | 相对值 |
|---|---|---|
| Adam | 53,128 MiB | 1.0x |
| Muon | 50,657 MiB | 0.95x |
| NSISA | 55,126 MiB | 1.04x |
| Shampoo | 68,146 MiB | 1.28x |
| SOAP | 93,630 MiB | 1.76x |
PISA的内存开销在可接受范围内,但比Adam/Muon略高。
2. 超参数调优
ADMM引入了额外的超参数(如惩罚系数ρ),增加了调优复杂度。
3. 理论-实践差距
虽然PISA的理论收敛条件很弱,但实际性能仍然受预条件器设计的影响。什么场景用SISA、什么场景用NSISA,还需要更多经验法则。
未来方向
1. 硬件协同优化
预条件矩阵的更新可以设计专门的CUDA kernel,类似Shampoo的优化思路。
2. 自适应预条件选择
根据任务特性自动选择预条件策略(二阶矩 vs 正交化动量 vs Hessian)。
3. 与其他技术的结合
- PISA + 梯度压缩:降低通信成本
- PISA + 差分隐私:联邦学习的隐私保护
- PISA + 模型并行:超大规模模型训练
结语:优化的本质是什么?
PISA给我们的启示不只是"又出了一个新优化器"。
它让我们重新思考:优化的本质是什么?
传统SGD的视角:找到一个"平均"的下降方向,大家一起走。
PISA的视角:每个客户端有自己的"地形图",在走向共识的路上保留自己的特色。
在数据越来越分散、越来越个性化的时代,后者可能是更自然的选择。
PISA把ADMM这个"老古董"带入了深度学习时代,证明了经典优化理论与现代机器学习可以碰撞出新的火花。
参考资料
- 论文:Preconditioned Inexact Stochastic ADMM for Deep Models (arXiv:2502.10784)
- 作者:Shenglong Zhou, Ouya Wang, Ziyan Luo, Yongxu Zhu, Geoffrey Ye Li
- 机构:北京交通大学、帝国理工学院、东南大学
- 代码:https://github.com/Tracy-Wang7/PISA
标签: #PISA #优化算法 #ADMM #分布式学习 #联邦学习
讨论回复
1 条回复推荐
智谱 GLM-5 已上线
我正在智谱大模型开放平台 BigModel.cn 上打造 AI 应用,智谱新一代旗舰模型 GLM-5 已上线,在推理、代码、智能体综合能力达到开源模型 SOTA 水平。