sa_token_plugin_actix_web/
middleware.rs

1// Author: 金书记
2//
3//! Actix-web中间件
4
5use std::future::{ready, Ready, Future};
6use std::pin::Pin;
7use std::rc::Rc;
8use actix_web::{
9    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
10    Error, HttpMessage, error::ErrorUnauthorized,
11};
12use crate::SaTokenState;
13use crate::adapter::ActixRequestAdapter;
14use sa_token_adapter::context::SaRequest;
15use sa_token_core::{token::TokenValue, SaTokenContext, error::messages};
16use std::sync::Arc;
17
18/// sa-token 基础中间件 - 提取并验证 token
19use sa_token_core::router::PathAuthConfig;
20
21/// Sa-Token middleware with optional path-based authentication
22/// 支持可选路径鉴权的 Sa-Token 中间件
23pub struct SaTokenMiddleware {
24    pub state: SaTokenState,
25    /// Optional path authentication configuration
26    /// 可选的路径鉴权配置
27    pub path_config: Option<PathAuthConfig>,
28}
29
30impl SaTokenMiddleware {
31    /// Create middleware without path authentication
32    /// 创建不带路径鉴权的中间件
33    pub fn new(state: SaTokenState) -> Self {
34        Self { state, path_config: None }
35    }
36    
37    /// Create middleware with path-based authentication
38    /// 创建带路径鉴权的中间件
39    pub fn with_path_auth(state: SaTokenState, config: PathAuthConfig) -> Self {
40        Self { state, path_config: Some(config) }
41    }
42}
43
44impl<S, B> Transform<S, ServiceRequest> for SaTokenMiddleware
45where
46    S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
47    S::Future: 'static,
48    B: 'static,
49{
50    type Response = ServiceResponse<B>;
51    type Error = Error;
52    type InitError = ();
53    type Transform = SaTokenMiddlewareService<S>;
54    type Future = Ready<Result<Self::Transform, Self::InitError>>;
55
56    fn new_transform(&self, service: S) -> Self::Future {
57        ready(Ok(SaTokenMiddlewareService {
58            service: Rc::new(service),
59            state: self.state.clone(),
60            path_config: self.path_config.clone(),
61        }))
62    }
63}
64
65/// Sa-Token middleware service for Actix-web
66/// Actix-web 的 Sa-Token 中间件服务
67pub struct SaTokenMiddlewareService<S> {
68    service: Rc<S>,
69    state: SaTokenState,
70    /// Optional path authentication configuration
71    /// 可选的路径鉴权配置
72    path_config: Option<PathAuthConfig>,
73}
74
75impl<S, B> Service<ServiceRequest> for SaTokenMiddlewareService<S>
76where
77    S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
78    S::Future: 'static,
79    B: 'static,
80{
81    type Response = ServiceResponse<B>;
82    type Error = Error;
83    type Future = Pin<Box<dyn Future<Output=Result<Self::Response, Self::Error>>>>;
84
85    forward_ready!(service);
86
87    fn call(&self, req: ServiceRequest) -> Self::Future {
88        let service = Rc::clone(&self.service);
89        let state = self.state.clone();
90        let path_config = self.path_config.clone();
91        
92        Box::pin(async move {
93            if let Some(config) = path_config {
94                let path = req.path();
95                let token_str = extract_token_from_request(&req, &state);
96                let result = sa_token_core::router::process_auth(path, token_str, &config, &state.manager).await;
97                
98                if result.should_reject() {
99                    return Err(ErrorUnauthorized(serde_json::json!({"code": 401, "message": messages::AUTH_ERROR}).to_string()));
100                }
101                
102                if let Some(token) = &result.token {
103                    req.extensions_mut().insert(token.clone());
104                }
105                if let Some(login_id) = result.login_id() {
106                    req.extensions_mut().insert(login_id.to_string());
107                }
108                
109                let ctx = sa_token_core::router::create_context(&result);
110                SaTokenContext::set_current(ctx);
111                let response = service.call(req).await;
112                SaTokenContext::clear();
113                return response;
114            }
115            
116            let mut ctx = SaTokenContext::new();
117            if let Some(token_str) = extract_token_from_request(&req, &state) {
118                let token = TokenValue::new(token_str);
119                if state.manager.is_valid(&token).await {
120                    req.extensions_mut().insert(token.clone());
121                    if let Ok(token_info) = state.manager.get_token_info(&token).await {
122                        let login_id = token_info.login_id.clone();
123                        req.extensions_mut().insert(login_id.clone());
124                        ctx.token = Some(token.clone());
125                        ctx.token_info = Some(Arc::new(token_info));
126                        ctx.login_id = Some(login_id);
127                    }
128                }
129            }
130            
131            SaTokenContext::set_current(ctx);
132            let result = service.call(req).await;
133            SaTokenContext::clear();
134            result
135        })
136    }
137}
138
139/// sa-token 登录检查中间件 - 强制要求登录
140pub struct SaCheckLoginMiddleware {
141    pub state: SaTokenState,
142}
143
144impl SaCheckLoginMiddleware {
145    pub fn new(state: SaTokenState) -> Self {
146        Self { state }
147    }
148}
149
150impl<S, B> Transform<S, ServiceRequest> for SaCheckLoginMiddleware
151where
152    S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
153    S::Future: 'static,
154    B: 'static,
155{
156    type Response = ServiceResponse<B>;
157    type Error = Error;
158    type InitError = ();
159    type Transform = SaCheckLoginMiddlewareService<S>;
160    type Future = Ready<Result<Self::Transform, Self::InitError>>;
161
162    fn new_transform(&self, service: S) -> Self::Future {
163        ready(Ok(SaCheckLoginMiddlewareService {
164            service: Rc::new(service),
165            state: self.state.clone(),
166        }))
167    }
168}
169
170pub struct SaCheckLoginMiddlewareService<S> {
171    service: Rc<S>,
172    state: SaTokenState,
173}
174
175impl<S, B> Service<ServiceRequest> for SaCheckLoginMiddlewareService<S>
176where
177    S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
178    S::Future: 'static,
179    B: 'static,
180{
181    type Response = ServiceResponse<B>;
182    type Error = Error;
183    type Future = Pin<Box<dyn Future<Output=Result<Self::Response, Self::Error>>>>;
184
185    forward_ready!(service);
186
187    fn call(&self, req: ServiceRequest) -> Self::Future {
188        let service = Rc::clone(&self.service);
189        let state = self.state.clone();
190
191        Box::pin(async move {
192            let mut ctx = SaTokenContext::new();
193            // 提取 token
194            if let Some(token_str) = extract_token_from_request(&req, &state) {
195                tracing::debug!("Sa-Token(login-check): extracted token from request: {}", token_str);
196                let token = TokenValue::new(token_str);
197
198                // 验证 token
199                if state.manager.is_valid(&token).await {
200                    // 存储 token 和 login_id
201                    req.extensions_mut().insert(token.clone());
202
203                    if let Ok(token_info) = state.manager.get_token_info(&token).await {
204                        let login_id = token_info.login_id.clone();
205                        req.extensions_mut().insert(login_id.clone());
206                        ctx.token = Some(token.clone());
207                        ctx.token_info = Some(Arc::new(token_info));
208                        ctx.login_id = Some(login_id);
209
210                        // 设置上下文
211                        SaTokenContext::set_current(ctx);
212                        let result = service.call(req).await;
213                        SaTokenContext::clear();
214                        return result;
215                    }
216                }
217            }
218
219            // 未登录,返回 401
220            Err(ErrorUnauthorized(serde_json::json!({
221                "code": 401,
222                "message": messages::AUTH_ERROR
223            }).to_string()))
224        })
225    }
226}
227
228/// 从请求中提取 token
229pub fn extract_token_from_request(req: &ServiceRequest, state: &SaTokenState) -> Option<String> {
230    let adapter = ActixRequestAdapter::new(req.request());
231    let token_name = &state.manager.config.token_name;
232    
233    tracing::debug!("Sa-Token: 尝试从请求提取 token,token_name: {}", token_name);
234    
235    // 1. 优先从 Header 中获取(检查 token_name 配置的头)
236    if let Some(token) = adapter.get_header(token_name) {
237        tracing::debug!("Sa-Token: 从 Header[{}] 获取到 token", token_name);
238        return Some(extract_bearer_token(&token));
239    }
240    
241    // 2. 如果 token_name 不是 "Authorization",也尝试从 "Authorization" 头获取
242    if token_name != "Authorization" {
243        if let Some(token) = adapter.get_header("Authorization") {
244            tracing::debug!("Sa-Token: 从 Header[Authorization] 获取到 token");
245            return Some(extract_bearer_token(&token));
246        }
247    }
248    
249    // 3. 从 Cookie 中获取
250    if let Some(token) = adapter.get_cookie(token_name) {
251        tracing::debug!("Sa-Token: 从 Cookie[{}] 获取到 token", token_name);
252        return Some(token);
253    }
254    
255    // 4. 从 Query 参数中获取
256    if let Some(query) = req.query_string().split('&').find_map(|pair| {
257        let mut parts = pair.split('=');
258        if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
259            if key == token_name {
260                return urlencoding::decode(value).ok().map(|s| s.to_string());
261            }
262        }
263        None
264    }) {
265        tracing::debug!("Sa-Token: 从 Query[{}] 获取到 token", token_name);
266        return Some(query);
267    }
268    
269    tracing::debug!("Sa-Token: 所有位置都未找到 token");
270    None
271}
272
273fn extract_bearer_token(token: &str) -> String {
274    if token.starts_with("Bearer ") {
275        token[7..].to_string()
276    } else {
277        token.to_string()
278    }
279}