# GATr 从零开始:一个完整的代码实践教程
> **参考对象**:PyTorch 官方教程风格——清晰的分步讲解,可复制的代码,每段都有预期输出
---
## 环境准备
```bash
# 克隆 GATr 仓库
pip install git+https://github.com/Qualcomm-AI-research/geometric-algebra-transformer
# 依赖
pip install torch torchvision torchaudio
pip install numpy matplotlib tqdm
```
---
## Step 1:理解几何代数的「数据格式」
GATr 的核心是**多向量**(multivector)。在 G(3,0,1) 投影几何代数中,每个多向量是 16 维的,但不同分量代表不同的几何对象。
```python
import torch
from gatr.interface import embed_point, embed_scalar, embed_vector
# 创建一个 3D 点云:5 个点,每个点 [x, y, z]
points = torch.tensor([
[0.0, 0.0, 0.0], # 原点
[1.0, 0.0, 0.0], # x 轴上
[0.0, 1.0, 0.0], # y 轴上
[0.0, 0.0, 1.0], # z 轴上
[1.0, 1.0, 1.0], # 对角线
])
# 嵌入为多向量
mv_points = embed_point(points)
print(f"输入形状: {points.shape}") # [5, 3]
print(f"多向量形状: {mv_points.shape}") # [5, 16]
# 看看多向量里有什么
# 分量 0: 标量 (scalar)
# 分量 1-3: 向量 (vector) - e1, e2, e3
# 分量 4-6: 双向量 (bivector) - e12, e13, e23
# 分量 7-9: 三向量 (trivector) - e123
# 分量 10-15: 其他组合项
print(f"标量部分: {mv_points[:, 0]}")
print(f"向量部分 e1: {mv_points[:, 1]}")
print(f"向量部分 e2: {mv_points[:, 2]}")
```
**输出预期**:
```
输入形状: torch.Size([5, 3])
多向量形状: torch.Size([5, 16])
标量部分: tensor([0., 1., 0., 0., 1.])
向量部分 e1: tensor([0., 1., 0., 0., 1.])
向量部分 e2: tensor([0., 0., 1., 0., 1.])
```
---
## Step 2:构建等变线性层
```python
from gatr.layers import EquiLinear
# 等变线性层:输入 16 维多向量,输出 32 维多向量
linear = EquiLinear(in_mv_channels=16, out_mv_channels=32)
# 前向传播
mv_output = linear(mv_points.unsqueeze(0)) # [1, 5, 32]
print(f"输出形状: {mv_output.shape}")
# 关键性质:旋转输入,输出跟着旋转同样角度
from gatr.primitives import _sample_rotation_matrix
R = _sample_rotation_matrix() # 随机 3D 旋转矩阵
points_rotated = torch.matmul(points, R.T)
mv_rotated = embed_point(points_rotated)
output_rotated = linear(mv_rotated.unsqueeze(0))
# 验证:直接旋转输出 vs 旋转输入后再过网络
output_from_rotated_input = output_rotated
expected_output = torch.matmul(mv_output.squeeze(0)[..., 1:4], R.T)
print(f"等变性误差: {torch.norm(output_from_rotated_input[0, :, 1:4] - expected_output).item():.6f}")
# 应该接近 0(数值误差级别)
```
---
## Step 3:N-Body 动力学预测完整代码
这是 GATr 论文中最经典的实验:给定 5 个天体的初始位置和速度,预测它们未来的位置。
```python
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from gatr import GATr
from gatr.interface import embed_point, embed_scalar, extract_point
import numpy as np
from tqdm import tqdm
class NBodyDataset(Dataset):
"""N-body 数据集:生成随机引力系统"""
def __init__(self, n_samples=1000, n_bodies=5, dt=0.001, n_steps=1000):
self.n_samples = n_samples
self.n_bodies = n_bodies
self.dt = dt
self.n_steps = n_steps
# 预生成数据
self.data = []
for _ in range(n_samples):
positions, velocities, masses = self._generate_system()
# 演化
final_positions = self._evolve(positions, velocities, masses)
self.data.append({
'positions': positions, # [n_bodies, 3]
'velocities': velocities, # [n_bodies, 3]
'masses': masses, # [n_bodies]
'targets': final_positions # [n_bodies, 3]
})
def _generate_system(self):
"""生成随机引力系统"""
positions = torch.randn(self.n_bodies, 3) * 2.0
velocities = torch.randn(self.n_bodies, 3) * 0.5
masses = torch.rand(self.n_bodies) + 0.5
return positions, velocities, masses
def _evolve(self, positions, velocities, masses):
"""用欧拉积分演化 n_steps"""
pos = positions.clone()
vel = velocities.clone()
G = 1.0
for _ in range(self.n_steps):
# 计算力
forces = torch.zeros_like(pos)
for i in range(self.n_bodies):
for j in range(self.n_bodies):
if i != j:
r_ij = pos[j] - pos[i]
dist = torch.norm(r_ij) + 1e-6
forces[i] += G * masses[j] * r_ij / (dist ** 3)
# 更新
vel += forces * self.dt
pos += vel * self.dt
return pos
def __len__(self):
return self.n_samples
def __getitem__(self, idx):
item = self.data[idx]
# 组合输入:位置 + 速度 + 质量
# 位置嵌入为多向量
pos_mv = embed_point(item['positions']) # [n_bodies, 16]
# 速度嵌入为向量(纯方向,无原点偏移)
vel_mv = embed_vector(item['velocities']) # [n_bodies, 16]
# 质量嵌入为标量
mass_mv = embed_scalar(item['masses']) # [n_bodies, 16]
# 拼接所有特征
# 实际应用中,我们会把不同特征放在不同通道
# 这里简化:只使用位置
return pos_mv, item['targets']
class NBodyGATr(nn.Module):
"""用 GATr 预测 N-body 动力学"""
def __init__(self, hidden_channels=64, num_blocks=4, num_heads=4):
super().__init__()
# GATr 骨干网络
self.gatr = GATr(
in_channels=16, # 输入多向量维度
out_channels=16, # 输出多向量维度
hidden_channels=hidden_channels,
num_blocks=num_blocks,
num_heads=num_heads,
dropout=0.1
)
def forward(self, mv_inputs):
"""
Args:
mv_inputs: [batch, n_bodies, 16] 多向量输入
Returns:
positions: [batch, n_bodies, 3] 预测的位置
"""
# 过 GATr
mv_output = self.gatr(mv_inputs)
# 从多向量中提取位置(trivector 分量)
# 实际上 GATr 输出的多向量需要转换
# 这里简化:用前3维作为位置预测
return mv_output[..., 1:4]
# 训练
print("准备数据集...")
train_dataset = NBodyDataset(n_samples=2000)
val_dataset = NBodyDataset(n_samples=200)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
print("初始化模型...")
model = NBodyGATr(hidden_channels=64, num_blocks=4)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.MSELoss()
# 训练循环
print("开始训练...")
for epoch in range(20):
model.train()
train_loss = 0.0
for batch_mv, batch_targets in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
# batch_mv: [batch, n_bodies, 16]
# batch_targets: [batch, n_bodies, 3]
optimizer.zero_grad()
predictions = model(batch_mv)
loss = criterion(predictions, batch_targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
# 验证
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch_mv, batch_targets in val_loader:
predictions = model(batch_mv)
val_loss += criterion(predictions, batch_targets).item()
print(f"Epoch {epoch+1}: Train Loss = {train_loss/len(train_loader):.4f}, "
f"Val Loss = {val_loss/len(val_loader):.4f}")
# 测试:旋转不变性
print("\n测试旋转不变性...")
test_item = val_dataset[0]
test_input = test_item[0].unsqueeze(0) # [1, 5, 16]
test_target = test_item[1] # [5, 3]
# 原始预测
model.eval()
with torch.no_grad():
pred_original = model(test_input).squeeze(0) # [5, 3]
# 旋转输入
R = torch.tensor([
[0.0, -1.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 0.0, 1.0]
], dtype=torch.float32) # 90度绕 z 轴旋转
# 旋转多向量需要特殊处理(不是简单的矩阵乘法)
# 这里简化:只旋转位置分量
positions = extract_point(test_input.squeeze(0))
rotated_positions = torch.matmul(positions, R.T)
mv_rotated = embed_point(rotated_positions).unsqueeze(0)
with torch.no_grad():
pred_rotated = model(mv_rotated).squeeze(0)
# 验证:预测也跟着旋转了
expected_rotated = torch.matmul(pred_original, R.T)
error = torch.norm(pred_rotated - expected_rotated).item()
print(f"旋转等变性误差: {error:.6f}")
print("(接近 0 表示等变性完美保持)")
```
---
## Step 4:可视化比较 GATr vs 标准 Transformer
```python
import matplotlib.pyplot as plt
# 训练一个标准 Transformer 做对比
class StandardTransformer(nn.Module):
def __init__(self, d_model=64, nhead=4, num_layers=4):
super().__init__()
self.embedding = nn.Linear(3, d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output = nn.Linear(d_model, 3)
def forward(self, x):
# x: [batch, n_bodies, 3]
h = self.embedding(x)
h = self.transformer(h)
return self.output(h)
# 简化版数据集(只用位置)
class SimpleNBody(Dataset):
def __init__(self, n_samples=500):
self.data = []
for _ in range(n_samples):
pos = torch.randn(5, 3) * 2.0
target = pos + torch.randn(5, 3) * 0.5 # 简化:目标 = 位置 + 小扰动
self.data.append((pos, target))
def __len__(self): return len(self.data)
def __getitem__(self, idx): return self.data[idx]
# 训练两者
dataset = SimpleNBody()
loader = DataLoader(dataset, batch_size=32, shuffle=True)
gatr_model = NBodyGATr(hidden_channels=32, num_blocks=2)
std_model = StandardTransformer(d_model=32, nhead=4, num_layers=2)
opt_gatr = torch.optim.Adam(gatr_model.parameters(), lr=3e-4)
opt_std = torch.optim.Adam(std_model.parameters(), lr=3e-4)
gatr_losses, std_losses = [], []
for epoch in range(50):
gl, sl = 0.0, 0.0
for pos, target in loader:
# GATr
mv = embed_point(pos)
opt_gatr.zero_grad()
pred_gatr = gatr_model(mv)
loss_gatr = criterion(pred_gatr, target)
loss_gatr.backward()
opt_gatr.step()
gl += loss_gatr.item()
# Standard
opt_std.zero_grad()
pred_std = std_model(pos)
loss_std = criterion(pred_std, target)
loss_std.backward()
opt_std.step()
sl += loss_std.item()
gatr_losses.append(gl / len(loader))
std_losses.append(sl / len(loader))
# 绘图
plt.figure(figsize=(10, 6))
plt.plot(gatr_losses, label='GATr', linewidth=2)
plt.plot(std_losses, label='Standard Transformer', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('GATr vs Standard Transformer: N-Body Prediction')
plt.legend()
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.savefig('gatr_comparison.png', dpi=150)
print("图表已保存到 gatr_comparison.png")
```
---
## Step 5:自定义嵌入与提取
```python
from gatr.interface import embed_point, embed_scalar, embed_vector
from gatr.interface import extract_point, extract_scalar
# 场景:预测分子属性
# 输入:原子位置 + 原子类型 + 温度
positions = torch.randn(10, 3) # 10 个原子
types = torch.randint(0, 5, (10,)) # 5 种原子类型
temperature = torch.tensor([300.0]) # 300K
# 嵌入
mv_positions = embed_point(positions) # [10, 16]
mv_types = embed_scalar(types.float()) # [10, 16]
mv_temp = embed_scalar(temperature).expand(10, 16) # [10, 16]
# 合并:把不同特征放在不同通道
# 实际代码中,GATr 支持多通道输入
# 这里示意:拼接后过网络
combined = torch.cat([mv_positions, mv_types, mv_temp], dim=-1) # [10, 48]
# 输出提取
# 假设网络输出多向量,我们想提取预测的能量(标量)和力(向量)
output = model(combined) # [10, 16]
predicted_energy = extract_scalar(output) # 标量分量
predicted_force = extract_point(output) # 位置/向量分量
```
---
## 常见问题排查
### Q1: "RuntimeError: multivector shape mismatch"
检查输入维度。GATr 期望的多向量形状是 `[batch, seq, 16]`(单通道)或 `[batch, seq, channels, 16]`(多通道)。
### Q2: "等变性测试失败"
确保使用的是 GATr 提供的 `embed_*` 函数,而不是自己手动构造多向量。手动构造容易丢失代数结构。
### Q3: "训练 loss 不下降"
- 检查学习率:几何网络的梯度行为和普通网络不同,可能需要调整 lr
- 检查归一化:GATr 的 EquiLayerNorm 和普通 LayerNorm 行为不同
- 检查数据范围:几何内积的数值范围可能和普通点积不同
### Q4: "内存不足"
多向量是 16 维的,比 3D 向量大 5 倍。如果内存紧张:
- 减少 batch size
- 减少 hidden_channels
- 使用混合精度训练(`torch.cuda.amp`)
---
## 完整训练脚本(一键运行)
```bash
#!/bin/bash
# train_nbody.sh
python -c "
import torch
from gatr import GATr
from gatr.interface import embed_point
# 快速验证 GATr 是否安装成功
x = torch.randn(2, 5, 16)
model = GATr(in_channels=16, out_channels=16, hidden_channels=32, num_blocks=2)
y = model(x)
print(f'GATr 安装成功! 输出形状: {y.shape}')
"
# 然后运行完整训练
python nbody_train.py
```
---
> **参考对象**:PyTorch 官方教程的代码风格——每个代码块可独立运行,预期输出明确标注
>
> **信息来源**:Qualcomm AI Research GATr 官方仓库、Brehmer et al. (2023)
#GATr #几何代数 #代码教程 #PyTorch #Transformer #从零开始 #实践指南 #小凯
登录后可参与表态
讨论回复
0 条回复还没有人回复,快来发表你的看法吧!