sa_token_plugin_salvo/
middleware.rs

1// Author: 金书记
2//
3// 中文 | English
4// Salvo 认证中间件 | Salvo authentication middleware
5
6use salvo::prelude::*;
7use sa_token_core::{StpUtil, error::messages, SaTokenContext, token::TokenValue};
8use serde_json::json;
9use crate::state::SaTokenState;
10use std::sync::Arc;
11use crate::layer::extract_token_from_request;
12
13/// 中文 | English
14/// 认证中间件 - 验证用户登录状态 | Authentication middleware - verify user login status
15///
16/// # 示例 | Example
17/// ```rust,ignore
18/// use salvo::prelude::*;
19/// use sa_token_plugin_salvo::auth_middleware;
20///
21/// let router = Router::new()
22///     .hoop(auth_middleware())
23///     .push(Router::with_path("user").get(user_handler));
24/// ```
25pub fn auth_middleware() -> impl Handler {
26    auth_middleware_handler
27}
28
29#[handler]
30async fn auth_middleware_handler(req: &mut Request, res: &mut Response, depot: &mut Depot, ctrl: &mut FlowCtrl) {
31    // 中文 | English
32    // 从请求头中获取 token | Get token from request headers
33    let token = req
34        .headers()
35        .get("Authorization")
36        .and_then(|v| v.to_str().ok())
37        .and_then(|s| s.strip_prefix("Bearer "))
38        .map(|s| s.to_string());
39    
40    if let Some(token_str) = token {
41        // 中文 | English
42        // 验证 token 是否有效 | Verify if token is valid
43        use sa_token_core::TokenValue;
44        let token_value = TokenValue::from(token_str.clone());
45        if StpUtil::is_login(&token_value).await {
46            // 中文 | English
47            // Token 有效,将 login_id 存入 depot | Token valid, store login_id in depot
48            if let Ok(login_id) = StpUtil::get_login_id(&token_value).await {
49                depot.insert("login_id", login_id);
50                ctrl.call_next(req, depot, res).await;
51                return;
52            }
53        }
54    }
55    
56    // 中文 | English
57    // Token 无效,返回 401 | Token invalid, return 401
58    res.status_code(StatusCode::UNAUTHORIZED);
59    res.render(Text::Json(r#"{"error":"Unauthorized"}"#));
60    ctrl.skip_rest();
61}
62
63/// 中文 | English
64/// 权限验证中间件 - 验证用户是否拥有指定权限 | Permission middleware - verify if user has specified permissions
65///
66/// # 参数 | Parameters
67/// - `permission`: 需要的权限 | Required permission
68///
69/// # 示例 | Example
70/// ```rust,ignore
71/// let router = Router::new()
72///     .hoop(permission_middleware("user:read"))
73///     .push(Router::with_path("user").get(user_handler));
74/// ```
75pub fn permission_middleware(permission: &'static str) -> impl Handler {
76    PermissionMiddleware { permission }
77}
78
79struct PermissionMiddleware {
80    permission: &'static str,
81}
82
83#[handler]
84impl PermissionMiddleware {
85    async fn handle(&self, req: &mut Request, res: &mut Response, depot: &mut Depot, ctrl: &mut FlowCtrl) {
86        // 中文 | English
87        // 从 depot 获取 login_id | Get login_id from depot
88        if let Ok(login_id) = depot.get::<String>("login_id") {
89            // 中文 | English
90            // 验证权限 | Verify permission
91            if StpUtil::has_permission(login_id, self.permission).await {
92                ctrl.call_next(req, depot, res).await;
93                return;
94            }
95        }
96        
97        // 中文 | English
98        // 无权限,返回 403 | No permission, return 403
99        res.status_code(StatusCode::FORBIDDEN);
100        res.render(Text::Json(r#"{"error":"Forbidden"}"#));
101        ctrl.skip_rest();
102    }
103}
104
105/// 中文 | English
106/// Sa-Token 登录检查中间件 | Sa-Token login check middleware
107///
108/// 使用标准错误消息,检查当前请求是否已登录 | Uses standard error messages, checks if current request is logged in
109#[derive(Clone)]
110pub struct SaCheckLoginMiddleware {
111    pub state: SaTokenState,
112}
113
114impl SaCheckLoginMiddleware {
115    /// 中文 | English
116    /// 创建新的登录检查中间件 | Create new login check middleware
117    pub fn new(state: SaTokenState) -> Self {
118        Self { state }
119    }
120}
121
122#[salvo::async_trait]
123impl Handler for SaCheckLoginMiddleware {
124    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
125        let mut ctx = SaTokenContext::new();
126        
127        if let Some(token_str) = extract_token_from_request(req, &self.state.manager.config.token_name) {
128            tracing::debug!("Sa-Token(login-check): extracted token from request: {}", token_str);
129            let token = TokenValue::new(token_str);
130            
131            if self.state.manager.is_valid(&token).await {
132                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
133                    let login_id = token_info.login_id.clone();
134                    depot.insert("sa_token", token.clone());
135                    depot.insert("sa_login_id", login_id.clone());
136                    
137                    ctx.token = Some(token.clone());
138                    ctx.token_info = Some(Arc::new(token_info));
139                    ctx.login_id = Some(login_id);
140                    
141                    SaTokenContext::set_current(ctx);
142                    ctrl.call_next(req, depot, res).await;
143                    SaTokenContext::clear();
144                    return;
145                }
146            }
147        }
148        
149        // 未登录,返回401错误
150        res.status_code(StatusCode::UNAUTHORIZED);
151        res.render(Text::Json(json!({
152            "code": 401,
153            "message": messages::AUTH_ERROR
154        }).to_string()));
155        ctrl.skip_rest();
156    }
157}
158
159/// 中文 | English
160/// Sa-Token 权限检查中间件 | Sa-Token permission check middleware
161///
162/// 检查当前请求用户是否拥有指定权限 | Checks if current request user has specified permission
163#[derive(Clone)]
164pub struct SaCheckPermissionMiddleware {
165    pub state: SaTokenState,
166    permission: String,
167}
168
169impl SaCheckPermissionMiddleware {
170    /// 中文 | English
171    /// 创建新的权限检查中间件 | Create new permission check middleware
172    pub fn new(state: SaTokenState, permission: impl Into<String>) -> Self {
173        Self { state, permission: permission.into() }
174    }
175}
176
177#[salvo::async_trait]
178impl Handler for SaCheckPermissionMiddleware {
179    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
180        let mut ctx = SaTokenContext::new();
181        
182        if let Some(token_str) = extract_token_from_request(req, &self.state.manager.config.token_name) {
183            tracing::debug!("Sa-Token(permission-check): extracted token from request: {}", token_str);
184            let token = TokenValue::new(token_str);
185            
186            if self.state.manager.is_valid(&token).await {
187                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
188                    let login_id = token_info.login_id.clone();
189                    
190                    // 检查权限
191                    if StpUtil::has_permission(&login_id, &self.permission).await {
192                        depot.insert("sa_token", token.clone());
193                        depot.insert("sa_login_id", login_id.clone());
194                        
195                        ctx.token = Some(token.clone());
196                        ctx.token_info = Some(Arc::new(token_info));
197                        ctx.login_id = Some(login_id);
198                        
199                        SaTokenContext::set_current(ctx);
200                        ctrl.call_next(req, depot, res).await;
201                        SaTokenContext::clear();
202                        return;
203                    }
204                }
205            }
206        }
207        
208        // 无权限,返回403错误
209        res.status_code(StatusCode::FORBIDDEN);
210        res.render(Text::Json(json!({
211            "code": 403,
212            "message": messages::PERMISSION_REQUIRED
213        }).to_string()));
214        ctrl.skip_rest();
215    }
216}
217
218/// 中文 | English
219/// Sa-Token 角色检查中间件 | Sa-Token role check middleware
220///
221/// 检查当前请求用户是否拥有指定角色 | Checks if current request user has specified role
222#[derive(Clone)]
223pub struct SaCheckRoleMiddleware {
224    pub state: SaTokenState,
225    role: String,
226}
227
228impl SaCheckRoleMiddleware {
229    /// 中文 | English
230    /// 创建新的角色检查中间件 | Create new role check middleware
231    pub fn new(state: SaTokenState, role: impl Into<String>) -> Self {
232        Self { state, role: role.into() }
233    }
234}
235
236#[salvo::async_trait]
237impl Handler for SaCheckRoleMiddleware {
238    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
239        let mut ctx = SaTokenContext::new();
240        
241        if let Some(token_str) = extract_token_from_request(req, &self.state.manager.config.token_name) {
242            tracing::debug!("Sa-Token(role-check): extracted token from request: {}", token_str);
243            let token = TokenValue::new(token_str);
244            
245            if self.state.manager.is_valid(&token).await {
246                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
247                    let login_id = token_info.login_id.clone();
248                    
249                    // 检查角色
250                    if StpUtil::has_role(&login_id, &self.role).await {
251                        depot.insert("sa_token", token.clone());
252                        depot.insert("sa_login_id", login_id.clone());
253                        
254                        ctx.token = Some(token.clone());
255                        ctx.token_info = Some(Arc::new(token_info));
256                        ctx.login_id = Some(login_id);
257                        
258                        SaTokenContext::set_current(ctx);
259                        ctrl.call_next(req, depot, res).await;
260                        SaTokenContext::clear();
261                        return;
262                    }
263                }
264            }
265        }
266        
267        // 无角色权限,返回403错误
268        res.status_code(StatusCode::FORBIDDEN);
269        res.render(Text::Json(json!({
270            "code": 403,
271            "message": messages::ROLE_REQUIRED
272        }).to_string()));
273        ctrl.skip_rest();
274    }
275}