Skip to main content

server_middleware/middleware/
auth.rs

1use std::sync::Arc;
2
3///
4/// 认证中间件
5/// 通过请求路径查看是否可以通过认证
6///
7///
8///
9use axum::{extract::Request, http::header, middleware::Next, response::Response};
10use server_common::{error::auth::AuthError, jwt::JwtService};
11
12pub async fn auth_middleware(mut req: Request, next: Next) -> Result<Response, AuthError> {
13    // 跳过公开路径(可选)
14    let path = req.uri().path();
15    // 请求方法
16    let method = req.method().clone();
17
18    // 获取jwt服务
19    let service = req
20        .extensions()
21        .get::<Arc<JwtService>>()
22        .ok_or(AuthError::MissingToken)?;
23
24    // 判断是否为跟目录
25    if path == "/" {
26        return Ok(next.run(req).await);
27    }
28
29    // 判断是否需要认证
30    if service.is_ignore_uri(path, method.as_str()) {
31        return Ok(next.run(req).await);
32    }
33
34    // 获取认证token
35    let token = req
36        .headers()
37        .get(header::AUTHORIZATION)
38        .and_then(|value| value.to_str().ok())
39        .and_then(|value| value.strip_prefix("Bearer "))
40        .ok_or(AuthError::MissingToken)?;
41
42    // 验证 JWT
43    let claims = service.verify_token(token)?;
44
45    // 为了在提取器中减少开销
46    req.extensions_mut().insert(Arc::new(claims));
47
48    Ok(next.run(req).await)
49}