sa_token_plugin_actix_web/
layer.rs1use std::future::{ready, Ready, Future};
2use std::pin::Pin;
3use std::rc::Rc;
4use actix_web::{
5 dev::{Service, ServiceRequest, ServiceResponse, Transform},
6 Error, HttpMessage,
7};
8use crate::SaTokenState;
9use crate::adapter::ActixRequestAdapter;
10use sa_token_adapter::context::SaRequest;
11use sa_token_core::{token::TokenValue, SaTokenContext};
12use std::sync::Arc;
13
14#[derive(Clone)]
15pub struct SaTokenLayer {
16 state: SaTokenState,
17}
18
19impl SaTokenLayer {
20 pub fn new(state: SaTokenState) -> Self {
21 Self { state }
22 }
23}
24
25impl<S, B> Transform<S, ServiceRequest> for SaTokenLayer
26where
27 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
28 S::Future: 'static,
29 B: 'static,
30{
31 type Response = ServiceResponse<B>;
32 type Error = Error;
33 type InitError = ();
34 type Transform = SaTokenLayerService<S>;
35 type Future = Ready<Result<Self::Transform, Self::InitError>>;
36
37 fn new_transform(&self, service: S) -> Self::Future {
38 ready(Ok(SaTokenLayerService {
39 service: Rc::new(service),
40 state: self.state.clone(),
41 }))
42 }
43}
44
45pub struct SaTokenLayerService<S> {
46 service: Rc<S>,
47 state: SaTokenState,
48}
49
50impl<S, B> Service<ServiceRequest> for SaTokenLayerService<S>
51where
52 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
53 S::Future: 'static,
54 B: 'static,
55{
56 type Response = ServiceResponse<B>;
57 type Error = Error;
58 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
59
60 fn poll_ready(&self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
61 self.service.poll_ready(cx)
62 }
63
64 fn call(&self, req: ServiceRequest) -> Self::Future {
65 let service = Rc::clone(&self.service);
66 let state = self.state.clone();
67
68 Box::pin(async move {
69 let mut ctx = SaTokenContext::new();
70
71 if let Some(token_str) = extract_token_from_request(&req, &state) {
72 tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
73 let token = TokenValue::new(token_str);
74
75 if state.manager.is_valid(&token).await {
76 req.extensions_mut().insert(token.clone());
77
78 if let Ok(token_info) = state.manager.get_token_info(&token).await {
79 let login_id = token_info.login_id.clone();
80 req.extensions_mut().insert(login_id.clone());
81
82 ctx.token = Some(token.clone());
83 ctx.token_info = Some(Arc::new(token_info));
84 ctx.login_id = Some(login_id);
85 }
86 }
87 }
88
89 SaTokenContext::set_current(ctx);
90 let result = service.call(req).await;
91 SaTokenContext::clear();
92 result
93 })
94 }
95}
96
97fn extract_token_from_request(req: &ServiceRequest, state: &SaTokenState) -> Option<String> {
98 let adapter = ActixRequestAdapter::new(req.request());
99 let token_name = &state.manager.config.token_name;
100
101 if let Some(token) = adapter.get_header(token_name) {
102 return Some(extract_bearer_token(&token));
103 }
104
105 if let Some(token) = adapter.get_cookie(token_name) {
106 return Some(token);
107 }
108
109 if let Some(query) = req.query_string().split('&').find_map(|pair| {
110 let mut parts = pair.split('=');
111 if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
112 if key == token_name {
113 return urlencoding::decode(value).ok().map(|s| s.to_string());
114 }
115 }
116 None
117 }) {
118 return Some(query);
119 }
120
121 None
122}
123
124fn extract_bearer_token(token: &str) -> String {
125 if token.starts_with("Bearer ") {
126 token[7..].to_string()
127 } else {
128 token.to_string()
129 }
130}