sa_token_plugin_actix_web/
layer.rs

1use std::future::{ready, Ready, Future};
2use std::pin::Pin;
3use std::rc::Rc;
4use actix_web::{
5    dev::{Service, ServiceRequest, ServiceResponse, Transform},
6    Error, HttpMessage,
7};
8use crate::SaTokenState;
9use crate::adapter::ActixRequestAdapter;
10use sa_token_adapter::context::SaRequest;
11use sa_token_core::{token::TokenValue, SaTokenContext};
12use std::sync::Arc;
13
14#[derive(Clone)]
15pub struct SaTokenLayer {
16    state: SaTokenState,
17}
18
19impl SaTokenLayer {
20    pub fn new(state: SaTokenState) -> Self {
21        Self { state }
22    }
23}
24
25impl<S, B> Transform<S, ServiceRequest> for SaTokenLayer
26where
27    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
28    S::Future: 'static,
29    B: 'static,
30{
31    type Response = ServiceResponse<B>;
32    type Error = Error;
33    type InitError = ();
34    type Transform = SaTokenLayerService<S>;
35    type Future = Ready<Result<Self::Transform, Self::InitError>>;
36    
37    fn new_transform(&self, service: S) -> Self::Future {
38        ready(Ok(SaTokenLayerService {
39            service: Rc::new(service),
40            state: self.state.clone(),
41        }))
42    }
43}
44
45pub struct SaTokenLayerService<S> {
46    service: Rc<S>,
47    state: SaTokenState,
48}
49
50impl<S, B> Service<ServiceRequest> for SaTokenLayerService<S>
51where
52    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
53    S::Future: 'static,
54    B: 'static,
55{
56    type Response = ServiceResponse<B>;
57    type Error = Error;
58    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
59    
60    fn poll_ready(&self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
61        self.service.poll_ready(cx)
62    }
63    
64    fn call(&self, req: ServiceRequest) -> Self::Future {
65        let service = Rc::clone(&self.service);
66        let state = self.state.clone();
67        
68        Box::pin(async move {
69            let mut ctx = SaTokenContext::new();
70            
71            if let Some(token_str) = extract_token_from_request(&req, &state) {
72                tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
73                let token = TokenValue::new(token_str);
74                
75                if state.manager.is_valid(&token).await {
76                    req.extensions_mut().insert(token.clone());
77                    
78                    if let Ok(token_info) = state.manager.get_token_info(&token).await {
79                        let login_id = token_info.login_id.clone();
80                        req.extensions_mut().insert(login_id.clone());
81                        
82                        ctx.token = Some(token.clone());
83                        ctx.token_info = Some(Arc::new(token_info));
84                        ctx.login_id = Some(login_id);
85                    }
86                }
87            }
88            
89            SaTokenContext::set_current(ctx);
90            let result = service.call(req).await;
91            SaTokenContext::clear();
92            result
93        })
94    }
95}
96
97fn extract_token_from_request(req: &ServiceRequest, state: &SaTokenState) -> Option<String> {
98    let adapter = ActixRequestAdapter::new(req.request());
99    let token_name = &state.manager.config.token_name;
100    
101    if let Some(token) = adapter.get_header(token_name) {
102        return Some(extract_bearer_token(&token));
103    }
104    
105    if let Some(token) = adapter.get_cookie(token_name) {
106        return Some(token);
107    }
108    
109    if let Some(query) = req.query_string().split('&').find_map(|pair| {
110        let mut parts = pair.split('=');
111        if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
112            if key == token_name {
113                return urlencoding::decode(value).ok().map(|s| s.to_string());
114            }
115        }
116        None
117    }) {
118        return Some(query);
119    }
120    
121    None
122}
123
124fn extract_bearer_token(token: &str) -> String {
125    if token.starts_with("Bearer ") {
126        token[7..].to_string()
127    } else {
128        token.to_string()
129    }
130}