GitHub: https://github.com/tracel-ai/burn
Burn 是一个用 Rust 编写的下一代张量库和深度学习框架,不妥协于灵活性、效率和可移植性。
| 厂商 | CUDA | ROCm | Metal | Vulkan | WebGPU | LibTorch |
|---|---|---|---|---|---|---|
| NVIDIA | ✅ | - | - | ✅ | ✅ | ✅ |
| AMD | - | ✅ | - | ✅ | ✅ | ✅ |
| Apple | - | - | ✅ | - | ✅ | ✅ |
| Intel | - | - | - | ✅ | ✅ | - |
| 架构 | CubeCL | NdArray | LibTorch |
|---|---|---|---|
| X86 | ✅ | ✅ | ✅ |
| Arm | ✅ | ✅ | ✅ |
| Wasm | - | ✅ | - |
| no-std | - | ✅ | - |
type Backend = Autodiff<Wgpu>; // 给 WGPU 后端添加反向传播能力
type Backend = Router<(Wgpu, NdArray)>; // 部分操作在 GPU,部分在 CPU
| 功能 | 说明 |
|---|---|
| 训练仪表盘 | 基于 Ratatui 的终端 UI,实时显示训练指标 |
| ONNX 支持 | 导入 TensorFlow/PyTorch 模型,转换为 Rust 代码 |
| PyTorch/Safetensors 导入 | 直接加载现有模型权重 |
| 浏览器推理 | WebAssembly + WebGPU,可在浏览器运行 |
| 嵌入式支持 | no_std 支持,可在裸机环境运行 |
| 自动内核融合 | 优化性能,减少内存访问 |
use burn::nn;
use burn::module::Module;
use burn::tensor::backend::Backend;
#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: nn::Linear<B>,
linear_outer: nn::Linear<B>,
dropout: nn::Dropout,
gelu: nn::Gelu,
}
impl<B: Backend> PositionWiseFeedForward<B> {
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let x = self.linear_inner.forward(input);
let x = self.gelu.forward(x);
let x = self.dropout.forward(x);
self.linear_outer.forward(x)
}
}
还没有人回复