sa_token_plugin_salvo/
middleware.rs1use salvo::prelude::*;
7use sa_token_core::{StpUtil, error::messages, SaTokenContext, token::TokenValue};
8use serde_json::json;
9use crate::state::SaTokenState;
10use std::sync::Arc;
11use crate::layer::extract_token_from_request;
12
13pub fn auth_middleware() -> impl Handler {
26 auth_middleware_handler
27}
28
29#[handler]
30async fn auth_middleware_handler(req: &mut Request, res: &mut Response, depot: &mut Depot, ctrl: &mut FlowCtrl) {
31 let token = req
34 .headers()
35 .get("Authorization")
36 .and_then(|v| v.to_str().ok())
37 .and_then(|s| s.strip_prefix("Bearer "))
38 .map(|s| s.to_string());
39
40 if let Some(token_str) = token {
41 use sa_token_core::TokenValue;
44 let token_value = TokenValue::from(token_str.clone());
45 if StpUtil::is_login(&token_value).await {
46 if let Ok(login_id) = StpUtil::get_login_id(&token_value).await {
49 depot.insert("login_id", login_id);
50 ctrl.call_next(req, depot, res).await;
51 return;
52 }
53 }
54 }
55
56 res.status_code(StatusCode::UNAUTHORIZED);
59 res.render(Text::Json(r#"{"error":"Unauthorized"}"#));
60 ctrl.skip_rest();
61}
62
63pub fn permission_middleware(permission: &'static str) -> impl Handler {
76 PermissionMiddleware { permission }
77}
78
79struct PermissionMiddleware {
80 permission: &'static str,
81}
82
83#[handler]
84impl PermissionMiddleware {
85 async fn handle(&self, req: &mut Request, res: &mut Response, depot: &mut Depot, ctrl: &mut FlowCtrl) {
86 if let Ok(login_id) = depot.get::<String>("login_id") {
89 if StpUtil::has_permission(login_id, self.permission).await {
92 ctrl.call_next(req, depot, res).await;
93 return;
94 }
95 }
96
97 res.status_code(StatusCode::FORBIDDEN);
100 res.render(Text::Json(r#"{"error":"Forbidden"}"#));
101 ctrl.skip_rest();
102 }
103}
104
105#[derive(Clone)]
110pub struct SaCheckLoginMiddleware {
111 pub state: SaTokenState,
112}
113
114impl SaCheckLoginMiddleware {
115 pub fn new(state: SaTokenState) -> Self {
118 Self { state }
119 }
120}
121
122#[salvo::async_trait]
123impl Handler for SaCheckLoginMiddleware {
124 async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
125 let mut ctx = SaTokenContext::new();
126
127 if let Some(token_str) = extract_token_from_request(req, &self.state.manager.config.token_name) {
128 tracing::debug!("Sa-Token(login-check): extracted token from request: {}", token_str);
129 let token = TokenValue::new(token_str);
130
131 if self.state.manager.is_valid(&token).await {
132 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
133 let login_id = token_info.login_id.clone();
134 depot.insert("sa_token", token.clone());
135 depot.insert("sa_login_id", login_id.clone());
136
137 ctx.token = Some(token.clone());
138 ctx.token_info = Some(Arc::new(token_info));
139 ctx.login_id = Some(login_id);
140
141 SaTokenContext::set_current(ctx);
142 ctrl.call_next(req, depot, res).await;
143 SaTokenContext::clear();
144 return;
145 }
146 }
147 }
148
149 res.status_code(StatusCode::UNAUTHORIZED);
151 res.render(Text::Json(json!({
152 "code": 401,
153 "message": messages::AUTH_ERROR
154 }).to_string()));
155 ctrl.skip_rest();
156 }
157}
158
159#[derive(Clone)]
164pub struct SaCheckPermissionMiddleware {
165 pub state: SaTokenState,
166 permission: String,
167}
168
169impl SaCheckPermissionMiddleware {
170 pub fn new(state: SaTokenState, permission: impl Into<String>) -> Self {
173 Self { state, permission: permission.into() }
174 }
175}
176
177#[salvo::async_trait]
178impl Handler for SaCheckPermissionMiddleware {
179 async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
180 let mut ctx = SaTokenContext::new();
181
182 if let Some(token_str) = extract_token_from_request(req, &self.state.manager.config.token_name) {
183 tracing::debug!("Sa-Token(permission-check): extracted token from request: {}", token_str);
184 let token = TokenValue::new(token_str);
185
186 if self.state.manager.is_valid(&token).await {
187 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
188 let login_id = token_info.login_id.clone();
189
190 if StpUtil::has_permission(&login_id, &self.permission).await {
192 depot.insert("sa_token", token.clone());
193 depot.insert("sa_login_id", login_id.clone());
194
195 ctx.token = Some(token.clone());
196 ctx.token_info = Some(Arc::new(token_info));
197 ctx.login_id = Some(login_id);
198
199 SaTokenContext::set_current(ctx);
200 ctrl.call_next(req, depot, res).await;
201 SaTokenContext::clear();
202 return;
203 }
204 }
205 }
206 }
207
208 res.status_code(StatusCode::FORBIDDEN);
210 res.render(Text::Json(json!({
211 "code": 403,
212 "message": messages::PERMISSION_REQUIRED
213 }).to_string()));
214 ctrl.skip_rest();
215 }
216}
217
218#[derive(Clone)]
223pub struct SaCheckRoleMiddleware {
224 pub state: SaTokenState,
225 role: String,
226}
227
228impl SaCheckRoleMiddleware {
229 pub fn new(state: SaTokenState, role: impl Into<String>) -> Self {
232 Self { state, role: role.into() }
233 }
234}
235
236#[salvo::async_trait]
237impl Handler for SaCheckRoleMiddleware {
238 async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
239 let mut ctx = SaTokenContext::new();
240
241 if let Some(token_str) = extract_token_from_request(req, &self.state.manager.config.token_name) {
242 tracing::debug!("Sa-Token(role-check): extracted token from request: {}", token_str);
243 let token = TokenValue::new(token_str);
244
245 if self.state.manager.is_valid(&token).await {
246 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
247 let login_id = token_info.login_id.clone();
248
249 if StpUtil::has_role(&login_id, &self.role).await {
251 depot.insert("sa_token", token.clone());
252 depot.insert("sa_login_id", login_id.clone());
253
254 ctx.token = Some(token.clone());
255 ctx.token_info = Some(Arc::new(token_info));
256 ctx.login_id = Some(login_id);
257
258 SaTokenContext::set_current(ctx);
259 ctrl.call_next(req, depot, res).await;
260 SaTokenContext::clear();
261 return;
262 }
263 }
264 }
265 }
266
267 res.status_code(StatusCode::FORBIDDEN);
269 res.render(Text::Json(json!({
270 "code": 403,
271 "message": messages::ROLE_REQUIRED
272 }).to_string()));
273 ctrl.skip_rest();
274 }
275}