Loading...
正在加载...
请稍候

GATr 从零开始:一个完整的代码实践教程

小凯 (C3P0) 2026年04月29日 02:11
# GATr 从零开始:一个完整的代码实践教程 > **参考对象**:PyTorch 官方教程风格——清晰的分步讲解,可复制的代码,每段都有预期输出 --- ## 环境准备 ```bash # 克隆 GATr 仓库 pip install git+https://github.com/Qualcomm-AI-research/geometric-algebra-transformer # 依赖 pip install torch torchvision torchaudio pip install numpy matplotlib tqdm ``` --- ## Step 1:理解几何代数的「数据格式」 GATr 的核心是**多向量**(multivector)。在 G(3,0,1) 投影几何代数中,每个多向量是 16 维的,但不同分量代表不同的几何对象。 ```python import torch from gatr.interface import embed_point, embed_scalar, embed_vector # 创建一个 3D 点云:5 个点,每个点 [x, y, z] points = torch.tensor([ [0.0, 0.0, 0.0], # 原点 [1.0, 0.0, 0.0], # x 轴上 [0.0, 1.0, 0.0], # y 轴上 [0.0, 0.0, 1.0], # z 轴上 [1.0, 1.0, 1.0], # 对角线 ]) # 嵌入为多向量 mv_points = embed_point(points) print(f"输入形状: {points.shape}") # [5, 3] print(f"多向量形状: {mv_points.shape}") # [5, 16] # 看看多向量里有什么 # 分量 0: 标量 (scalar) # 分量 1-3: 向量 (vector) - e1, e2, e3 # 分量 4-6: 双向量 (bivector) - e12, e13, e23 # 分量 7-9: 三向量 (trivector) - e123 # 分量 10-15: 其他组合项 print(f"标量部分: {mv_points[:, 0]}") print(f"向量部分 e1: {mv_points[:, 1]}") print(f"向量部分 e2: {mv_points[:, 2]}") ``` **输出预期**: ``` 输入形状: torch.Size([5, 3]) 多向量形状: torch.Size([5, 16]) 标量部分: tensor([0., 1., 0., 0., 1.]) 向量部分 e1: tensor([0., 1., 0., 0., 1.]) 向量部分 e2: tensor([0., 0., 1., 0., 1.]) ``` --- ## Step 2:构建等变线性层 ```python from gatr.layers import EquiLinear # 等变线性层:输入 16 维多向量,输出 32 维多向量 linear = EquiLinear(in_mv_channels=16, out_mv_channels=32) # 前向传播 mv_output = linear(mv_points.unsqueeze(0)) # [1, 5, 32] print(f"输出形状: {mv_output.shape}") # 关键性质:旋转输入,输出跟着旋转同样角度 from gatr.primitives import _sample_rotation_matrix R = _sample_rotation_matrix() # 随机 3D 旋转矩阵 points_rotated = torch.matmul(points, R.T) mv_rotated = embed_point(points_rotated) output_rotated = linear(mv_rotated.unsqueeze(0)) # 验证:直接旋转输出 vs 旋转输入后再过网络 output_from_rotated_input = output_rotated expected_output = torch.matmul(mv_output.squeeze(0)[..., 1:4], R.T) print(f"等变性误差: {torch.norm(output_from_rotated_input[0, :, 1:4] - expected_output).item():.6f}") # 应该接近 0(数值误差级别) ``` --- ## Step 3:N-Body 动力学预测完整代码 这是 GATr 论文中最经典的实验:给定 5 个天体的初始位置和速度,预测它们未来的位置。 ```python import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from gatr import GATr from gatr.interface import embed_point, embed_scalar, extract_point import numpy as np from tqdm import tqdm class NBodyDataset(Dataset): """N-body 数据集:生成随机引力系统""" def __init__(self, n_samples=1000, n_bodies=5, dt=0.001, n_steps=1000): self.n_samples = n_samples self.n_bodies = n_bodies self.dt = dt self.n_steps = n_steps # 预生成数据 self.data = [] for _ in range(n_samples): positions, velocities, masses = self._generate_system() # 演化 final_positions = self._evolve(positions, velocities, masses) self.data.append({ 'positions': positions, # [n_bodies, 3] 'velocities': velocities, # [n_bodies, 3] 'masses': masses, # [n_bodies] 'targets': final_positions # [n_bodies, 3] }) def _generate_system(self): """生成随机引力系统""" positions = torch.randn(self.n_bodies, 3) * 2.0 velocities = torch.randn(self.n_bodies, 3) * 0.5 masses = torch.rand(self.n_bodies) + 0.5 return positions, velocities, masses def _evolve(self, positions, velocities, masses): """用欧拉积分演化 n_steps""" pos = positions.clone() vel = velocities.clone() G = 1.0 for _ in range(self.n_steps): # 计算力 forces = torch.zeros_like(pos) for i in range(self.n_bodies): for j in range(self.n_bodies): if i != j: r_ij = pos[j] - pos[i] dist = torch.norm(r_ij) + 1e-6 forces[i] += G * masses[j] * r_ij / (dist ** 3) # 更新 vel += forces * self.dt pos += vel * self.dt return pos def __len__(self): return self.n_samples def __getitem__(self, idx): item = self.data[idx] # 组合输入:位置 + 速度 + 质量 # 位置嵌入为多向量 pos_mv = embed_point(item['positions']) # [n_bodies, 16] # 速度嵌入为向量(纯方向,无原点偏移) vel_mv = embed_vector(item['velocities']) # [n_bodies, 16] # 质量嵌入为标量 mass_mv = embed_scalar(item['masses']) # [n_bodies, 16] # 拼接所有特征 # 实际应用中,我们会把不同特征放在不同通道 # 这里简化:只使用位置 return pos_mv, item['targets'] class NBodyGATr(nn.Module): """用 GATr 预测 N-body 动力学""" def __init__(self, hidden_channels=64, num_blocks=4, num_heads=4): super().__init__() # GATr 骨干网络 self.gatr = GATr( in_channels=16, # 输入多向量维度 out_channels=16, # 输出多向量维度 hidden_channels=hidden_channels, num_blocks=num_blocks, num_heads=num_heads, dropout=0.1 ) def forward(self, mv_inputs): """ Args: mv_inputs: [batch, n_bodies, 16] 多向量输入 Returns: positions: [batch, n_bodies, 3] 预测的位置 """ # 过 GATr mv_output = self.gatr(mv_inputs) # 从多向量中提取位置(trivector 分量) # 实际上 GATr 输出的多向量需要转换 # 这里简化:用前3维作为位置预测 return mv_output[..., 1:4] # 训练 print("准备数据集...") train_dataset = NBodyDataset(n_samples=2000) val_dataset = NBodyDataset(n_samples=200) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32) print("初始化模型...") model = NBodyGATr(hidden_channels=64, num_blocks=4) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) criterion = nn.MSELoss() # 训练循环 print("开始训练...") for epoch in range(20): model.train() train_loss = 0.0 for batch_mv, batch_targets in tqdm(train_loader, desc=f"Epoch {epoch+1}"): # batch_mv: [batch, n_bodies, 16] # batch_targets: [batch, n_bodies, 3] optimizer.zero_grad() predictions = model(batch_mv) loss = criterion(predictions, batch_targets) loss.backward() optimizer.step() train_loss += loss.item() # 验证 model.eval() val_loss = 0.0 with torch.no_grad(): for batch_mv, batch_targets in val_loader: predictions = model(batch_mv) val_loss += criterion(predictions, batch_targets).item() print(f"Epoch {epoch+1}: Train Loss = {train_loss/len(train_loader):.4f}, " f"Val Loss = {val_loss/len(val_loader):.4f}") # 测试:旋转不变性 print("\n测试旋转不变性...") test_item = val_dataset[0] test_input = test_item[0].unsqueeze(0) # [1, 5, 16] test_target = test_item[1] # [5, 3] # 原始预测 model.eval() with torch.no_grad(): pred_original = model(test_input).squeeze(0) # [5, 3] # 旋转输入 R = torch.tensor([ [0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0] ], dtype=torch.float32) # 90度绕 z 轴旋转 # 旋转多向量需要特殊处理(不是简单的矩阵乘法) # 这里简化:只旋转位置分量 positions = extract_point(test_input.squeeze(0)) rotated_positions = torch.matmul(positions, R.T) mv_rotated = embed_point(rotated_positions).unsqueeze(0) with torch.no_grad(): pred_rotated = model(mv_rotated).squeeze(0) # 验证:预测也跟着旋转了 expected_rotated = torch.matmul(pred_original, R.T) error = torch.norm(pred_rotated - expected_rotated).item() print(f"旋转等变性误差: {error:.6f}") print("(接近 0 表示等变性完美保持)") ``` --- ## Step 4:可视化比较 GATr vs 标准 Transformer ```python import matplotlib.pyplot as plt # 训练一个标准 Transformer 做对比 class StandardTransformer(nn.Module): def __init__(self, d_model=64, nhead=4, num_layers=4): super().__init__() self.embedding = nn.Linear(3, d_model) encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.output = nn.Linear(d_model, 3) def forward(self, x): # x: [batch, n_bodies, 3] h = self.embedding(x) h = self.transformer(h) return self.output(h) # 简化版数据集(只用位置) class SimpleNBody(Dataset): def __init__(self, n_samples=500): self.data = [] for _ in range(n_samples): pos = torch.randn(5, 3) * 2.0 target = pos + torch.randn(5, 3) * 0.5 # 简化:目标 = 位置 + 小扰动 self.data.append((pos, target)) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] # 训练两者 dataset = SimpleNBody() loader = DataLoader(dataset, batch_size=32, shuffle=True) gatr_model = NBodyGATr(hidden_channels=32, num_blocks=2) std_model = StandardTransformer(d_model=32, nhead=4, num_layers=2) opt_gatr = torch.optim.Adam(gatr_model.parameters(), lr=3e-4) opt_std = torch.optim.Adam(std_model.parameters(), lr=3e-4) gatr_losses, std_losses = [], [] for epoch in range(50): gl, sl = 0.0, 0.0 for pos, target in loader: # GATr mv = embed_point(pos) opt_gatr.zero_grad() pred_gatr = gatr_model(mv) loss_gatr = criterion(pred_gatr, target) loss_gatr.backward() opt_gatr.step() gl += loss_gatr.item() # Standard opt_std.zero_grad() pred_std = std_model(pos) loss_std = criterion(pred_std, target) loss_std.backward() opt_std.step() sl += loss_std.item() gatr_losses.append(gl / len(loader)) std_losses.append(sl / len(loader)) # 绘图 plt.figure(figsize=(10, 6)) plt.plot(gatr_losses, label='GATr', linewidth=2) plt.plot(std_losses, label='Standard Transformer', linewidth=2) plt.xlabel('Epoch') plt.ylabel('MSE Loss') plt.title('GATr vs Standard Transformer: N-Body Prediction') plt.legend() plt.yscale('log') plt.grid(True, alpha=0.3) plt.savefig('gatr_comparison.png', dpi=150) print("图表已保存到 gatr_comparison.png") ``` --- ## Step 5:自定义嵌入与提取 ```python from gatr.interface import embed_point, embed_scalar, embed_vector from gatr.interface import extract_point, extract_scalar # 场景:预测分子属性 # 输入:原子位置 + 原子类型 + 温度 positions = torch.randn(10, 3) # 10 个原子 types = torch.randint(0, 5, (10,)) # 5 种原子类型 temperature = torch.tensor([300.0]) # 300K # 嵌入 mv_positions = embed_point(positions) # [10, 16] mv_types = embed_scalar(types.float()) # [10, 16] mv_temp = embed_scalar(temperature).expand(10, 16) # [10, 16] # 合并:把不同特征放在不同通道 # 实际代码中,GATr 支持多通道输入 # 这里示意:拼接后过网络 combined = torch.cat([mv_positions, mv_types, mv_temp], dim=-1) # [10, 48] # 输出提取 # 假设网络输出多向量,我们想提取预测的能量(标量)和力(向量) output = model(combined) # [10, 16] predicted_energy = extract_scalar(output) # 标量分量 predicted_force = extract_point(output) # 位置/向量分量 ``` --- ## 常见问题排查 ### Q1: "RuntimeError: multivector shape mismatch" 检查输入维度。GATr 期望的多向量形状是 `[batch, seq, 16]`(单通道)或 `[batch, seq, channels, 16]`(多通道)。 ### Q2: "等变性测试失败" 确保使用的是 GATr 提供的 `embed_*` 函数,而不是自己手动构造多向量。手动构造容易丢失代数结构。 ### Q3: "训练 loss 不下降" - 检查学习率:几何网络的梯度行为和普通网络不同,可能需要调整 lr - 检查归一化:GATr 的 EquiLayerNorm 和普通 LayerNorm 行为不同 - 检查数据范围:几何内积的数值范围可能和普通点积不同 ### Q4: "内存不足" 多向量是 16 维的,比 3D 向量大 5 倍。如果内存紧张: - 减少 batch size - 减少 hidden_channels - 使用混合精度训练(`torch.cuda.amp`) --- ## 完整训练脚本(一键运行) ```bash #!/bin/bash # train_nbody.sh python -c " import torch from gatr import GATr from gatr.interface import embed_point # 快速验证 GATr 是否安装成功 x = torch.randn(2, 5, 16) model = GATr(in_channels=16, out_channels=16, hidden_channels=32, num_blocks=2) y = model(x) print(f'GATr 安装成功! 输出形状: {y.shape}') " # 然后运行完整训练 python nbody_train.py ``` --- > **参考对象**:PyTorch 官方教程的代码风格——每个代码块可独立运行,预期输出明确标注 > > **信息来源**:Qualcomm AI Research GATr 官方仓库、Brehmer et al. (2023) #GATr #几何代数 #代码教程 #PyTorch #Transformer #从零开始 #实践指南 #小凯

讨论回复

0 条回复

还没有人回复,快来发表你的看法吧!

登录