sa_token_plugin_axum/
layer.rs

1// Author: 金书记
2//
3//! Axum中间件层
4
5use std::task::{Context, Poll};
6use tower::{Layer, Service};
7use http::{Request, Response};
8use sa_token_adapter::context::SaRequest;
9use crate::{SaTokenState, adapter::AxumRequestAdapter};
10use sa_token_core::{SaTokenContext, router::PathAuthConfig};
11use std::sync::Arc;
12
13/// Sa-Token layer for Axum with optional path-based authentication
14/// 支持可选路径鉴权的 Axum Sa-Token 层
15#[derive(Clone)]
16pub struct SaTokenLayer {
17    state: SaTokenState,
18    /// Optional path authentication configuration
19    /// 可选的路径鉴权配置
20    path_config: Option<PathAuthConfig>,
21}
22
23impl SaTokenLayer {
24    pub fn new(state: SaTokenState) -> Self {
25        Self { state, path_config: None }
26    }
27    
28    pub fn with_path_auth(state: SaTokenState, config: PathAuthConfig) -> Self {
29        Self { state, path_config: Some(config) }
30    }
31}
32
33impl<S> Layer<S> for SaTokenLayer {
34    type Service = SaTokenMiddleware<S>;
35    
36    fn layer(&self, inner: S) -> Self::Service {
37        SaTokenMiddleware {
38            inner,
39            state: self.state.clone(),
40            path_config: self.path_config.clone(),
41        }
42    }
43}
44
45#[derive(Clone)]
46pub struct SaTokenMiddleware<S> {
47    pub(crate) inner: S,
48    pub(crate) state: SaTokenState,
49    /// Optional path authentication configuration
50    /// 可选的路径鉴权配置
51    pub(crate) path_config: Option<PathAuthConfig>,
52}
53
54impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SaTokenMiddleware<S>
55where
56    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
57    S::Future: Send + 'static,
58    ReqBody: Send + 'static,
59    ResBody: Default + Send + 'static,
60{
61    type Response = S::Response;
62    type Error = S::Error;
63    type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
64    
65    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
66        self.inner.poll_ready(cx)
67    }
68    
69    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
70        let mut inner = self.inner.clone();
71        let state = self.state.clone();
72        let path_config = self.path_config.clone();
73        
74        Box::pin(async move {
75            if let Some(config) = path_config {
76                let path = request.uri().path();
77                let token_str = extract_token_from_request(&request, &state);
78                let result = sa_token_core::router::process_auth(path, token_str, &config, &state.manager).await;
79                
80                if result.should_reject() {
81                    let mut response = Response::new(ResBody::default());
82                    *response.status_mut() = http::StatusCode::UNAUTHORIZED;
83                    return Ok(response);
84                }
85                
86                if let Some(token) = &result.token {
87                    request.extensions_mut().insert(token.clone());
88                }
89                if let Some(login_id) = result.login_id() {
90                    request.extensions_mut().insert(login_id.to_string());
91                }
92                
93                let ctx = sa_token_core::router::create_context(&result);
94                SaTokenContext::set_current(ctx);
95                let response = inner.call(request).await;
96                SaTokenContext::clear();
97                return response;
98            }
99            
100            // No path auth config, use default token extraction and validation
101            // 没有路径鉴权配置,使用默认的 token 提取和验证
102            let mut ctx = SaTokenContext::new();
103            if let Some(token_str) = extract_token_from_request(&request, &state) {
104                let token = sa_token_core::token::TokenValue::new(token_str);
105                if state.manager.is_valid(&token).await {
106                    request.extensions_mut().insert(token.clone());
107                    if let Ok(token_info) = state.manager.get_token_info(&token).await {
108                        let login_id = token_info.login_id.clone();
109                        request.extensions_mut().insert(login_id.clone());
110                        ctx.token = Some(token.clone());
111                        ctx.token_info = Some(Arc::new(token_info));
112                        ctx.login_id = Some(login_id);
113                    }
114                }
115            }
116            
117            SaTokenContext::set_current(ctx);
118            let response = inner.call(request).await;
119            SaTokenContext::clear();
120            response
121        })
122    }
123}
124
125/// 从请求中提取 Token
126/// 
127/// 按优先级顺序查找 Token:
128/// 1. HTTP Header - `<token_name>: <token>` 或 `<token_name>: Bearer <token>`
129/// 2. HTTP Header - `Authorization: <token>` 或 `Authorization: Bearer <token>`(标准头)
130/// 3. Cookie - `<token_name>=<token>`
131/// 4. Query Parameter - `?<token_name>=<token>`
132/// 
133/// # 参数
134/// - `request` - HTTP 请求
135/// - `state` - SaToken 状态(从配置中获取 token_name)
136/// 
137/// # 返回
138/// - `Some(token)` - 找到有效的 token
139/// - `None` - 未找到 token
140pub fn extract_token_from_request<T>(request: &Request<T>, state: &SaTokenState) -> Option<String> {
141    let adapter = AxumRequestAdapter::new(request);
142    // 从配置中获取 token_name
143    let token_name = &state.manager.config.token_name;
144    
145    // 1. 优先从 Header 中获取(检查 token_name 配置的头)
146    if let Some(token) = adapter.get_header(token_name) {
147        return Some(extract_bearer_token(&token));
148    }
149    
150    // 2. 如果 token_name 不是 "Authorization",也尝试从 "Authorization" 头获取
151    if token_name != "Authorization" {
152        if let Some(token) = adapter.get_header("Authorization") {
153            return Some(extract_bearer_token(&token));
154        }
155    }
156    
157    // 3. 从 Cookie 中获取
158    if let Some(token) = adapter.get_cookie(token_name) {
159        return Some(token);
160    }
161    
162    // 4. 从 Query 参数中获取
163    if let Some(query) = request.uri().query() {
164        if let Some(token) = parse_query_param(query, token_name) {
165            return Some(token);
166        }
167    }
168    
169    None
170}
171
172/// 提取 Bearer Token
173/// 
174/// 支持两种格式:
175/// - `Bearer <token>` - 标准 Bearer Token 格式
176/// - `<token>` - 直接的 Token 字符串
177fn extract_bearer_token(header_value: &str) -> String {
178    const BEARER_PREFIX: &str = "Bearer ";
179    
180    if header_value.starts_with(BEARER_PREFIX) {
181        // 去除 "Bearer " 前缀
182        header_value[BEARER_PREFIX.len()..].trim().to_string()
183    } else {
184        // 直接返回 token
185        header_value.trim().to_string()
186    }
187}
188
189fn parse_query_param(query: &str, param_name: &str) -> Option<String> {
190    for pair in query.split('&') {
191        let parts: Vec<&str> = pair.splitn(2, '=').collect();
192        if parts.len() == 2 && parts[0] == param_name {
193            return urlencoding::decode(parts[1])
194                .ok()
195                .map(|s| s.into_owned());
196        }
197    }
198    None
199}