第十六章:Web开发——Axum实战

第十六章:Web开发——Axum实战

本章导读:Web 开发是 Rust 的重要应用领域。Axum 是 Tokio 团队开发的 Web 框架,它简洁、模块化、类型安全。本章我们将从零开始构建一个 RESTful API,涵盖路由、中间件、数据库集成等核心主题,感受 Rust 在服务端开发的魅力。


🌐 16.1 Axum 简介

🎯 16.1.1 为什么选择 Axum?

特性说明
简洁基于函数的处理器,无需宏
类型安全利用类型系统保证路由正确性
高性能基于 Tokio 和 Hyper
模块化可与其他 tower 生态组件组合

📦 16.1.2 项目初始化

cargo new web_api
cd web_api
# Cargo.toml
[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower = "0.4"
tower-http = { version = "0.5", features = ["cors", "trace"] }
tracing = "0.1"
tracing-subscriber = "0.3"

🚀 16.2 Hello World

🌍 16.2.1 最简服务器

use axum::{
    routing::get,
    Router,
};
use std::net::SocketAddr;

#[tokio::main]
async fn main() {
    // 构建路由
    let app = Router::new()
        .route("/", get(hello));

    // 绑定地址
    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    println!("服务器启动在 http://{}", addr);

    // 启动服务器
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

// 处理函数:返回字符串
async fn hello() -> &'static str {
    "Hello, Axum!"
}

📝 16.2.2 JSON 响应

use axum::{
    extract::Json,
    routing::{get, post},
    Router,
};
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize)]
struct User {
    id: u64,
    name: String,
    email: String,
}

async fn get_user() -> Json<User> {
    Json(User {
        id: 1,
        name: "Alice".to_string(),
        email: "alice@example.com".to_string(),
    })
}

async fn create_user(Json(user): Json<User>) -> Json<User> {
    // 处理创建逻辑...
    println!("创建用户: {:?}", user);
    Json(user)
}

#[tokio::main]
async fn main() {
    let app = Router::new()
        .route("/user", get(get_user))
        .route("/user", post(create_user));

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

🛣️ 16.3 路由

📍 16.3.1 路径参数

use axum::{
    extract::Path,
    routing::get,
    Router,
};

// 静态路径
async fn hello() -> &'static str {
    "Hello!"
}

// 动态路径参数
async fn get_user_by_id(Path(id): Path<u64>) -> String {
    format!("用户ID: {}", id)
}

// 多个参数
async fn get_post(Path((user_id, post_id)): Path<(u64, u64)>) -> String {
    format!("用户 {} 的文章 {}", user_id, post_id)
}

#[tokio::main]
async fn main() {
    let app = Router::new()
        .route("/", get(hello))
        .route("/users/:id", get(get_user_by_id))
        .route("/users/:user_id/posts/:post_id", get(get_post));

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

🏗️ 16.3.2 路由组

use axum::{
    routing::{get, post, delete},
    Router,
};

#[tokio::main]
async fn main() {
    // API 路由组
    let api_routes = Router::new()
        .route("/users", get(list_users).post(create_user))
        .route("/users/:id", get(get_user).delete(delete_user));

    // 管理路由组
    let admin_routes = Router::new()
        .route("/stats", get(get_stats))
        .route("/logs", get(get_logs));

    // 主路由
    let app = Router::new()
        .route("/", get(root))
        .nest("/api", api_routes)
        .nest("/admin", admin_routes);

    // GET /           -> root
    // GET /api/users  -> list_users
    // POST /api/users -> create_user
    // GET /admin/stats -> get_stats
}

async fn root() -> &'static str {
    "欢迎!"
}

async fn list_users() -> &'static str {
    "用户列表"
}

async fn create_user() -> &'static str {
    "创建用户"
}

async fn get_user() -> &'static str {
    "获取用户"
}

async fn delete_user() -> &'static str {
    "删除用户"
}

async fn get_stats() -> &'static str {
    "统计数据"
}

async fn get_logs() -> &'static str {
    "日志"
}

📤 16.4 提取器(Extractors)

提取器从请求中提取数据。

🔍 16.4.1 常用提取器

use axum::{
    extract::{
        Path, Query, Json, Form,
        Extension, State,
        OriginalUri, ConnectInfo,
    },
    body::Body,
    http::{HeaderMap, Method, Uri},
};
use serde::Deserialize;

#[derive(Deserialize)]
struct SearchQuery {
    q: String,
    page: Option<u32>,
}

#[derive(Deserialize)]
struct LoginForm {
    username: String,
    password: String,
}

async fn extract_all(
    // 路径参数
    Path(id): Path<u64>,
    // 查询参数
    Query(query): Query<SearchQuery>,
    // JSON 请求体
    Json(payload): Json<User>,
    // 表单数据
    Form(form): Form<LoginForm>,
    // 请求头
    headers: HeaderMap,
    // HTTP 方法
    method: Method,
    // URI
    uri: Uri,
    // 原始 URI
    OriginalUri(original_uri): OriginalUri,
) -> String {
    format!(
        "ID: {}, 搜索: {:?}, 方法: {}",
        id, query, method
    )
}

📦 16.4.2 自定义提取器

use axum::{
    async_trait,
    extract::{FromRequestParts, FromRequest},
    http::{request::Parts, Request},
    body::Body,
};

// 从请求头提取 API Key
struct ApiKey(String);

#[async_trait]
impl<S> FromRequestParts<S> for ApiKey
where
    S: Send + Sync,
{
    type Rejection = (axum::http::StatusCode, &'static str);

    async fn from_request_parts(
        parts: &mut Parts,
        _state: &S,
    ) -> Result<Self, Self::Rejection> {
        let header = parts.headers
            .get("X-API-Key")
            .and_then(|v| v.to_str().ok());

        match header {
            Some(key) => Ok(ApiKey(key.to_string())),
            None => Err((
                axum::http::StatusCode::UNAUTHORIZED,
                "缺少 API Key",
            )),
        }
    }
}

async fn protected(ApiKey(key): ApiKey) -> String {
    format!("你的 API Key: {}", key)
}

🗄️ 16.5 状态管理

📦 16.5.1 共享状态

use axum::{
    extract::State,
    routing::{get, post},
    Router,
};
use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize, Clone)]
struct User {
    id: u64,
    name: String,
}

// 应用状态
#[derive(Default)]
struct AppState {
    users: Mutex<Vec<User>>,
    counter: Mutex<u64>,
}

type SharedState = Arc<AppState>;

#[tokio::main]
async fn main() {
    // 创建共享状态
    let state = Arc::new(AppState::default());

    let app = Router::new()
        .route("/users", get(list_users).post(create_user))
        .route("/count", get(get_count))
        .with_state(state);

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

async fn list_users(State(state): State<SharedState>) -> Json<Vec<User>> {
    let users = state.users.lock().unwrap().clone();
    Json(users)
}

async fn create_user(
    State(state): State<SharedState>,
    Json(user): Json<User>,
) -> Json<User> {
    let mut users = state.users.lock().unwrap();
    let mut counter = state.counter.lock().unwrap();

    *counter += 1;
    let new_user = User {
        id: *counter,
        ..user
    };
    users.push(new_user.clone());

    Json(new_user)
}

async fn get_count(State(state): State<SharedState>) -> String {
    let counter = state.counter.lock().unwrap();
    format!("计数: {}", *counter)
}

🧵 16.5.2 使用 tokio::sync

use tokio::sync::RwLock;

struct AsyncState {
    users: RwLock<Vec<User>>,
}

async fn async_list_users(
    State(state): State<Arc<AsyncState>>,
) -> Json<Vec<User>> {
    let users = state.users.read().await.clone();
    Json(users)
}

async fn async_create_user(
    State(state): State<Arc<AsyncState>>,
    Json(user): Json<User>,
) -> Json<User> {
    let mut users = state.users.write().await;
    users.push(user.clone());
    Json(user)
}

🛡️ 16.6 中间件

📝 16.6.1 日志中间件

use axum::{
    middleware::{self, Next},
    response::Response,
    body::Body,
    http::Request,
};

async fn logging_middleware(
    request: Request<Body>,
    next: Next,
) -> Response {
    let method = request.method().clone();
    let uri = request.uri().clone();

    println!("请求: {} {}", method, uri);

    let start = std::time::Instant::now();
    let response = next.run(request).await;

    println!("响应: {} ({:?})", response.status(), start.elapsed());

    response
}

#[tokio::main]
async fn main() {
    let app = Router::new()
        .route("/", get(|| async { "Hello" }))
        .layer(middleware::from_fn(logging_middleware));

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

🔐 16.6.2 认证中间件

use axum::{
    middleware::{self, Next},
    response::{Response, IntoResponse},
    http::{Request, StatusCode},
};

struct AuthUser {
    id: u64,
    name: String,
}

async fn auth_middleware(
    mut request: Request<Body>,
    next: Next,
) -> Result<Response, StatusCode> {
    let auth_header = request.headers()
        .get("Authorization")
        .and_then(|h| h.to_str().ok());

    match auth_header {
        Some(token) if token == "Bearer secret123" => {
            // 添加用户信息到请求扩展
            let user = AuthUser {
                id: 1,
                name: "Alice".to_string(),
            };
            request.extensions_mut().insert(user);
            Ok(next.run(request).await)
        }
        _ => Err(StatusCode::UNAUTHORIZED),
    }
}

🌐 16.6.3 CORS

use tower_http::cors::{CorsLayer, Any};

#[tokio::main]
async fn main() {
    let cors = CorsLayer::new()
        .allow_origin(Any)
        .allow_methods(Any)
        .allow_headers(Any);

    let app = Router::new()
        .route("/", get(|| async { "Hello" }))
        .layer(cors);

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

🧪 16.7 实战:构建 RESTful API

让我们构建一个完整的待办事项 API:

use axum::{
    extract::{Path, State},
    http::StatusCode,
    response::{IntoResponse, Json},
    routing::{delete, get, post, put},
    Router,
};
use serde::{Deserialize, Serialize};
use std::{
    collections::HashMap,
    sync::{Arc, Mutex},
    net::SocketAddr,
};
use uuid::Uuid;

// ============ 数据模型 ============

#[derive(Debug, Serialize, Deserialize, Clone)]
struct Todo {
    id: String,
    title: String,
    completed: bool,
}

#[derive(Debug, Deserialize)]
struct CreateTodo {
    title: String,
}

#[derive(Debug, Deserialize)]
struct UpdateTodo {
    title: Option<String>,
    completed: Option<bool>,
}

// ============ 状态 ============

type Db = Arc<Mutex<HashMap<String, Todo>>>;

// ============ 错误处理 ============

enum ApiError {
    NotFound,
    BadRequest(String),
}

impl IntoResponse for ApiError {
    fn into_response(self) -> axum::response::Response {
        match self {
            ApiError::NotFound => {
                (StatusCode::NOT_FOUND, "资源未找到").into_response()
            }
            ApiError::BadRequest(msg) => {
                (StatusCode::BAD_REQUEST, msg).into_response()
            }
        }
    }
}

// ============ 处理函数 ============

async fn list_todos(State(db): State<Db>) -> impl IntoResponse {
    let todos: Vec<Todo> = db.lock().unwrap().values().cloned().collect();
    Json(todos)
}

async fn get_todo(
    State(db): State<Db>,
    Path(id): Path<String>,
) -> Result<Json<Todo>, ApiError> {
    let db = db.lock().unwrap();
    let todo = db.get(&id).cloned().ok_or(ApiError::NotFound)?;
    Ok(Json(todo))
}

async fn create_todo(
    State(db): State<Db>,
    Json(payload): Json<CreateTodo>,
) -> impl IntoResponse {
    let todo = Todo {
        id: Uuid::new_v4().to_string(),
        title: payload.title,
        completed: false,
    };

    db.lock().unwrap().insert(todo.id.clone(), todo.clone());

    (StatusCode::CREATED, Json(todo))
}

async fn update_todo(
    State(db): State<Db>,
    Path(id): Path<String>,
    Json(payload): Json<UpdateTodo>,
) -> Result<Json<Todo>, ApiError> {
    let mut db = db.lock().unwrap();

    let todo = db.get_mut(&id).ok_or(ApiError::NotFound)?;

    if let Some(title) = payload.title {
        todo.title = title;
    }
    if let Some(completed) = payload.completed {
        todo.completed = completed;
    }

    Ok(Json(todo.clone()))
}

async fn delete_todo(
    State(db): State<Db>,
    Path(id): Path<String>,
) -> Result<StatusCode, ApiError> {
    let mut db = db.lock().unwrap();

    if db.remove(&id).is_some() {
        Ok(StatusCode::NO_CONTENT)
    } else {
        Err(ApiError::NotFound)
    }
}

// ============ 主函数 ============

#[tokio::main]
async fn main() {
    // 初始化日志
    tracing_subscriber::fmt::init();

    // 创建数据库
    let db: Db = Arc::new(Mutex::new(HashMap::new()));

    // 构建路由
    let app = Router::new()
        .route("/todos", get(list_todos).post(create_todo))
        .route("/todos/:id", get(get_todo).put(update_todo).delete(delete_todo))
        .with_state(db);

    // 启动服务器
    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    println!("🚀 服务器启动在 http://{}", addr);
    println!("API 端点:");
    println!("  GET    /todos      获取所有待办事项");
    println!("  POST   /todos      创建待办事项");
    println!("  GET    /todos/:id  获取单个待办事项");
    println!("  PUT    /todos/:id  更新待办事项");
    println!("  DELETE /todos/:id  删除待办事项");

    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

📝 本章小结

本章我们学习了使用 Axum 进行 Web 开发:

组件用途
Router定义路由
提取器从请求提取数据
State共享应用状态
中间件处理横切关注点
IntoResponse自定义响应

关键要点:

  • Axum 是基于函数的框架,无需大量宏
  • 提取器让请求数据处理类型安全
  • 中间件用于认证、日志、CORS 等
  • 状态管理使用 Arc 包装

费曼技巧提问:为什么 Axum 的路由定义能保证编译期类型安全?提示:想想提取器如何工作。


动手实验

  1. 为待办事项 API 添加分页功能(?page=1&limit=10)。
  2. 实现一个简单的 JWT 认证中间件。
  3. 添加请求日志,记录每个请求的方法、路径和耗时。
  4. 使用 sqlx 替换内存存储,连接 SQLite 数据库。
← 返回目录