sa_token_plugin_tide/
layer.rs

1use tide::{Middleware, Request, Result, Next};
2use sa_token_core::{token::TokenValue, SaTokenContext};
3use std::sync::Arc;
4use crate::state::SaTokenState;
5use sa_token_adapter::utils::{parse_cookies, parse_query_string, extract_bearer_token as utils_extract_bearer_token};
6
7use sa_token_core::router::PathAuthConfig;
8
9/// Sa-Token layer for Tide with optional path-based authentication
10/// 支持可选路径鉴权的 Tide Sa-Token 层
11#[derive(Clone)]
12pub struct SaTokenLayer {
13    state: SaTokenState,
14    /// Optional path authentication configuration
15    /// 可选的路径鉴权配置
16    path_config: Option<PathAuthConfig>,
17}
18
19impl SaTokenLayer {
20    /// Create layer without path authentication
21    /// 创建不带路径鉴权的层
22    pub fn new(state: SaTokenState) -> Self {
23        Self { state, path_config: None }
24    }
25    
26    /// Create layer with path-based authentication
27    /// 创建带路径鉴权的层
28    pub fn with_path_auth(state: SaTokenState, config: PathAuthConfig) -> Self {
29        Self { state, path_config: Some(config) }
30    }
31}
32
33#[tide::utils::async_trait]
34impl<State: Clone + Send + Sync + 'static> Middleware<State> for SaTokenLayer {
35    async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> Result {
36        if let Some(config) = &self.path_config {
37            let path = req.url().path();
38            let token_str = extract_token_from_request(&req, &self.state.manager.config.token_name);
39            let result = sa_token_core::router::process_auth(path, token_str, config, &self.state.manager).await;
40            
41            if result.should_reject() {
42                return Ok(tide::Response::builder(tide::StatusCode::Unauthorized).build());
43            }
44            
45            let ctx = sa_token_core::router::create_context(&result);
46            SaTokenContext::set_current(ctx);
47            let response = next.run(req).await;
48            SaTokenContext::clear();
49            return Ok(response);
50        }
51        
52        // No path auth config, use default token extraction and validation
53        // 没有路径鉴权配置,使用默认的 token 提取和验证
54        let mut ctx = SaTokenContext::new();
55        if let Some(token_str) = extract_token_from_request(&req, &self.state.manager.config.token_name) {
56            tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
57            let token = TokenValue::new(token_str);
58            
59            if self.state.manager.is_valid(&token).await {
60                req.set_ext(token.clone());
61                
62                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
63                    let login_id = token_info.login_id.clone();
64                    req.set_ext(login_id.clone());
65                    
66                    ctx.token = Some(token.clone());
67                    ctx.token_info = Some(Arc::new(token_info));
68                    ctx.login_id = Some(login_id);
69                }
70            }
71        }
72        
73        SaTokenContext::set_current(ctx);
74        let result = next.run(req).await;
75        SaTokenContext::clear();
76        Ok(result)
77    }
78}
79
80/// 中文 | English
81/// 从请求中提取 token | Extract token from request
82///
83/// 按以下顺序尝试提取 token: | Try to extract token in the following order:
84/// 1. 从指定名称的请求头 | From specified header name
85/// 2. 从 Authorization 请求头 | From Authorization header
86/// 3. 从 Cookie | From cookie
87/// 4. 从查询参数 | From query parameter
88pub fn extract_token_from_request<State>(req: &Request<State>, token_name: &str) -> Option<String> {
89    if let Some(header_value) = req.header(token_name) {
90        if let Some(value_str) = header_value.get(0) {
91            let value_str = value_str.as_str();
92            if !value_str.is_empty() {
93                if let Some(token) = utils_extract_bearer_token(value_str) {
94                    return Some(token);
95                }
96            }
97        }
98    }
99    
100    // 2. 从 Authorization 请求头提取 | Extract from Authorization header
101    if let Some(auth_header) = req.header("authorization") {
102        if let Some(auth_str) = auth_header.get(0) {
103            let auth_str = auth_str.as_str();
104            if !auth_str.is_empty() {
105                if let Some(token) = utils_extract_bearer_token(auth_str) {
106                    return Some(token);
107                }
108            }
109        }
110    }
111    
112    // 3. 从 Cookie 提取 | Extract from cookie
113    if let Some(cookie_header) = req.header("cookie") {
114        if let Some(cookie_str) = cookie_header.get(0) {
115            let cookies = parse_cookies(cookie_str.as_str());
116            if let Some(token) = cookies.get(token_name) {
117                if !token.is_empty() {
118                    return Some(token.to_string());
119                }
120            }
121        }
122    }
123    
124    // 4. 从查询参数提取 | Extract from query parameter
125    if let Some(query) = req.url().query() {
126        let params = parse_query_string(query);
127        if let Some(token) = params.get(token_name) {
128            if !token.is_empty() {
129                return Some(token.to_string());
130            }
131        }
132    }
133    
134    None
135}