sa_token_plugin_gotham/
layer.rs

1use gotham::state::State;
2use gotham::middleware::Middleware;
3use gotham::handler::HandlerFuture;
4use std::pin::Pin;
5use sa_token_core::{token::TokenValue, SaTokenContext};
6use crate::state::SaTokenState;
7use std::sync::Arc;
8
9#[derive(Clone)]
10pub struct SaTokenLayer {
11    state: SaTokenState,
12}
13
14impl SaTokenLayer {
15    pub fn new(state: SaTokenState) -> Self {
16        Self { state }
17    }
18}
19
20impl Middleware for SaTokenLayer {
21    fn call<Chain>(self, mut state: State, chain: Chain) -> Pin<Box<HandlerFuture>>
22    where
23        Chain: FnOnce(State) -> Pin<Box<HandlerFuture>> + Send + 'static,
24    {
25        Box::pin(async move {
26            let mut ctx = SaTokenContext::new();
27            
28            if let Some(token_str) = extract_token_from_state(&state, &self.state) {
29                tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
30                let token = TokenValue::new(token_str);
31                
32                if self.state.manager.is_valid(&token).await {
33                    if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
34                        let login_id = token_info.login_id.clone();
35                        
36                        ctx.token = Some(token.clone());
37                        ctx.token_info = Some(Arc::new(token_info));
38                        ctx.login_id = Some(login_id.clone());
39                        
40                        state.put(crate::wrapper::TokenValueWrapper(token));
41                        state.put(crate::wrapper::LoginIdWrapper(login_id));
42                    }
43                }
44            }
45            
46            SaTokenContext::set_current(ctx);
47            let result = chain(state).await;
48            SaTokenContext::clear();
49            result
50        })
51    }
52}
53
54/// 从 Gotham State 中提取 Token
55/// 
56/// 按优先级顺序查找 Token:
57/// 1. HTTP Header - `<token_name>: <token>` 或 `<token_name>: Bearer <token>`
58/// 2. Authorization Header - `Authorization: Bearer <token>`
59/// 3. Cookie - `<token_name>=<token>`
60/// 4. Query Parameter - `?<token_name>=<token>`
61fn extract_token_from_state(state: &State, token_state: &SaTokenState) -> Option<String> {
62    use gotham::hyper::{HeaderMap, Uri};
63    use sa_token_adapter::utils::{parse_cookies, parse_query_string};
64    
65    // 从配置中获取 token_name
66    let token_name = &token_state.manager.config.token_name;
67    
68    // 1. 从 Header 中获取
69    if let Some(headers) = state.try_borrow::<HeaderMap>() {
70        // 1.1 尝试从指定名称的 header 获取
71        if let Some(header_value) = headers.get(token_name) {
72            if let Ok(value_str) = header_value.to_str() {
73                return Some(extract_bearer_token(value_str));
74            }
75        }
76        
77        // 1.2 尝试从 Authorization header 获取
78        if let Some(auth_header) = headers.get("authorization") {
79            if let Ok(auth_str) = auth_header.to_str() {
80                return Some(extract_bearer_token(auth_str));
81            }
82        }
83        
84        // 2. 从 Cookie 中获取
85        if let Some(cookie_header) = headers.get("cookie") {
86            if let Ok(cookie_str) = cookie_header.to_str() {
87                let cookies = parse_cookies(cookie_str);
88                if let Some(token) = cookies.get(token_name) {
89                    return Some(token.clone());
90                }
91            }
92        }
93    }
94    
95    // 3. 从 Query 参数中获取
96    if let Some(uri) = state.try_borrow::<Uri>() {
97        if let Some(query) = uri.query() {
98            let params = parse_query_string(query);
99            if let Some(token) = params.get(token_name) {
100                return Some(token.clone());
101            }
102        }
103    }
104    
105    None
106}
107
108/// 提取 Bearer Token
109/// 
110/// 支持两种格式:
111/// - `Bearer <token>` - 标准 Bearer Token 格式
112/// - `<token>` - 直接的 Token 字符串
113fn extract_bearer_token(header_value: &str) -> String {
114    const BEARER_PREFIX: &str = "Bearer ";
115    
116    if header_value.starts_with(BEARER_PREFIX) {
117        // 去除 "Bearer " 前缀
118        header_value[BEARER_PREFIX.len()..].trim().to_string()
119    } else {
120        // 直接返回 token
121        header_value.trim().to_string()
122    }
123}