sa_token_plugin_axum/
middleware.rs

1// Author: 金书记
2//
3//! 中间件实现
4//!
5//! 提供两种中间件:
6//! - `SaTokenMiddleware`:基础中间件,从请求中提取token并设置上下文
7//! - `SaCheckLoginMiddleware`:检查登录中间件,未登录时返回401错误
8
9use std::task::{Context, Poll};
10use tower::{Layer, Service};
11use http::{Request, Response, StatusCode};
12use http_body;
13use serde_json::json;
14use sa_token_core::error::messages;
15
16pub use crate::layer::SaTokenMiddleware;
17
18/// 检查登录中间件层
19#[derive(Clone)]
20pub struct SaCheckLoginLayer;
21
22impl SaCheckLoginLayer {
23    pub fn new() -> Self {
24        Self
25    }
26}
27
28impl<S> Layer<S> for SaCheckLoginLayer {
29    type Service = SaCheckLoginMiddleware<S>;
30    
31    fn layer(&self, inner: S) -> Self::Service {
32        SaCheckLoginMiddleware { inner }
33    }
34}
35
36/// 检查登录中间件
37/// 
38/// 如果请求未登录,直接返回401错误
39#[derive(Clone)]
40pub struct SaCheckLoginMiddleware<S> {
41    inner: S,
42}
43
44/// 检查权限中间件层
45#[derive(Clone)]
46pub struct SaCheckPermissionLayer {
47    permission: String,
48}
49
50impl SaCheckPermissionLayer {
51    pub fn new(permission: impl Into<String>) -> Self {
52        Self {
53            permission: permission.into(),
54        }
55    }
56}
57
58impl<S> Layer<S> for SaCheckPermissionLayer {
59    type Service = SaCheckPermissionMiddleware<S>;
60    
61    fn layer(&self, inner: S) -> Self::Service {
62        SaCheckPermissionMiddleware { 
63            inner,
64            permission: self.permission.clone(),
65        }
66    }
67}
68
69/// 检查权限中间件
70/// 
71/// 如果请求没有指定权限,直接返回403错误
72#[derive(Clone)]
73pub struct SaCheckPermissionMiddleware<S> {
74    inner: S,
75    permission: String,
76}
77
78impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SaCheckLoginMiddleware<S>
79where
80    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
81    S::Future: Send + 'static,
82    ReqBody: Send + 'static,
83    ResBody: http_body::Body + Default + Send + 'static,
84{
85    type Response = S::Response;
86    type Error = S::Error;
87    type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
88    
89    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90        self.inner.poll_ready(cx)
91    }
92    
93    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
94        let mut inner = self.inner.clone();
95        
96        Box::pin(async move {
97            // 检查是否有登录ID
98            if request.extensions().get::<String>().is_none() {
99                // 未登录,返回401错误
100                // 由于我们无法直接返回AxumResponse,这里使用一个hack方法
101                // 创建一个错误响应
102                let mut response = Response::builder()
103                    .status(StatusCode::UNAUTHORIZED)
104                    .body(ResBody::default())
105                    .expect("Unable to create response");
106                
107                // 添加错误信息
108                let error_json = serde_json::to_string(&json!({
109                    "code": 401,
110                    "message": messages::AUTH_ERROR
111                })).unwrap_or_default();
112                
113                // 添加到响应头中,这样上层可以读取
114                if let Ok(header_value) = http::header::HeaderValue::from_str(&error_json) {
115                    response.headers_mut().insert("X-Sa-Token-Error", header_value);
116                }
117                
118                return Ok(response);
119            }
120            
121            // 已登录,继续处理
122            inner.call(request).await
123        })
124    }
125}
126
127impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SaCheckPermissionMiddleware<S>
128where
129    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
130    S::Future: Send + 'static,
131    ReqBody: Send + 'static,
132    ResBody: http_body::Body + Default + Send + 'static,
133{
134    type Response = S::Response;
135    type Error = S::Error;
136    type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
137    
138    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139        self.inner.poll_ready(cx)
140    }
141    
142    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
143        let mut inner = self.inner.clone();
144        let permission = self.permission.clone();
145        
146        Box::pin(async move {
147            // 检查是否有登录ID
148            if let Some(login_id) = request.extensions().get::<String>() {
149                // 检查权限
150                if sa_token_core::StpUtil::has_permission(login_id, &permission).await {
151                    // 有权限,继续处理
152                    return inner.call(request).await;
153                }
154            }
155            
156            // 无权限或未登录,返回403错误
157            let mut response = Response::builder()
158                .status(StatusCode::FORBIDDEN)
159                .body(ResBody::default())
160                .expect("Unable to create response");
161            
162            // 添加错误信息
163            let error_json = serde_json::to_string(&json!({
164                "code": 403,
165                "message": messages::PERMISSION_REQUIRED
166            })).unwrap_or_default();
167            
168            // 添加到响应头中,这样上层可以读取
169            if let Ok(header_value) = http::header::HeaderValue::from_str(&error_json) {
170                response.headers_mut().insert("X-Sa-Token-Error", header_value);
171            }
172            
173            Ok(response)
174        })
175    }
176}