sa_token_plugin_salvo/
layer.rs

1use salvo::{Depot, Request, Response, Handler, FlowCtrl};
2use salvo::http::StatusCode;
3use sa_token_core::{token::TokenValue, SaTokenContext, router::PathAuthConfig};
4use crate::state::SaTokenState;
5use std::sync::Arc;
6use sa_token_adapter::utils::{parse_cookies, parse_query_string, extract_bearer_token as utils_extract_bearer_token};
7
8/// Sa-Token layer for Salvo with optional path-based authentication
9/// 支持可选路径鉴权的 Salvo Sa-Token 层
10#[derive(Clone)]
11pub struct SaTokenLayer {
12    state: SaTokenState,
13    /// Optional path authentication configuration
14    /// 可选的路径鉴权配置
15    path_config: Option<PathAuthConfig>,
16}
17
18impl SaTokenLayer {
19    /// Create layer without path authentication
20    /// 创建不带路径鉴权的层
21    pub fn new(state: SaTokenState) -> Self {
22        Self { state, path_config: None }
23    }
24    
25    /// Create layer with path-based authentication
26    /// 创建带路径鉴权的层
27    pub fn with_path_auth(state: SaTokenState, config: PathAuthConfig) -> Self {
28        Self { state, path_config: Some(config) }
29    }
30}
31
32#[salvo::async_trait]
33impl Handler for SaTokenLayer {
34    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
35        if let Some(config) = &self.path_config {
36            let path = req.uri().path();
37            let token_str = extract_token_from_request(req, &self.state.manager.config.token_name);
38            let result = sa_token_core::router::process_auth(path, token_str, config, &self.state.manager).await;
39            
40            if result.should_reject() {
41                res.status_code(StatusCode::UNAUTHORIZED);
42                return;
43            }
44            
45            let ctx = sa_token_core::router::create_context(&result);
46            SaTokenContext::set_current(ctx);
47            ctrl.call_next(req, depot, res).await;
48            SaTokenContext::clear();
49            return;
50        }
51        
52        // No path auth config, use default token extraction and validation
53        // 没有路径鉴权配置,使用默认的 token 提取和验证
54        let mut ctx = SaTokenContext::new();
55        if let Some(token_str) = extract_token_from_request(req, &self.state.manager.config.token_name) {
56            tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
57            let token = TokenValue::new(token_str);
58            
59            if self.state.manager.is_valid(&token).await {
60                depot.insert("sa_token", token.clone());
61                
62                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
63                    let login_id = token_info.login_id.clone();
64                    depot.insert("sa_login_id", login_id.clone());
65                    
66                    ctx.token = Some(token.clone());
67                    ctx.token_info = Some(Arc::new(token_info));
68                    ctx.login_id = Some(login_id);
69                }
70            }
71        }
72        
73        SaTokenContext::set_current(ctx);
74        ctrl.call_next(req, depot, res).await;
75        SaTokenContext::clear();
76    }
77}
78
79/// 中文 | English
80/// 从请求中提取 token | Extract token from request
81///
82/// 按以下顺序尝试提取 token: | Try to extract token in the following order:
83/// 1. 从指定名称的请求头 | From specified header name
84/// 2. 从 Authorization 请求头 | From Authorization header
85/// 3. 从 Cookie | From cookie
86/// 4. 从查询参数 | From query parameter
87pub fn extract_token_from_request(req: &Request, token_name: &str) -> Option<String> {
88    
89    // 1. 从指定名称的请求头提取 | Extract from specified header name
90    if let Some(header_value) = req.headers().get(token_name) {
91        if let Ok(value_str) = header_value.to_str() {
92            if !value_str.is_empty() {
93                if let Some(token) = utils_extract_bearer_token(value_str) {
94                    return Some(token);
95                }
96            }
97        }
98    }
99    
100    // 2. 从 Authorization 请求头提取 | Extract from Authorization header
101    if let Some(auth_header) = req.headers().get("authorization") {
102        if let Ok(auth_str) = auth_header.to_str() {
103            if !auth_str.is_empty() {
104                if let Some(token) = utils_extract_bearer_token(auth_str) {
105                    return Some(token);
106                }
107            }
108        }
109    }
110    
111    // 3. 从 Cookie 提取 | Extract from cookie
112    if let Some(cookie_header) = req.headers().get("cookie") {
113        if let Ok(cookie_str) = cookie_header.to_str() {
114            let cookies = parse_cookies(cookie_str);
115            if let Some(token) = cookies.get(token_name) {
116                if !token.is_empty() {
117                    return Some(token.to_string());
118                }
119            }
120        }
121    }
122    
123    // 4. 从查询参数提取 | Extract from query parameter
124    if let Some(query) = req.uri().query() {
125        let params = parse_query_string(query);
126        if let Some(token) = params.get(token_name) {
127            if !token.is_empty() {
128                return Some(token.to_string());
129            }
130        }
131    }
132    
133    None
134}