sa_token_plugin_actix_web/
middleware.rs

1// Author: 金书记
2//
3//! Actix-web中间件
4
5use std::future::{ready, Ready, Future};
6use std::pin::Pin;
7use std::rc::Rc;
8use actix_web::{
9    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
10    Error, HttpMessage, error::ErrorUnauthorized,
11};
12use crate::SaTokenState;
13use crate::adapter::ActixRequestAdapter;
14use sa_token_adapter::context::SaRequest;
15use sa_token_core::token::TokenValue;
16
17/// sa-token 基础中间件 - 提取并验证 token
18pub struct SaTokenMiddleware {
19    pub state: SaTokenState,
20}
21
22impl SaTokenMiddleware {
23    pub fn new(state: SaTokenState) -> Self {
24        Self { state }
25    }
26}
27
28impl<S, B> Transform<S, ServiceRequest> for SaTokenMiddleware
29where
30    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
31    S::Future: 'static,
32    B: 'static,
33{
34    type Response = ServiceResponse<B>;
35    type Error = Error;
36    type InitError = ();
37    type Transform = SaTokenMiddlewareService<S>;
38    type Future = Ready<Result<Self::Transform, Self::InitError>>;
39    
40    fn new_transform(&self, service: S) -> Self::Future {
41        ready(Ok(SaTokenMiddlewareService {
42            service: Rc::new(service),
43            state: self.state.clone(),
44        }))
45    }
46}
47
48pub struct SaTokenMiddlewareService<S> {
49    service: Rc<S>,
50    state: SaTokenState,
51}
52
53impl<S, B> Service<ServiceRequest> for SaTokenMiddlewareService<S>
54where
55    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
56    S::Future: 'static,
57    B: 'static,
58{
59    type Response = ServiceResponse<B>;
60    type Error = Error;
61    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
62    
63    forward_ready!(service);
64    
65    fn call(&self, req: ServiceRequest) -> Self::Future {
66        let service = Rc::clone(&self.service);
67        let state = self.state.clone();
68        
69        Box::pin(async move {
70            // 提取 token
71            if let Some(token_str) = extract_token_from_request(&req, &state) {
72                let token = TokenValue::new(token_str);
73                
74                // 验证 token
75                if state.manager.is_valid(&token).await {
76                    // 存储 token
77                    req.extensions_mut().insert(token.clone());
78                    
79                    // 获取并存储 login_id
80                    if let Ok(token_info) = state.manager.get_token_info(&token).await {
81                        req.extensions_mut().insert(token_info.login_id.clone());
82                    }
83                }
84            }
85            
86            service.call(req).await
87        })
88    }
89}
90
91/// sa-token 登录检查中间件 - 强制要求登录
92pub struct SaCheckLoginMiddleware {
93    pub state: SaTokenState,
94}
95
96impl SaCheckLoginMiddleware {
97    pub fn new(state: SaTokenState) -> Self {
98        Self { state }
99    }
100}
101
102impl<S, B> Transform<S, ServiceRequest> for SaCheckLoginMiddleware
103where
104    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
105    S::Future: 'static,
106    B: 'static,
107{
108    type Response = ServiceResponse<B>;
109    type Error = Error;
110    type InitError = ();
111    type Transform = SaCheckLoginMiddlewareService<S>;
112    type Future = Ready<Result<Self::Transform, Self::InitError>>;
113    
114    fn new_transform(&self, service: S) -> Self::Future {
115        ready(Ok(SaCheckLoginMiddlewareService {
116            service: Rc::new(service),
117            state: self.state.clone(),
118        }))
119    }
120}
121
122pub struct SaCheckLoginMiddlewareService<S> {
123    service: Rc<S>,
124    state: SaTokenState,
125}
126
127impl<S, B> Service<ServiceRequest> for SaCheckLoginMiddlewareService<S>
128where
129    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
130    S::Future: 'static,
131    B: 'static,
132{
133    type Response = ServiceResponse<B>;
134    type Error = Error;
135    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
136    
137    forward_ready!(service);
138    
139    fn call(&self, req: ServiceRequest) -> Self::Future {
140        let service = Rc::clone(&self.service);
141        let state = self.state.clone();
142        
143        Box::pin(async move {
144            // 提取 token
145            if let Some(token_str) = extract_token_from_request(&req, &state) {
146                let token = TokenValue::new(token_str);
147                
148                // 验证 token
149                if state.manager.is_valid(&token).await {
150                    // 存储 token 和 login_id
151                    req.extensions_mut().insert(token.clone());
152                    
153                    if let Ok(token_info) = state.manager.get_token_info(&token).await {
154                        req.extensions_mut().insert(token_info.login_id.clone());
155                    }
156                    
157                    return service.call(req).await;
158                }
159            }
160            
161            // 未登录,返回 401
162            Err(ErrorUnauthorized(serde_json::json!({
163                "code": 401,
164                "message": "未登录"
165            }).to_string()))
166        })
167    }
168}
169
170/// 从请求中提取 token
171fn extract_token_from_request(req: &ServiceRequest, state: &SaTokenState) -> Option<String> {
172    let adapter = ActixRequestAdapter::new(req.request());
173    let token_name = &state.manager.config.token_name;
174    
175    // 1. 优先从 Header 中获取
176    if let Some(token) = adapter.get_header(token_name) {
177        return Some(extract_bearer_token(&token));
178    }
179    
180    // 2. 从 Cookie 中获取
181    if let Some(token) = adapter.get_cookie(token_name) {
182        return Some(token);
183    }
184    
185    // 3. 从 Query 参数中获取
186    if let Some(query) = req.query_string().split('&').find_map(|pair| {
187        let mut parts = pair.split('=');
188        if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
189            if key == token_name {
190                return urlencoding::decode(value).ok().map(|s| s.to_string());
191            }
192        }
193        None
194    }) {
195        return Some(query);
196    }
197    
198    None
199}
200
201/// 提取 Bearer token
202fn extract_bearer_token(token: &str) -> String {
203    if token.starts_with("Bearer ") {
204        token[7..].to_string()
205    } else {
206        token.to_string()
207    }
208}