第十六章:Web开发——Axum实战
本章导读:Web 开发是 Rust 的重要应用领域。Axum 是 Tokio 团队开发的 Web 框架,它简洁、模块化、类型安全。本章我们将从零开始构建一个 RESTful API,涵盖路由、中间件、数据库集成等核心主题,感受 Rust 在服务端开发的魅力。
🌐 16.1 Axum 简介
🎯 16.1.1 为什么选择 Axum?
| 特性 | 说明 |
|---|---|
| 简洁 | 基于函数的处理器,无需宏 |
| 类型安全 | 利用类型系统保证路由正确性 |
| 高性能 | 基于 Tokio 和 Hyper |
| 模块化 | 可与其他 tower 生态组件组合 |
📦 16.1.2 项目初始化
cargo new web_api
cd web_api
# Cargo.toml
[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower = "0.4"
tower-http = { version = "0.5", features = ["cors", "trace"] }
tracing = "0.1"
tracing-subscriber = "0.3"
🚀 16.2 Hello World
🌍 16.2.1 最简服务器
use axum::{
routing::get,
Router,
};
use std::net::SocketAddr;
#[tokio::main]
async fn main() {
// 构建路由
let app = Router::new()
.route("/", get(hello));
// 绑定地址
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("服务器启动在 http://{}", addr);
// 启动服务器
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
// 处理函数:返回字符串
async fn hello() -> &'static str {
"Hello, Axum!"
}
📝 16.2.2 JSON 响应
use axum::{
extract::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
struct User {
id: u64,
name: String,
email: String,
}
async fn get_user() -> Json<User> {
Json(User {
id: 1,
name: "Alice".to_string(),
email: "alice@example.com".to_string(),
})
}
async fn create_user(Json(user): Json<User>) -> Json<User> {
// 处理创建逻辑...
println!("创建用户: {:?}", user);
Json(user)
}
#[tokio::main]
async fn main() {
let app = Router::new()
.route("/user", get(get_user))
.route("/user", post(create_user));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
🛣️ 16.3 路由
📍 16.3.1 路径参数
use axum::{
extract::Path,
routing::get,
Router,
};
// 静态路径
async fn hello() -> &'static str {
"Hello!"
}
// 动态路径参数
async fn get_user_by_id(Path(id): Path<u64>) -> String {
format!("用户ID: {}", id)
}
// 多个参数
async fn get_post(Path((user_id, post_id)): Path<(u64, u64)>) -> String {
format!("用户 {} 的文章 {}", user_id, post_id)
}
#[tokio::main]
async fn main() {
let app = Router::new()
.route("/", get(hello))
.route("/users/:id", get(get_user_by_id))
.route("/users/:user_id/posts/:post_id", get(get_post));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
🏗️ 16.3.2 路由组
use axum::{
routing::{get, post, delete},
Router,
};
#[tokio::main]
async fn main() {
// API 路由组
let api_routes = Router::new()
.route("/users", get(list_users).post(create_user))
.route("/users/:id", get(get_user).delete(delete_user));
// 管理路由组
let admin_routes = Router::new()
.route("/stats", get(get_stats))
.route("/logs", get(get_logs));
// 主路由
let app = Router::new()
.route("/", get(root))
.nest("/api", api_routes)
.nest("/admin", admin_routes);
// GET / -> root
// GET /api/users -> list_users
// POST /api/users -> create_user
// GET /admin/stats -> get_stats
}
async fn root() -> &'static str {
"欢迎!"
}
async fn list_users() -> &'static str {
"用户列表"
}
async fn create_user() -> &'static str {
"创建用户"
}
async fn get_user() -> &'static str {
"获取用户"
}
async fn delete_user() -> &'static str {
"删除用户"
}
async fn get_stats() -> &'static str {
"统计数据"
}
async fn get_logs() -> &'static str {
"日志"
}
📤 16.4 提取器(Extractors)
提取器从请求中提取数据。
🔍 16.4.1 常用提取器
use axum::{
extract::{
Path, Query, Json, Form,
Extension, State,
OriginalUri, ConnectInfo,
},
body::Body,
http::{HeaderMap, Method, Uri},
};
use serde::Deserialize;
#[derive(Deserialize)]
struct SearchQuery {
q: String,
page: Option<u32>,
}
#[derive(Deserialize)]
struct LoginForm {
username: String,
password: String,
}
async fn extract_all(
// 路径参数
Path(id): Path<u64>,
// 查询参数
Query(query): Query<SearchQuery>,
// JSON 请求体
Json(payload): Json<User>,
// 表单数据
Form(form): Form<LoginForm>,
// 请求头
headers: HeaderMap,
// HTTP 方法
method: Method,
// URI
uri: Uri,
// 原始 URI
OriginalUri(original_uri): OriginalUri,
) -> String {
format!(
"ID: {}, 搜索: {:?}, 方法: {}",
id, query, method
)
}
📦 16.4.2 自定义提取器
use axum::{
async_trait,
extract::{FromRequestParts, FromRequest},
http::{request::Parts, Request},
body::Body,
};
// 从请求头提取 API Key
struct ApiKey(String);
#[async_trait]
impl<S> FromRequestParts<S> for ApiKey
where
S: Send + Sync,
{
type Rejection = (axum::http::StatusCode, &'static str);
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
let header = parts.headers
.get("X-API-Key")
.and_then(|v| v.to_str().ok());
match header {
Some(key) => Ok(ApiKey(key.to_string())),
None => Err((
axum::http::StatusCode::UNAUTHORIZED,
"缺少 API Key",
)),
}
}
}
async fn protected(ApiKey(key): ApiKey) -> String {
format!("你的 API Key: {}", key)
}
🗄️ 16.5 状态管理
📦 16.5.1 共享状态
use axum::{
extract::State,
routing::{get, post},
Router,
};
use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone)]
struct User {
id: u64,
name: String,
}
// 应用状态
#[derive(Default)]
struct AppState {
users: Mutex<Vec<User>>,
counter: Mutex<u64>,
}
type SharedState = Arc<AppState>;
#[tokio::main]
async fn main() {
// 创建共享状态
let state = Arc::new(AppState::default());
let app = Router::new()
.route("/users", get(list_users).post(create_user))
.route("/count", get(get_count))
.with_state(state);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
async fn list_users(State(state): State<SharedState>) -> Json<Vec<User>> {
let users = state.users.lock().unwrap().clone();
Json(users)
}
async fn create_user(
State(state): State<SharedState>,
Json(user): Json<User>,
) -> Json<User> {
let mut users = state.users.lock().unwrap();
let mut counter = state.counter.lock().unwrap();
*counter += 1;
let new_user = User {
id: *counter,
..user
};
users.push(new_user.clone());
Json(new_user)
}
async fn get_count(State(state): State<SharedState>) -> String {
let counter = state.counter.lock().unwrap();
format!("计数: {}", *counter)
}
🧵 16.5.2 使用 tokio::sync
use tokio::sync::RwLock;
struct AsyncState {
users: RwLock<Vec<User>>,
}
async fn async_list_users(
State(state): State<Arc<AsyncState>>,
) -> Json<Vec<User>> {
let users = state.users.read().await.clone();
Json(users)
}
async fn async_create_user(
State(state): State<Arc<AsyncState>>,
Json(user): Json<User>,
) -> Json<User> {
let mut users = state.users.write().await;
users.push(user.clone());
Json(user)
}
🛡️ 16.6 中间件
📝 16.6.1 日志中间件
use axum::{
middleware::{self, Next},
response::Response,
body::Body,
http::Request,
};
async fn logging_middleware(
request: Request<Body>,
next: Next,
) -> Response {
let method = request.method().clone();
let uri = request.uri().clone();
println!("请求: {} {}", method, uri);
let start = std::time::Instant::now();
let response = next.run(request).await;
println!("响应: {} ({:?})", response.status(), start.elapsed());
response
}
#[tokio::main]
async fn main() {
let app = Router::new()
.route("/", get(|| async { "Hello" }))
.layer(middleware::from_fn(logging_middleware));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
🔐 16.6.2 认证中间件
use axum::{
middleware::{self, Next},
response::{Response, IntoResponse},
http::{Request, StatusCode},
};
struct AuthUser {
id: u64,
name: String,
}
async fn auth_middleware(
mut request: Request<Body>,
next: Next,
) -> Result<Response, StatusCode> {
let auth_header = request.headers()
.get("Authorization")
.and_then(|h| h.to_str().ok());
match auth_header {
Some(token) if token == "Bearer secret123" => {
// 添加用户信息到请求扩展
let user = AuthUser {
id: 1,
name: "Alice".to_string(),
};
request.extensions_mut().insert(user);
Ok(next.run(request).await)
}
_ => Err(StatusCode::UNAUTHORIZED),
}
}
🌐 16.6.3 CORS
use tower_http::cors::{CorsLayer, Any};
#[tokio::main]
async fn main() {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/", get(|| async { "Hello" }))
.layer(cors);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
🧪 16.7 实战:构建 RESTful API
让我们构建一个完整的待办事项 API:
use axum::{
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Json},
routing::{delete, get, post, put},
Router,
};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
net::SocketAddr,
};
use uuid::Uuid;
// ============ 数据模型 ============
#[derive(Debug, Serialize, Deserialize, Clone)]
struct Todo {
id: String,
title: String,
completed: bool,
}
#[derive(Debug, Deserialize)]
struct CreateTodo {
title: String,
}
#[derive(Debug, Deserialize)]
struct UpdateTodo {
title: Option<String>,
completed: Option<bool>,
}
// ============ 状态 ============
type Db = Arc<Mutex<HashMap<String, Todo>>>;
// ============ 错误处理 ============
enum ApiError {
NotFound,
BadRequest(String),
}
impl IntoResponse for ApiError {
fn into_response(self) -> axum::response::Response {
match self {
ApiError::NotFound => {
(StatusCode::NOT_FOUND, "资源未找到").into_response()
}
ApiError::BadRequest(msg) => {
(StatusCode::BAD_REQUEST, msg).into_response()
}
}
}
}
// ============ 处理函数 ============
async fn list_todos(State(db): State<Db>) -> impl IntoResponse {
let todos: Vec<Todo> = db.lock().unwrap().values().cloned().collect();
Json(todos)
}
async fn get_todo(
State(db): State<Db>,
Path(id): Path<String>,
) -> Result<Json<Todo>, ApiError> {
let db = db.lock().unwrap();
let todo = db.get(&id).cloned().ok_or(ApiError::NotFound)?;
Ok(Json(todo))
}
async fn create_todo(
State(db): State<Db>,
Json(payload): Json<CreateTodo>,
) -> impl IntoResponse {
let todo = Todo {
id: Uuid::new_v4().to_string(),
title: payload.title,
completed: false,
};
db.lock().unwrap().insert(todo.id.clone(), todo.clone());
(StatusCode::CREATED, Json(todo))
}
async fn update_todo(
State(db): State<Db>,
Path(id): Path<String>,
Json(payload): Json<UpdateTodo>,
) -> Result<Json<Todo>, ApiError> {
let mut db = db.lock().unwrap();
let todo = db.get_mut(&id).ok_or(ApiError::NotFound)?;
if let Some(title) = payload.title {
todo.title = title;
}
if let Some(completed) = payload.completed {
todo.completed = completed;
}
Ok(Json(todo.clone()))
}
async fn delete_todo(
State(db): State<Db>,
Path(id): Path<String>,
) -> Result<StatusCode, ApiError> {
let mut db = db.lock().unwrap();
if db.remove(&id).is_some() {
Ok(StatusCode::NO_CONTENT)
} else {
Err(ApiError::NotFound)
}
}
// ============ 主函数 ============
#[tokio::main]
async fn main() {
// 初始化日志
tracing_subscriber::fmt::init();
// 创建数据库
let db: Db = Arc::new(Mutex::new(HashMap::new()));
// 构建路由
let app = Router::new()
.route("/todos", get(list_todos).post(create_todo))
.route("/todos/:id", get(get_todo).put(update_todo).delete(delete_todo))
.with_state(db);
// 启动服务器
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
println!("🚀 服务器启动在 http://{}", addr);
println!("API 端点:");
println!(" GET /todos 获取所有待办事项");
println!(" POST /todos 创建待办事项");
println!(" GET /todos/:id 获取单个待办事项");
println!(" PUT /todos/:id 更新待办事项");
println!(" DELETE /todos/:id 删除待办事项");
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
📝 本章小结
本章我们学习了使用 Axum 进行 Web 开发:
| 组件 | 用途 |
|---|---|
Router | 定义路由 |
| 提取器 | 从请求提取数据 |
State | 共享应用状态 |
| 中间件 | 处理横切关注点 |
IntoResponse | 自定义响应 |
关键要点:
- Axum 是基于函数的框架,无需大量宏
- 提取器让请求数据处理类型安全
- 中间件用于认证、日志、CORS 等
- 状态管理使用
Arc包装
费曼技巧提问:为什么 Axum 的路由定义能保证编译期类型安全?提示:想想提取器如何工作。
动手实验:
- 为待办事项 API 添加分页功能(
?page=1&limit=10)。- 实现一个简单的 JWT 认证中间件。
- 添加请求日志,记录每个请求的方法、路径和耗时。
- 使用
sqlx替换内存存储,连接 SQLite 数据库。