sa_token_plugin_tide/
layer.rs

1use tide::{Middleware, Request, Result, Next};
2use sa_token_core::{token::TokenValue, SaTokenContext};
3use std::sync::Arc;
4use crate::state::SaTokenState;
5use sa_token_adapter::utils::{parse_cookies, parse_query_string, extract_bearer_token as utils_extract_bearer_token};
6
7#[derive(Clone)]
8pub struct SaTokenLayer {
9    state: SaTokenState,
10}
11
12impl SaTokenLayer {
13    pub fn new(state: SaTokenState) -> Self {
14        Self { state }
15    }
16}
17
18#[tide::utils::async_trait]
19impl<State: Clone + Send + Sync + 'static> Middleware<State> for SaTokenLayer {
20    async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> Result {
21        let mut ctx = SaTokenContext::new();
22        
23        if let Some(token_str) = extract_token_from_request(&req, &self.state) {
24            tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
25            let token = TokenValue::new(token_str);
26            
27            if self.state.manager.is_valid(&token).await {
28                req.set_ext(token.clone());
29                
30                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
31                    let login_id = token_info.login_id.clone();
32                    req.set_ext(login_id.clone());
33                    
34                    ctx.token = Some(token.clone());
35                    ctx.token_info = Some(Arc::new(token_info));
36                    ctx.login_id = Some(login_id);
37                }
38            }
39        }
40        
41        SaTokenContext::set_current(ctx);
42        let result = next.run(req).await;
43        SaTokenContext::clear();
44        Ok(result)
45    }
46}
47
48/// 中文 | English
49/// 从请求中提取 token | Extract token from request
50///
51/// 按以下顺序尝试提取 token: | Try to extract token in the following order:
52/// 1. 从指定名称的请求头 | From specified header name
53/// 2. 从 Authorization 请求头 | From Authorization header
54/// 3. 从 Cookie | From cookie
55/// 4. 从查询参数 | From query parameter
56pub fn extract_token_from_request<State>(req: &Request<State>, token_state: &SaTokenState) -> Option<String> {
57    let token_name = &token_state.manager.config.token_name;
58    
59    // 1. 从指定名称的请求头提取 | Extract from specified header name
60    if let Some(header_value) = req.header(token_name.as_str()) {
61        if let Some(value_str) = header_value.get(0) {
62            let value_str = value_str.as_str();
63            if !value_str.is_empty() {
64                if let Some(token) = utils_extract_bearer_token(value_str) {
65                    return Some(token);
66                }
67            }
68        }
69    }
70    
71    // 2. 从 Authorization 请求头提取 | Extract from Authorization header
72    if let Some(auth_header) = req.header("authorization") {
73        if let Some(auth_str) = auth_header.get(0) {
74            let auth_str = auth_str.as_str();
75            if !auth_str.is_empty() {
76                if let Some(token) = utils_extract_bearer_token(auth_str) {
77                    return Some(token);
78                }
79            }
80        }
81    }
82    
83    // 3. 从 Cookie 提取 | Extract from cookie
84    if let Some(cookie_header) = req.header("cookie") {
85        if let Some(cookie_str) = cookie_header.get(0) {
86            let cookies = parse_cookies(cookie_str.as_str());
87            if let Some(token) = cookies.get(token_name) {
88                if !token.is_empty() {
89                    return Some(token.to_string());
90                }
91            }
92        }
93    }
94    
95    // 4. 从查询参数提取 | Extract from query parameter
96    if let Some(query) = req.url().query() {
97        let params = parse_query_string(query);
98        if let Some(token) = params.get(token_name) {
99            if !token.is_empty() {
100                return Some(token.to_string());
101            }
102        }
103    }
104    
105    None
106}