sa_token_plugin_ntex/
layer.rs

1use ntex::service::{Service, ServiceCtx, Middleware};
2use ntex::web::{Error, ErrorRenderer, WebRequest, WebResponse};
3use crate::state::SaTokenState;
4use sa_token_core::{token::TokenValue, SaTokenContext};
5use std::sync::Arc;
6
7#[derive(Clone)]
8pub struct SaTokenLayer {
9    state: SaTokenState,
10}
11
12impl SaTokenLayer {
13    pub fn new(state: SaTokenState) -> Self {
14        Self { state }
15    }
16}
17
18impl<S> Middleware<S> for SaTokenLayer {
19    type Service = SaTokenMiddleware<S>;
20
21    fn create(&self, service: S) -> Self::Service {
22        SaTokenMiddleware {
23            service,
24            state: self.state.clone(),
25        }
26    }
27}
28
29pub struct SaTokenMiddleware<S> {
30    service: S,
31    state: SaTokenState,
32}
33
34impl<S, Err> Service<WebRequest<Err>> for SaTokenMiddleware<S>
35where
36    S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
37    Err: ErrorRenderer,
38{
39    type Response = WebResponse;
40    type Error = Error;
41
42    async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
43        let mut sa_ctx = SaTokenContext::new();
44        
45        if let Some(token_str) = extract_token_from_request(&req, &self.state) {
46            tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
47            let token = TokenValue::new(token_str);
48            
49            if self.state.manager.is_valid(&token).await {
50                req.extensions_mut().insert(token.clone());
51                
52                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
53                    let login_id = token_info.login_id.clone();
54                    req.extensions_mut().insert(login_id.clone());
55                    
56                    sa_ctx.token = Some(token.clone());
57                    sa_ctx.token_info = Some(Arc::new(token_info));
58                    sa_ctx.login_id = Some(login_id);
59                }
60            }
61        }
62        
63        SaTokenContext::set_current(sa_ctx);
64        let result = ctx.call(&self.service, req).await;
65        SaTokenContext::clear();
66        result
67    }
68}
69
70fn extract_token_from_request<Err>(req: &WebRequest<Err>, state: &SaTokenState) -> Option<String> 
71where
72    Err: ErrorRenderer,
73{
74    let token_name = &state.manager.config.token_name;
75    let headers = req.headers();
76    
77    // 1. 从 token_name 指定的 header 获取
78    if let Some(header_value) = headers.get(token_name) {
79        if let Ok(value_str) = header_value.to_str() {
80            return Some(extract_bearer_token(value_str));
81        }
82    }
83    
84    // 2. 从标准 Authorization 头获取
85    if let Some(auth_header) = headers.get("Authorization") {
86        if let Ok(auth_str) = auth_header.to_str() {
87            return Some(extract_bearer_token(auth_str));
88        }
89    }
90    
91    // 3. 从 Cookie 获取
92    if let Some(cookie_header) = headers.get("cookie") {
93        if let Ok(cookie_str) = cookie_header.to_str() {
94            if let Some(token) = parse_cookie(cookie_str, token_name) {
95                return Some(token);
96            }
97        }
98    }
99    
100    // 4. 从查询参数获取
101    if let Some(query) = req.uri().query() {
102        if let Some(token) = parse_query_param(query, token_name) {
103            return Some(token);
104        }
105    }
106    
107    None
108}
109
110fn extract_bearer_token(header_value: &str) -> String {
111    if header_value.starts_with("Bearer ") {
112        header_value[7..].trim().to_string()
113    } else {
114        header_value.trim().to_string()
115    }
116}
117
118fn parse_cookie(cookie_str: &str, token_name: &str) -> Option<String> {
119    for part in cookie_str.split(';') {
120        let part = part.trim();
121        if let Some(eq_pos) = part.find('=') {
122            let (name, value) = part.split_at(eq_pos);
123            if name.trim() == token_name {
124                return Some(value[1..].trim().to_string());
125            }
126        }
127    }
128    None
129}
130
131fn parse_query_param(query: &str, param_name: &str) -> Option<String> {
132    for pair in query.split('&') {
133        let parts: Vec<&str> = pair.splitn(2, '=').collect();
134        if parts.len() == 2 && parts[0] == param_name {
135            return urlencoding::decode(parts[1])
136                .ok()
137                .map(|s| s.into_owned());
138        }
139    }
140    None
141}