sa_token_plugin_ntex/
middleware.rs

1// Author: 金书记
2//
3//! 中间件实现
4//!
5//! 提供多种中间件:
6//! - `SaTokenMiddleware`:基础 token 提取和验证中间件
7//! - `SaCheckLoginMiddleware`:检查登录中间件,未登录时返回401错误
8//! - `SaCheckPermissionMiddleware`:检查权限中间件,无权限时返回403错误
9//! - `SaCheckRoleMiddleware`:检查角色中间件,无角色时返回403错误
10//! - `AuthMiddleware`、`PermissionMiddleware`:已废弃,建议使用上述中间件
11
12use ntex::service::{Middleware, Service, ServiceCtx};
13use ntex::web::{Error, ErrorRenderer, WebRequest, WebResponse};
14use std::sync::Arc;
15use serde_json::json;
16use sa_token_core::{
17    error::messages, 
18    token::TokenValue, 
19    SaTokenContext,
20    StpUtil
21};
22use sa_token_adapter::utils::{parse_cookies, parse_query_string, extract_bearer_token};
23use crate::SaTokenState;
24use ntex::web::error::InternalError;
25use ntex::web::Error as WebError;
26
27
28/// sa-token 基础中间件 - 提取并验证 token
29/// 
30/// 此中间件会从请求中提取 token,验证其有效性,并将相关信息存储到请求扩展中
31pub struct SaTokenMiddleware {
32    pub state: SaTokenState,
33}
34
35impl SaTokenMiddleware {
36    pub fn new(state: SaTokenState) -> Self {
37        Self { state }
38    }
39}
40
41impl<S> Middleware<S> for SaTokenMiddleware {
42    type Service = SaTokenMiddlewareService<S>;
43
44    fn create(&self, service: S) -> Self::Service {
45        SaTokenMiddlewareService {
46            service,
47            state: self.state.clone(),
48        }
49    }
50}
51
52pub struct SaTokenMiddlewareService<S> {
53    service: S,
54    state: SaTokenState,
55}
56
57impl<S, Err> Service<WebRequest<Err>> for SaTokenMiddlewareService<S>
58where
59    S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
60    Err: ErrorRenderer,
61{
62    type Response = WebResponse;
63    type Error = Error;
64
65    async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
66        let mut sa_ctx = SaTokenContext::new();
67        
68        // 提取 token
69        if let Some(token_str) = extract_token_from_request(&req, &self.state) {
70            tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
71            let token = TokenValue::new(token_str);
72            
73            // 验证 token
74            if self.state.manager.is_valid(&token).await {
75                // 存储 token 到请求扩展
76                req.extensions_mut().insert(token.clone());
77                
78                // 获取并存储 login_id
79                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
80                    let login_id = token_info.login_id.clone();
81                    req.extensions_mut().insert(login_id.clone());
82                    
83                    // 设置上下文
84                    sa_ctx.token = Some(token.clone());
85                    sa_ctx.token_info = Some(Arc::new(token_info));
86                    sa_ctx.login_id = Some(login_id);
87                }
88            }
89        }
90        
91        // 设置当前上下文
92        SaTokenContext::set_current(sa_ctx);
93        
94        // 继续处理请求
95        let result = ctx.call(&self.service, req).await;
96        
97        // 清除上下文
98        SaTokenContext::clear();
99        
100        result
101    }
102}
103
104/// 中文 | English
105/// 认证中间件 - 验证用户登录状态 | Authentication middleware - verify user login status
106/// 
107/// 注意:此中间件已废弃,建议使用 SaTokenMiddleware + SaCheckLoginMiddleware
108/// 
109/// # 示例 | Example
110/// ```rust,ignore
111/// use ntex::web;
112/// use sa_token_plugin_ntex::AuthMiddleware;
113///
114/// let app = web::App::new()
115///     .wrap(AuthMiddleware)
116///     .route("/user", web::get().to(user_handler));
117/// ```
118#[deprecated(note = "Use SaTokenMiddleware + SaCheckLoginMiddleware instead")]
119pub struct AuthMiddleware;
120
121impl<S> Middleware<S> for AuthMiddleware {
122    type Service = AuthMiddlewareService<S>;
123
124    fn create(&self, service: S) -> Self::Service {
125        AuthMiddlewareService { service }
126    }
127}
128
129pub struct AuthMiddlewareService<S> {
130    service: S,
131}
132
133impl<S, Err> Service<WebRequest<Err>> for AuthMiddlewareService<S>
134where
135    S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
136    Err: ErrorRenderer,
137{
138    type Response = WebResponse;
139    type Error = Error;
140
141    async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
142        // 中文 | English
143        // 从请求头中获取 token | Get token from request headers
144        let token = req
145            .headers()
146            .get("Authorization")
147            .and_then(|v| v.to_str().ok())
148            .and_then(|s| s.strip_prefix("Bearer "))
149            .map(|s| s.to_string());
150        
151        if let Some(token_str) = token {
152            // 中文 | English
153            // 验证 token 是否有效 | Verify if token is valid
154            use sa_token_core::TokenValue;
155            let token_value = TokenValue::from(token_str.clone());
156            if StpUtil::is_login(&token_value).await {
157                // 中文 | English
158                // Token 有效,继续处理请求 | Token valid, continue processing
159                if let Ok(login_id) = StpUtil::get_login_id(&token_value).await {
160                    req.extensions_mut().insert(login_id);
161                    return ctx.call(&self.service, req).await;
162                }
163            }
164        }
165        
166        // 中文 | English
167        // Token 无效,返回 401 | Token invalid, return 401
168        Err(WebError::from(InternalError::new(
169            "Unauthorized",
170            ntex::http::StatusCode::UNAUTHORIZED,
171        )))
172    }
173}
174
175/// sa-token 登录检查中间件 - 强制要求登录
176/// 
177/// 此中间件会检查用户是否已登录,如果未登录则返回401错误
178pub struct SaCheckLoginMiddleware {
179    pub state: SaTokenState,
180}
181
182impl SaCheckLoginMiddleware {
183    pub fn new(state: SaTokenState) -> Self {
184        Self { state }
185    }
186}
187
188impl<S> Middleware<S> for SaCheckLoginMiddleware {
189    type Service = SaCheckLoginMiddlewareService<S>;
190
191    fn create(&self, service: S) -> Self::Service {
192        SaCheckLoginMiddlewareService {
193            service,
194            state: self.state.clone(),
195        }
196    }
197}
198
199pub struct SaCheckLoginMiddlewareService<S> {
200    service: S,
201    state: SaTokenState,
202}
203
204impl<S, Err> Service<WebRequest<Err>> for SaCheckLoginMiddlewareService<S>
205where
206    S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
207    Err: ErrorRenderer,
208{
209    type Response = WebResponse;
210    type Error = Error;
211
212    async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
213        let mut sa_ctx = SaTokenContext::new();
214        
215        // 提取 token
216        if let Some(token_str) = extract_token_from_request(&req, &self.state) {
217            tracing::debug!("Sa-Token(login-check): extracted token from request: {}", token_str);
218            let token = TokenValue::new(token_str);
219            
220            // 验证 token
221            if self.state.manager.is_valid(&token).await {
222                // 存储 token 和 login_id
223                req.extensions_mut().insert(token.clone());
224                
225                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
226                    let login_id = token_info.login_id.clone();
227                    req.extensions_mut().insert(login_id.clone());
228                    
229                    // 设置上下文
230                    sa_ctx.token = Some(token.clone());
231                    sa_ctx.token_info = Some(Arc::new(token_info));
232                    sa_ctx.login_id = Some(login_id);
233                    
234                    SaTokenContext::set_current(sa_ctx);
235                    let result = ctx.call(&self.service, req).await;
236                    SaTokenContext::clear();
237                    return result;
238                }
239            }
240        }
241        
242        // 未登录,返回401错误
243        Err(WebError::from(InternalError::new(
244            json!({
245                "code": 401,
246                "message": messages::AUTH_ERROR
247            }).to_string(),
248            ntex::http::StatusCode::UNAUTHORIZED,
249        )))
250    }
251}
252
253/// sa-token 权限检查中间件 - 强制要求特定权限
254/// 
255/// 此中间件会检查用户是否拥有指定权限,如果没有则返回403错误
256pub struct SaCheckPermissionMiddleware {
257    pub state: SaTokenState,
258    permission: String,
259}
260
261impl SaCheckPermissionMiddleware {
262    pub fn new(state: SaTokenState, permission: impl Into<String>) -> Self {
263        Self {
264            state,
265            permission: permission.into(),
266        }
267    }
268}
269
270impl<S> Middleware<S> for SaCheckPermissionMiddleware {
271    type Service = SaCheckPermissionMiddlewareService<S>;
272
273    fn create(&self, service: S) -> Self::Service {
274        SaCheckPermissionMiddlewareService {
275            service,
276            state: self.state.clone(),
277            permission: self.permission.clone(),
278        }
279    }
280}
281
282pub struct SaCheckPermissionMiddlewareService<S> {
283    service: S,
284    state: SaTokenState,
285    permission: String,
286}
287
288impl<S, Err> Service<WebRequest<Err>> for SaCheckPermissionMiddlewareService<S>
289where
290    S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
291    Err: ErrorRenderer,
292{
293    type Response = WebResponse;
294    type Error = Error;
295
296    async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
297        let mut sa_ctx = SaTokenContext::new();
298        
299        // 提取 token
300        if let Some(token_str) = extract_token_from_request(&req, &self.state) {
301            tracing::debug!("Sa-Token(permission-check): extracted token from request: {}", token_str);
302            let token = TokenValue::new(token_str);
303            
304            // 验证 token
305            if self.state.manager.is_valid(&token).await {
306                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
307                    let login_id = token_info.login_id.clone();
308                    
309                    // 检查权限
310                    if StpUtil::has_permission(&login_id, &self.permission).await {
311                        // 存储信息到请求扩展
312                        req.extensions_mut().insert(token.clone());
313                        req.extensions_mut().insert(login_id.clone());
314                        
315                        // 设置上下文
316                        sa_ctx.token = Some(token.clone());
317                        sa_ctx.token_info = Some(Arc::new(token_info));
318                        sa_ctx.login_id = Some(login_id);
319                        
320                        SaTokenContext::set_current(sa_ctx);
321                        let result = ctx.call(&self.service, req).await;
322                        SaTokenContext::clear();
323                        return result;
324                    }
325                }
326            }
327        }
328        
329        // 无权限或未登录,返回403错误
330        Err(WebError::from(InternalError::new(
331            json!({
332                "code": 403,
333                "message": messages::PERMISSION_REQUIRED
334            }).to_string(),
335            ntex::http::StatusCode::FORBIDDEN,
336        )))
337    }
338}
339
340/// sa-token 角色检查中间件 - 强制要求特定角色
341/// 
342/// 此中间件会检查用户是否拥有指定角色,如果没有则返回403错误
343pub struct SaCheckRoleMiddleware {
344    pub state: SaTokenState,
345    role: String,
346}
347
348impl SaCheckRoleMiddleware {
349    pub fn new(state: SaTokenState, role: impl Into<String>) -> Self {
350        Self {
351            state,
352            role: role.into(),
353        }
354    }
355}
356
357impl<S> Middleware<S> for SaCheckRoleMiddleware {
358    type Service = SaCheckRoleMiddlewareService<S>;
359
360    fn create(&self, service: S) -> Self::Service {
361        SaCheckRoleMiddlewareService {
362            service,
363            state: self.state.clone(),
364            role: self.role.clone(),
365        }
366    }
367}
368
369pub struct SaCheckRoleMiddlewareService<S> {
370    service: S,
371    state: SaTokenState,
372    role: String,
373}
374
375impl<S, Err> Service<WebRequest<Err>> for SaCheckRoleMiddlewareService<S>
376where
377    S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
378    Err: ErrorRenderer,
379{
380    type Response = WebResponse;
381    type Error = Error;
382
383    async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
384        let mut sa_ctx = SaTokenContext::new();
385        
386        // 提取 token
387        if let Some(token_str) = extract_token_from_request(&req, &self.state) {
388            tracing::debug!("Sa-Token(role-check): extracted token from request: {}", token_str);
389            let token = TokenValue::new(token_str);
390            
391            // 验证 token
392            if self.state.manager.is_valid(&token).await {
393                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
394                    let login_id = token_info.login_id.clone();
395                    
396                    // 检查角色
397                    if StpUtil::has_role(&login_id, &self.role).await {
398                        // 存储信息到请求扩展
399                        req.extensions_mut().insert(token.clone());
400                        req.extensions_mut().insert(login_id.clone());
401                        
402                        // 设置上下文
403                        sa_ctx.token = Some(token.clone());
404                        sa_ctx.token_info = Some(Arc::new(token_info));
405                        sa_ctx.login_id = Some(login_id);
406                        
407                        SaTokenContext::set_current(sa_ctx);
408                        let result = ctx.call(&self.service, req).await;
409                        SaTokenContext::clear();
410                        return result;
411                    }
412                }
413            }
414        }
415        
416        // 无角色或未登录,返回403错误
417
418        Err(WebError::from(InternalError::new(
419            json!({
420                "code": 403,
421                "message": messages::ROLE_REQUIRED
422            }).to_string(),
423            ntex::http::StatusCode::FORBIDDEN,
424        )))
425    }
426}
427
428/// 中文 | English
429/// 权限验证中间件 - 验证用户是否拥有指定权限 | Permission middleware - verify if user has specified permissions
430/// 
431/// 注意:此中间件已废弃,建议使用 SaCheckPermissionMiddleware
432#[deprecated(note = "Use SaCheckPermissionMiddleware instead")]
433pub struct PermissionMiddleware {
434    permission: String,
435}
436
437impl PermissionMiddleware {
438    /// 中文 | English
439    /// 创建权限验证中间件 | Create permission middleware
440    pub fn new(permission: impl Into<String>) -> Self {
441        Self {
442            permission: permission.into(),
443        }
444    }
445}
446
447impl<S> Middleware<S> for PermissionMiddleware {
448    type Service = PermissionMiddlewareService<S>;
449
450    fn create(&self, service: S) -> Self::Service {
451        PermissionMiddlewareService {
452            service,
453            permission: self.permission.clone(),
454        }
455    }
456}
457
458pub struct PermissionMiddlewareService<S> {
459    service: S,
460    permission: String,
461}
462
463impl<S, Err> Service<WebRequest<Err>> for PermissionMiddlewareService<S>
464where
465    S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
466    Err: ErrorRenderer,
467{
468    type Response = WebResponse;
469    type Error = Error;
470
471    async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
472        // 中文 | English
473        // 注意:此方法已废弃,建议使用 SaCheckPermissionMiddleware
474        // Note: This method is deprecated, use SaCheckPermissionMiddleware instead
475        
476        // 首先尝试从扩展数据获取 login_id(可能由其他中间件设置)
477        // First try to get login_id from extensions (may be set by other middleware)
478        let has_login_id = req.extensions().get::<String>().is_some();
479        
480        if has_login_id {
481            let login_id = req.extensions().get::<String>().unwrap().clone();
482            // 验证权限 | Verify permission
483            if StpUtil::has_permission(&login_id, &self.permission).await {
484                return ctx.call(&self.service, req).await;
485            }
486        } else {
487            // 如果扩展中没有 login_id,尝试从请求中提取 token 并验证
488            // If no login_id in extensions, try to extract token from request and verify
489            if let Some(token_str) = extract_token_from_request_simple(&req) {
490                let token = TokenValue::new(token_str);
491                
492                // 简单验证 token 是否有效
493                // Simple token validation
494                if StpUtil::is_login(&token).await {
495                    if let Ok(login_id) = StpUtil::get_login_id(&token).await {
496                        // 验证权限 | Verify permission
497                        if StpUtil::has_permission(&login_id, &self.permission).await {
498                            // 将 login_id 存储到扩展中供后续使用
499                            // Store login_id in extensions for later use
500                            req.extensions_mut().insert(login_id);
501                            return ctx.call(&self.service, req).await;
502                        }
503                    }
504                }
505            }
506        }
507        
508        // 无权限或未登录,返回 403 | No permission or not logged in, return 403
509        Err(WebError::from(InternalError::new(
510            json!({
511                "code": 403,
512                "message": messages::PERMISSION_REQUIRED
513            }).to_string(),
514            ntex::http::StatusCode::FORBIDDEN,
515        )))
516    }
517}
518
519/// 从请求中提取 token
520/// 
521/// 参考 Actix-web 实现,支持从 Header、Cookie、Query 参数中提取
522fn extract_token_from_request<Err>(req: &WebRequest<Err>, state: &SaTokenState) -> Option<String>
523where
524    Err: ErrorRenderer,
525{
526    let token_name = &state.manager.config.token_name;
527    
528    // 1. 优先从 Header 中获取
529    if let Some(header_value) = req.headers().get(token_name) {
530        if let Ok(value_str) = header_value.to_str() {
531            if let Some(token) = extract_bearer_token(value_str) {
532                return Some(token);
533            }
534        }
535    }
536    
537    // 检查 Authorization header
538    if let Some(auth_header) = req.headers().get("authorization") {
539        if let Ok(auth_str) = auth_header.to_str() {
540            if let Some(token) = extract_bearer_token(auth_str) {
541                return Some(token);
542            }
543        }
544    }
545    
546    // 2. 从 Cookie 中获取
547    if let Some(cookie_header) = req.headers().get("cookie") {
548        if let Ok(cookie_str) = cookie_header.to_str() {
549            let cookies = parse_cookies(cookie_str);
550            if let Some(token) = cookies.get(token_name) {
551                return Some(token.clone());
552            }
553        }
554    }
555    
556    // 3. 从 Query 参数中获取
557    let query = req.query_string();
558    if !query.is_empty() {
559        let params = parse_query_string(query);
560        if let Some(token) = params.get(token_name) {
561            return Some(token.clone());
562        }
563    }
564    
565    None
566}
567
568/// 简化的 token 提取函数(用于废弃的中间件)
569/// 
570/// 仅从 Authorization header 中提取 Bearer token
571fn extract_token_from_request_simple<Err>(req: &WebRequest<Err>) -> Option<String>
572where
573    Err: ErrorRenderer,
574{
575    // 只从 Authorization header 中获取 Bearer token
576    if let Some(auth_header) = req.headers().get("authorization") {
577        if let Ok(auth_str) = auth_header.to_str() {
578            if let Some(token) = extract_bearer_token(auth_str) {
579                return Some(token);
580            }
581        }
582    }
583    
584    None
585}