sa_token_plugin_actix_web/
middleware.rs

1// Author: 金书记
2//
3//! Actix-web中间件
4
5use std::future::{ready, Ready, Future};
6use std::pin::Pin;
7use std::rc::Rc;
8use actix_web::{
9    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
10    Error, HttpMessage, error::ErrorUnauthorized,
11};
12use crate::SaTokenState;
13use crate::adapter::ActixRequestAdapter;
14use sa_token_adapter::context::SaRequest;
15use sa_token_core::{token::TokenValue, SaTokenContext, error::messages};
16use std::sync::Arc;
17
18/// sa-token 基础中间件 - 提取并验证 token
19pub struct SaTokenMiddleware {
20    pub state: SaTokenState,
21}
22
23impl SaTokenMiddleware {
24    pub fn new(state: SaTokenState) -> Self {
25        Self { state }
26    }
27}
28
29impl<S, B> Transform<S, ServiceRequest> for SaTokenMiddleware
30where
31    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
32    S::Future: 'static,
33    B: 'static,
34{
35    type Response = ServiceResponse<B>;
36    type Error = Error;
37    type InitError = ();
38    type Transform = SaTokenMiddlewareService<S>;
39    type Future = Ready<Result<Self::Transform, Self::InitError>>;
40    
41    fn new_transform(&self, service: S) -> Self::Future {
42        ready(Ok(SaTokenMiddlewareService {
43            service: Rc::new(service),
44            state: self.state.clone(),
45        }))
46    }
47}
48
49pub struct SaTokenMiddlewareService<S> {
50    service: Rc<S>,
51    state: SaTokenState,
52}
53
54impl<S, B> Service<ServiceRequest> for SaTokenMiddlewareService<S>
55where
56    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
57    S::Future: 'static,
58    B: 'static,
59{
60    type Response = ServiceResponse<B>;
61    type Error = Error;
62    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
63    
64    forward_ready!(service);
65    
66    fn call(&self, req: ServiceRequest) -> Self::Future {
67        let service = Rc::clone(&self.service);
68        let state = self.state.clone();
69        
70        Box::pin(async move {
71            let mut ctx = SaTokenContext::new();
72            // 提取 token
73            if let Some(token_str) = extract_token_from_request(&req, &state) {
74                tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
75                let token = TokenValue::new(token_str);
76                
77                // 验证 token
78                if state.manager.is_valid(&token).await {
79                    // 存储 token
80                    req.extensions_mut().insert(token.clone());
81                    
82                    // 获取并存储 login_id
83                    if let Ok(token_info) = state.manager.get_token_info(&token).await {
84                        let login_id = token_info.login_id.clone();
85                        req.extensions_mut().insert(login_id.clone());
86                        ctx.token = Some(token.clone());
87                        ctx.token_info = Some(Arc::new(token_info));
88                        ctx.login_id = Some(login_id);
89                    }
90                }
91            }
92            
93            SaTokenContext::set_current(ctx);
94            let result = service.call(req).await;
95            SaTokenContext::clear();
96            result
97        })
98    }
99}
100
101/// sa-token 登录检查中间件 - 强制要求登录
102pub struct SaCheckLoginMiddleware {
103    pub state: SaTokenState,
104}
105
106impl SaCheckLoginMiddleware {
107    pub fn new(state: SaTokenState) -> Self {
108        Self { state }
109    }
110}
111
112impl<S, B> Transform<S, ServiceRequest> for SaCheckLoginMiddleware
113where
114    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
115    S::Future: 'static,
116    B: 'static,
117{
118    type Response = ServiceResponse<B>;
119    type Error = Error;
120    type InitError = ();
121    type Transform = SaCheckLoginMiddlewareService<S>;
122    type Future = Ready<Result<Self::Transform, Self::InitError>>;
123    
124    fn new_transform(&self, service: S) -> Self::Future {
125        ready(Ok(SaCheckLoginMiddlewareService {
126            service: Rc::new(service),
127            state: self.state.clone(),
128        }))
129    }
130}
131
132pub struct SaCheckLoginMiddlewareService<S> {
133    service: Rc<S>,
134    state: SaTokenState,
135}
136
137impl<S, B> Service<ServiceRequest> for SaCheckLoginMiddlewareService<S>
138where
139    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
140    S::Future: 'static,
141    B: 'static,
142{
143    type Response = ServiceResponse<B>;
144    type Error = Error;
145    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
146    
147    forward_ready!(service);
148    
149    fn call(&self, req: ServiceRequest) -> Self::Future {
150        let service = Rc::clone(&self.service);
151        let state = self.state.clone();
152        
153        Box::pin(async move {
154            let mut ctx = SaTokenContext::new();
155            // 提取 token
156            if let Some(token_str) = extract_token_from_request(&req, &state) {
157                tracing::debug!("Sa-Token(login-check): extracted token from request: {}", token_str);
158                let token = TokenValue::new(token_str);
159                
160                // 验证 token
161                if state.manager.is_valid(&token).await {
162                    // 存储 token 和 login_id
163                    req.extensions_mut().insert(token.clone());
164                    
165                    if let Ok(token_info) = state.manager.get_token_info(&token).await {
166                        let login_id = token_info.login_id.clone();
167                        req.extensions_mut().insert(login_id.clone());
168                        ctx.token = Some(token.clone());
169                        ctx.token_info = Some(Arc::new(token_info));
170                        ctx.login_id = Some(login_id);
171                        
172                        // 设置上下文
173                        SaTokenContext::set_current(ctx);
174                        let result = service.call(req).await;
175                        SaTokenContext::clear();
176                        return result;
177                    }
178                }
179            }
180            
181            // 未登录,返回 401
182            Err(ErrorUnauthorized(serde_json::json!({
183                "code": 401,
184                "message": messages::AUTH_ERROR
185            }).to_string()))
186        })
187    }
188}
189
190/// 从请求中提取 token
191fn extract_token_from_request(req: &ServiceRequest, state: &SaTokenState) -> Option<String> {
192    let adapter = ActixRequestAdapter::new(req.request());
193    let token_name = &state.manager.config.token_name;
194    
195    // 1. 优先从 Header 中获取
196    if let Some(token) = adapter.get_header(token_name) {
197        return Some(extract_bearer_token(&token));
198    }
199    
200    // 2. 从 Cookie 中获取
201    if let Some(token) = adapter.get_cookie(token_name) {
202        return Some(token);
203    }
204    
205    // 3. 从 Query 参数中获取
206    if let Some(query) = req.query_string().split('&').find_map(|pair| {
207        let mut parts = pair.split('=');
208        if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
209            if key == token_name {
210                return urlencoding::decode(value).ok().map(|s| s.to_string());
211            }
212        }
213        None
214    }) {
215        return Some(query);
216    }
217    
218    None
219}
220
221/// 提取 Bearer token
222fn extract_bearer_token(token: &str) -> String {
223    if token.starts_with("Bearer ") {
224        token[7..].to_string()
225    } else {
226        token.to_string()
227    }
228}