Skip to main content

sa_token_plugin_salvo_v079/
middleware.rs

1// Author: 金书记
2//
3// 中文 | English
4// Salvo 认证中间件 | Salvo authentication middleware
5
6use salvo::prelude::*;
7use sa_token_core::{StpUtil, error::messages};
8use serde_json::json;
9use sa_token_plugin_salvo_core::{run_auth_flow, SaTokenState};
10
11use crate::adapter::SalvoCapturedRequest;
12
13/// 中文 | English
14/// 认证中间件 - 验证用户登录状态 | Authentication middleware - verify user login status
15///
16/// # 示例 | Example
17/// ```rust,ignore
18/// use salvo::prelude::*;
19/// use sa_token_plugin_salvo::auth_middleware;
20///
21/// let router = Router::new()
22///     .hoop(auth_middleware())
23///     .push(Router::with_path("user").get(user_handler));
24/// ```
25pub fn auth_middleware() -> impl Handler {
26    auth_middleware_handler
27}
28
29#[handler]
30async fn auth_middleware_handler(req: &mut Request, res: &mut Response, depot: &mut Depot, ctrl: &mut FlowCtrl) {
31    // 中文 | English
32    // 从请求头中获取 token | Get token from request headers
33    let token = req
34        .headers()
35        .get("Authorization")
36        .and_then(|v| v.to_str().ok())
37        .and_then(|s| s.strip_prefix("Bearer "))
38        .map(|s| s.to_string());
39    
40    if let Some(token_str) = token {
41        // 中文 | English
42        // 验证 token 是否有效 | Verify if token is valid
43        use sa_token_core::TokenValue;
44        let token_value = TokenValue::from(token_str.clone());
45        if StpUtil::is_login(&token_value).await {
46            // 中文 | English
47            // Token 有效,将 login_id 存入 depot | Token valid, store login_id in depot
48            if let Ok(login_id) = StpUtil::get_login_id(&token_value).await {
49                depot.insert("login_id", login_id);
50                ctrl.call_next(req, depot, res).await;
51                return;
52            }
53        }
54    }
55    
56    // 中文 | English
57    // Token 无效,返回 401 | Token invalid, return 401
58    res.status_code(StatusCode::UNAUTHORIZED);
59    res.render(Text::Json(r#"{"error":"Unauthorized"}"#));
60    ctrl.skip_rest();
61}
62
63/// 中文 | English
64/// 权限验证中间件 - 验证用户是否拥有指定权限 | Permission middleware - verify if user has specified permissions
65///
66/// # 参数 | Parameters
67/// - `permission`: 需要的权限 | Required permission
68///
69/// # 示例 | Example
70/// ```rust,ignore
71/// let router = Router::new()
72///     .hoop(permission_middleware("user:read"))
73///     .push(Router::with_path("user").get(user_handler));
74/// ```
75pub fn permission_middleware(permission: &'static str) -> impl Handler {
76    PermissionMiddleware { permission }
77}
78
79struct PermissionMiddleware {
80    permission: &'static str,
81}
82
83#[handler]
84impl PermissionMiddleware {
85    async fn handle(&self, req: &mut Request, res: &mut Response, depot: &mut Depot, ctrl: &mut FlowCtrl) {
86        // 中文 | English
87        // 从 depot 获取 login_id | Get login_id from depot
88        if let Ok(login_id) = depot.get::<String>("login_id") {
89            // 中文 | English
90            // 验证权限 | Verify permission
91            if StpUtil::has_permission(login_id, self.permission).await {
92                ctrl.call_next(req, depot, res).await;
93                return;
94            }
95        }
96        
97        // 中文 | English
98        // 无权限,返回 403 | No permission, return 403
99        res.status_code(StatusCode::FORBIDDEN);
100        res.render(Text::Json(r#"{"error":"Forbidden"}"#));
101        ctrl.skip_rest();
102    }
103}
104
105/// 中文 | English
106/// Sa-Token 登录检查中间件 | Sa-Token login check middleware
107///
108/// 使用标准错误消息,检查当前请求是否已登录 | Uses standard error messages, checks if current request is logged in
109#[derive(Clone)]
110pub struct SaCheckLoginMiddleware {
111    pub state: SaTokenState,
112}
113
114impl SaCheckLoginMiddleware {
115    /// 中文 | English
116    /// 创建新的登录检查中间件 | Create new login check middleware
117    pub fn new(state: SaTokenState) -> Self {
118        Self { state }
119    }
120}
121
122#[salvo::async_trait]
123impl Handler for SaCheckLoginMiddleware {
124    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
125        let adapter =
126            SalvoCapturedRequest::capture(req, self.state.manager.config.token_name.as_str());
127        let flow = run_auth_flow(&adapter, &self.state.manager, None).await;
128
129        if flow.token.is_none() || flow.login_id.is_none() {
130            res.status_code(StatusCode::UNAUTHORIZED);
131            res.render(Text::Json(json!({
132                "code": 401,
133                "message": messages::AUTH_ERROR
134            }).to_string()));
135            ctrl.skip_rest();
136            return;
137        }
138
139        if let Some(ref t) = flow.token {
140            depot.insert("sa_token", t.clone());
141        }
142        if let Some(ref id) = flow.login_id {
143            depot.insert("sa_login_id", id.clone());
144        }
145
146        flow.run(ctrl.call_next(req, depot, res)).await;
147    }
148}
149
150/// 中文 | English
151/// Sa-Token 权限检查中间件 | Sa-Token permission check middleware
152///
153/// 检查当前请求用户是否拥有指定权限 | Checks if current request user has specified permission
154#[derive(Clone)]
155pub struct SaCheckPermissionMiddleware {
156    pub state: SaTokenState,
157    permission: String,
158}
159
160impl SaCheckPermissionMiddleware {
161    /// 中文 | English
162    /// 创建新的权限检查中间件 | Create new permission check middleware
163    pub fn new(state: SaTokenState, permission: impl Into<String>) -> Self {
164        Self { state, permission: permission.into() }
165    }
166}
167
168#[salvo::async_trait]
169impl Handler for SaCheckPermissionMiddleware {
170    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
171        let adapter =
172            SalvoCapturedRequest::capture(req, self.state.manager.config.token_name.as_str());
173        let flow = run_auth_flow(&adapter, &self.state.manager, None).await;
174
175        let Some(login_id) = flow.login_id.clone() else {
176            res.status_code(StatusCode::FORBIDDEN);
177            res.render(Text::Json(json!({
178                "code": 403,
179                "message": messages::PERMISSION_REQUIRED
180            }).to_string()));
181            ctrl.skip_rest();
182            return;
183        };
184
185        if !StpUtil::has_permission(&login_id, &self.permission).await {
186            res.status_code(StatusCode::FORBIDDEN);
187            res.render(Text::Json(json!({
188                "code": 403,
189                "message": messages::PERMISSION_REQUIRED
190            }).to_string()));
191            ctrl.skip_rest();
192            return;
193        }
194
195        if let Some(ref t) = flow.token {
196            depot.insert("sa_token", t.clone());
197        }
198        depot.insert("sa_login_id", login_id);
199
200        flow.run(ctrl.call_next(req, depot, res)).await;
201    }
202}
203
204/// 中文 | English
205/// Sa-Token 角色检查中间件 | Sa-Token role check middleware
206///
207/// 检查当前请求用户是否拥有指定角色 | Checks if current request user has specified role
208#[derive(Clone)]
209pub struct SaCheckRoleMiddleware {
210    pub state: SaTokenState,
211    role: String,
212}
213
214impl SaCheckRoleMiddleware {
215    /// 中文 | English
216    /// 创建新的角色检查中间件 | Create new role check middleware
217    pub fn new(state: SaTokenState, role: impl Into<String>) -> Self {
218        Self { state, role: role.into() }
219    }
220}
221
222#[salvo::async_trait]
223impl Handler for SaCheckRoleMiddleware {
224    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
225        let adapter =
226            SalvoCapturedRequest::capture(req, self.state.manager.config.token_name.as_str());
227        let flow = run_auth_flow(&adapter, &self.state.manager, None).await;
228
229        let Some(login_id) = flow.login_id.clone() else {
230            res.status_code(StatusCode::FORBIDDEN);
231            res.render(Text::Json(json!({
232                "code": 403,
233                "message": messages::ROLE_REQUIRED
234            }).to_string()));
235            ctrl.skip_rest();
236            return;
237        };
238
239        if !StpUtil::has_role(&login_id, &self.role).await {
240            res.status_code(StatusCode::FORBIDDEN);
241            res.render(Text::Json(json!({
242                "code": 403,
243                "message": messages::ROLE_REQUIRED
244            }).to_string()));
245            ctrl.skip_rest();
246            return;
247        }
248
249        if let Some(ref t) = flow.token {
250            depot.insert("sa_token", t.clone());
251        }
252        depot.insert("sa_login_id", login_id);
253
254        flow.run(ctrl.call_next(req, depot, res)).await;
255    }
256}