1use ntex::service::{Middleware, Service, ServiceCtx};
13use ntex::web::{Error, ErrorRenderer, WebRequest, WebResponse};
14use std::sync::Arc;
15use serde_json::json;
16use sa_token_core::{
17 error::messages,
18 token::TokenValue,
19 SaTokenContext,
20 StpUtil
21};
22use sa_token_adapter::utils::{parse_cookies, parse_query_string, extract_bearer_token};
23use crate::SaTokenState;
24use ntex::web::error::InternalError;
25use ntex::web::Error as WebError;
26
27
28pub struct SaTokenMiddleware {
32 pub state: SaTokenState,
33}
34
35impl SaTokenMiddleware {
36 pub fn new(state: SaTokenState) -> Self {
37 Self { state }
38 }
39}
40
41impl<S> Middleware<S> for SaTokenMiddleware {
42 type Service = SaTokenMiddlewareService<S>;
43
44 fn create(&self, service: S) -> Self::Service {
45 SaTokenMiddlewareService {
46 service,
47 state: self.state.clone(),
48 }
49 }
50}
51
52pub struct SaTokenMiddlewareService<S> {
53 service: S,
54 state: SaTokenState,
55}
56
57impl<S, Err> Service<WebRequest<Err>> for SaTokenMiddlewareService<S>
58where
59 S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
60 Err: ErrorRenderer,
61{
62 type Response = WebResponse;
63 type Error = Error;
64
65 async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
66 let mut sa_ctx = SaTokenContext::new();
67
68 if let Some(token_str) = extract_token_from_request(&req, &self.state) {
70 tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
71 let token = TokenValue::new(token_str);
72
73 if self.state.manager.is_valid(&token).await {
75 req.extensions_mut().insert(token.clone());
77
78 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
80 let login_id = token_info.login_id.clone();
81 req.extensions_mut().insert(login_id.clone());
82
83 sa_ctx.token = Some(token.clone());
85 sa_ctx.token_info = Some(Arc::new(token_info));
86 sa_ctx.login_id = Some(login_id);
87 }
88 }
89 }
90
91 SaTokenContext::set_current(sa_ctx);
93
94 let result = ctx.call(&self.service, req).await;
96
97 SaTokenContext::clear();
99
100 result
101 }
102}
103
104#[deprecated(note = "Use SaTokenMiddleware + SaCheckLoginMiddleware instead")]
119pub struct AuthMiddleware;
120
121impl<S> Middleware<S> for AuthMiddleware {
122 type Service = AuthMiddlewareService<S>;
123
124 fn create(&self, service: S) -> Self::Service {
125 AuthMiddlewareService { service }
126 }
127}
128
129pub struct AuthMiddlewareService<S> {
130 service: S,
131}
132
133impl<S, Err> Service<WebRequest<Err>> for AuthMiddlewareService<S>
134where
135 S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
136 Err: ErrorRenderer,
137{
138 type Response = WebResponse;
139 type Error = Error;
140
141 async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
142 let token = req
145 .headers()
146 .get("Authorization")
147 .and_then(|v| v.to_str().ok())
148 .and_then(|s| s.strip_prefix("Bearer "))
149 .map(|s| s.to_string());
150
151 if let Some(token_str) = token {
152 use sa_token_core::TokenValue;
155 let token_value = TokenValue::from(token_str.clone());
156 if StpUtil::is_login(&token_value).await {
157 if let Ok(login_id) = StpUtil::get_login_id(&token_value).await {
160 req.extensions_mut().insert(login_id);
161 return ctx.call(&self.service, req).await;
162 }
163 }
164 }
165
166 Err(WebError::from(InternalError::new(
169 "Unauthorized",
170 ntex::http::StatusCode::UNAUTHORIZED,
171 )))
172 }
173}
174
175pub struct SaCheckLoginMiddleware {
179 pub state: SaTokenState,
180}
181
182impl SaCheckLoginMiddleware {
183 pub fn new(state: SaTokenState) -> Self {
184 Self { state }
185 }
186}
187
188impl<S> Middleware<S> for SaCheckLoginMiddleware {
189 type Service = SaCheckLoginMiddlewareService<S>;
190
191 fn create(&self, service: S) -> Self::Service {
192 SaCheckLoginMiddlewareService {
193 service,
194 state: self.state.clone(),
195 }
196 }
197}
198
199pub struct SaCheckLoginMiddlewareService<S> {
200 service: S,
201 state: SaTokenState,
202}
203
204impl<S, Err> Service<WebRequest<Err>> for SaCheckLoginMiddlewareService<S>
205where
206 S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
207 Err: ErrorRenderer,
208{
209 type Response = WebResponse;
210 type Error = Error;
211
212 async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
213 let mut sa_ctx = SaTokenContext::new();
214
215 if let Some(token_str) = extract_token_from_request(&req, &self.state) {
217 tracing::debug!("Sa-Token(login-check): extracted token from request: {}", token_str);
218 let token = TokenValue::new(token_str);
219
220 if self.state.manager.is_valid(&token).await {
222 req.extensions_mut().insert(token.clone());
224
225 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
226 let login_id = token_info.login_id.clone();
227 req.extensions_mut().insert(login_id.clone());
228
229 sa_ctx.token = Some(token.clone());
231 sa_ctx.token_info = Some(Arc::new(token_info));
232 sa_ctx.login_id = Some(login_id);
233
234 SaTokenContext::set_current(sa_ctx);
235 let result = ctx.call(&self.service, req).await;
236 SaTokenContext::clear();
237 return result;
238 }
239 }
240 }
241
242 Err(WebError::from(InternalError::new(
244 json!({
245 "code": 401,
246 "message": messages::AUTH_ERROR
247 }).to_string(),
248 ntex::http::StatusCode::UNAUTHORIZED,
249 )))
250 }
251}
252
253pub struct SaCheckPermissionMiddleware {
257 pub state: SaTokenState,
258 permission: String,
259}
260
261impl SaCheckPermissionMiddleware {
262 pub fn new(state: SaTokenState, permission: impl Into<String>) -> Self {
263 Self {
264 state,
265 permission: permission.into(),
266 }
267 }
268}
269
270impl<S> Middleware<S> for SaCheckPermissionMiddleware {
271 type Service = SaCheckPermissionMiddlewareService<S>;
272
273 fn create(&self, service: S) -> Self::Service {
274 SaCheckPermissionMiddlewareService {
275 service,
276 state: self.state.clone(),
277 permission: self.permission.clone(),
278 }
279 }
280}
281
282pub struct SaCheckPermissionMiddlewareService<S> {
283 service: S,
284 state: SaTokenState,
285 permission: String,
286}
287
288impl<S, Err> Service<WebRequest<Err>> for SaCheckPermissionMiddlewareService<S>
289where
290 S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
291 Err: ErrorRenderer,
292{
293 type Response = WebResponse;
294 type Error = Error;
295
296 async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
297 let mut sa_ctx = SaTokenContext::new();
298
299 if let Some(token_str) = extract_token_from_request(&req, &self.state) {
301 tracing::debug!("Sa-Token(permission-check): extracted token from request: {}", token_str);
302 let token = TokenValue::new(token_str);
303
304 if self.state.manager.is_valid(&token).await {
306 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
307 let login_id = token_info.login_id.clone();
308
309 if StpUtil::has_permission(&login_id, &self.permission).await {
311 req.extensions_mut().insert(token.clone());
313 req.extensions_mut().insert(login_id.clone());
314
315 sa_ctx.token = Some(token.clone());
317 sa_ctx.token_info = Some(Arc::new(token_info));
318 sa_ctx.login_id = Some(login_id);
319
320 SaTokenContext::set_current(sa_ctx);
321 let result = ctx.call(&self.service, req).await;
322 SaTokenContext::clear();
323 return result;
324 }
325 }
326 }
327 }
328
329 Err(WebError::from(InternalError::new(
331 json!({
332 "code": 403,
333 "message": messages::PERMISSION_REQUIRED
334 }).to_string(),
335 ntex::http::StatusCode::FORBIDDEN,
336 )))
337 }
338}
339
340pub struct SaCheckRoleMiddleware {
344 pub state: SaTokenState,
345 role: String,
346}
347
348impl SaCheckRoleMiddleware {
349 pub fn new(state: SaTokenState, role: impl Into<String>) -> Self {
350 Self {
351 state,
352 role: role.into(),
353 }
354 }
355}
356
357impl<S> Middleware<S> for SaCheckRoleMiddleware {
358 type Service = SaCheckRoleMiddlewareService<S>;
359
360 fn create(&self, service: S) -> Self::Service {
361 SaCheckRoleMiddlewareService {
362 service,
363 state: self.state.clone(),
364 role: self.role.clone(),
365 }
366 }
367}
368
369pub struct SaCheckRoleMiddlewareService<S> {
370 service: S,
371 state: SaTokenState,
372 role: String,
373}
374
375impl<S, Err> Service<WebRequest<Err>> for SaCheckRoleMiddlewareService<S>
376where
377 S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
378 Err: ErrorRenderer,
379{
380 type Response = WebResponse;
381 type Error = Error;
382
383 async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
384 let mut sa_ctx = SaTokenContext::new();
385
386 if let Some(token_str) = extract_token_from_request(&req, &self.state) {
388 tracing::debug!("Sa-Token(role-check): extracted token from request: {}", token_str);
389 let token = TokenValue::new(token_str);
390
391 if self.state.manager.is_valid(&token).await {
393 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
394 let login_id = token_info.login_id.clone();
395
396 if StpUtil::has_role(&login_id, &self.role).await {
398 req.extensions_mut().insert(token.clone());
400 req.extensions_mut().insert(login_id.clone());
401
402 sa_ctx.token = Some(token.clone());
404 sa_ctx.token_info = Some(Arc::new(token_info));
405 sa_ctx.login_id = Some(login_id);
406
407 SaTokenContext::set_current(sa_ctx);
408 let result = ctx.call(&self.service, req).await;
409 SaTokenContext::clear();
410 return result;
411 }
412 }
413 }
414 }
415
416 Err(WebError::from(InternalError::new(
419 json!({
420 "code": 403,
421 "message": messages::ROLE_REQUIRED
422 }).to_string(),
423 ntex::http::StatusCode::FORBIDDEN,
424 )))
425 }
426}
427
428#[deprecated(note = "Use SaCheckPermissionMiddleware instead")]
433pub struct PermissionMiddleware {
434 permission: String,
435}
436
437impl PermissionMiddleware {
438 pub fn new(permission: impl Into<String>) -> Self {
441 Self {
442 permission: permission.into(),
443 }
444 }
445}
446
447impl<S> Middleware<S> for PermissionMiddleware {
448 type Service = PermissionMiddlewareService<S>;
449
450 fn create(&self, service: S) -> Self::Service {
451 PermissionMiddlewareService {
452 service,
453 permission: self.permission.clone(),
454 }
455 }
456}
457
458pub struct PermissionMiddlewareService<S> {
459 service: S,
460 permission: String,
461}
462
463impl<S, Err> Service<WebRequest<Err>> for PermissionMiddlewareService<S>
464where
465 S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
466 Err: ErrorRenderer,
467{
468 type Response = WebResponse;
469 type Error = Error;
470
471 async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
472 let has_login_id = req.extensions().get::<String>().is_some();
479
480 if has_login_id {
481 let login_id = req.extensions().get::<String>().unwrap().clone();
482 if StpUtil::has_permission(&login_id, &self.permission).await {
484 return ctx.call(&self.service, req).await;
485 }
486 } else {
487 if let Some(token_str) = extract_token_from_request_simple(&req) {
490 let token = TokenValue::new(token_str);
491
492 if StpUtil::is_login(&token).await {
495 if let Ok(login_id) = StpUtil::get_login_id(&token).await {
496 if StpUtil::has_permission(&login_id, &self.permission).await {
498 req.extensions_mut().insert(login_id);
501 return ctx.call(&self.service, req).await;
502 }
503 }
504 }
505 }
506 }
507
508 Err(WebError::from(InternalError::new(
510 json!({
511 "code": 403,
512 "message": messages::PERMISSION_REQUIRED
513 }).to_string(),
514 ntex::http::StatusCode::FORBIDDEN,
515 )))
516 }
517}
518
519fn extract_token_from_request<Err>(req: &WebRequest<Err>, state: &SaTokenState) -> Option<String>
523where
524 Err: ErrorRenderer,
525{
526 let token_name = &state.manager.config.token_name;
527
528 if let Some(header_value) = req.headers().get(token_name) {
530 if let Ok(value_str) = header_value.to_str() {
531 if let Some(token) = extract_bearer_token(value_str) {
532 return Some(token);
533 }
534 }
535 }
536
537 if let Some(auth_header) = req.headers().get("authorization") {
539 if let Ok(auth_str) = auth_header.to_str() {
540 if let Some(token) = extract_bearer_token(auth_str) {
541 return Some(token);
542 }
543 }
544 }
545
546 if let Some(cookie_header) = req.headers().get("cookie") {
548 if let Ok(cookie_str) = cookie_header.to_str() {
549 let cookies = parse_cookies(cookie_str);
550 if let Some(token) = cookies.get(token_name) {
551 return Some(token.clone());
552 }
553 }
554 }
555
556 let query = req.query_string();
558 if !query.is_empty() {
559 let params = parse_query_string(query);
560 if let Some(token) = params.get(token_name) {
561 return Some(token.clone());
562 }
563 }
564
565 None
566}
567
568fn extract_token_from_request_simple<Err>(req: &WebRequest<Err>) -> Option<String>
572where
573 Err: ErrorRenderer,
574{
575 if let Some(auth_header) = req.headers().get("authorization") {
577 if let Ok(auth_str) = auth_header.to_str() {
578 if let Some(token) = extract_bearer_token(auth_str) {
579 return Some(token);
580 }
581 }
582 }
583
584 None
585}