sa_token_plugin_actix_web/
middleware.rs1use std::future::{ready, Ready, Future};
6use std::pin::Pin;
7use std::rc::Rc;
8use actix_web::{
9 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
10 Error, HttpMessage, error::ErrorUnauthorized,
11};
12use crate::SaTokenState;
13use crate::adapter::ActixRequestAdapter;
14use sa_token_adapter::context::SaRequest;
15use sa_token_core::{token::TokenValue, SaTokenContext, error::messages};
16use std::sync::Arc;
17
18pub struct SaTokenMiddleware {
20 pub state: SaTokenState,
21}
22
23impl SaTokenMiddleware {
24 pub fn new(state: SaTokenState) -> Self {
25 Self { state }
26 }
27}
28
29impl<S, B> Transform<S, ServiceRequest> for SaTokenMiddleware
30where
31 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
32 S::Future: 'static,
33 B: 'static,
34{
35 type Response = ServiceResponse<B>;
36 type Error = Error;
37 type InitError = ();
38 type Transform = SaTokenMiddlewareService<S>;
39 type Future = Ready<Result<Self::Transform, Self::InitError>>;
40
41 fn new_transform(&self, service: S) -> Self::Future {
42 ready(Ok(SaTokenMiddlewareService {
43 service: Rc::new(service),
44 state: self.state.clone(),
45 }))
46 }
47}
48
49pub struct SaTokenMiddlewareService<S> {
50 service: Rc<S>,
51 state: SaTokenState,
52}
53
54impl<S, B> Service<ServiceRequest> for SaTokenMiddlewareService<S>
55where
56 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
57 S::Future: 'static,
58 B: 'static,
59{
60 type Response = ServiceResponse<B>;
61 type Error = Error;
62 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
63
64 forward_ready!(service);
65
66 fn call(&self, req: ServiceRequest) -> Self::Future {
67 let service = Rc::clone(&self.service);
68 let state = self.state.clone();
69
70 Box::pin(async move {
71 let mut ctx = SaTokenContext::new();
72 if let Some(token_str) = extract_token_from_request(&req, &state) {
74 tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
75 let token = TokenValue::new(token_str);
76
77 if state.manager.is_valid(&token).await {
79 req.extensions_mut().insert(token.clone());
81
82 if let Ok(token_info) = state.manager.get_token_info(&token).await {
84 let login_id = token_info.login_id.clone();
85 req.extensions_mut().insert(login_id.clone());
86 ctx.token = Some(token.clone());
87 ctx.token_info = Some(Arc::new(token_info));
88 ctx.login_id = Some(login_id);
89 }
90 }
91 }
92
93 SaTokenContext::set_current(ctx);
94 let result = service.call(req).await;
95 SaTokenContext::clear();
96 result
97 })
98 }
99}
100
101pub struct SaCheckLoginMiddleware {
103 pub state: SaTokenState,
104}
105
106impl SaCheckLoginMiddleware {
107 pub fn new(state: SaTokenState) -> Self {
108 Self { state }
109 }
110}
111
112impl<S, B> Transform<S, ServiceRequest> for SaCheckLoginMiddleware
113where
114 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
115 S::Future: 'static,
116 B: 'static,
117{
118 type Response = ServiceResponse<B>;
119 type Error = Error;
120 type InitError = ();
121 type Transform = SaCheckLoginMiddlewareService<S>;
122 type Future = Ready<Result<Self::Transform, Self::InitError>>;
123
124 fn new_transform(&self, service: S) -> Self::Future {
125 ready(Ok(SaCheckLoginMiddlewareService {
126 service: Rc::new(service),
127 state: self.state.clone(),
128 }))
129 }
130}
131
132pub struct SaCheckLoginMiddlewareService<S> {
133 service: Rc<S>,
134 state: SaTokenState,
135}
136
137impl<S, B> Service<ServiceRequest> for SaCheckLoginMiddlewareService<S>
138where
139 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
140 S::Future: 'static,
141 B: 'static,
142{
143 type Response = ServiceResponse<B>;
144 type Error = Error;
145 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
146
147 forward_ready!(service);
148
149 fn call(&self, req: ServiceRequest) -> Self::Future {
150 let service = Rc::clone(&self.service);
151 let state = self.state.clone();
152
153 Box::pin(async move {
154 let mut ctx = SaTokenContext::new();
155 if let Some(token_str) = extract_token_from_request(&req, &state) {
157 tracing::debug!("Sa-Token(login-check): extracted token from request: {}", token_str);
158 let token = TokenValue::new(token_str);
159
160 if state.manager.is_valid(&token).await {
162 req.extensions_mut().insert(token.clone());
164
165 if let Ok(token_info) = state.manager.get_token_info(&token).await {
166 let login_id = token_info.login_id.clone();
167 req.extensions_mut().insert(login_id.clone());
168 ctx.token = Some(token.clone());
169 ctx.token_info = Some(Arc::new(token_info));
170 ctx.login_id = Some(login_id);
171
172 SaTokenContext::set_current(ctx);
174 let result = service.call(req).await;
175 SaTokenContext::clear();
176 return result;
177 }
178 }
179 }
180
181 Err(ErrorUnauthorized(serde_json::json!({
183 "code": 401,
184 "message": messages::AUTH_ERROR
185 }).to_string()))
186 })
187 }
188}
189
190fn extract_token_from_request(req: &ServiceRequest, state: &SaTokenState) -> Option<String> {
192 let adapter = ActixRequestAdapter::new(req.request());
193 let token_name = &state.manager.config.token_name;
194
195 if let Some(token) = adapter.get_header(token_name) {
197 return Some(extract_bearer_token(&token));
198 }
199
200 if let Some(token) = adapter.get_cookie(token_name) {
202 return Some(token);
203 }
204
205 if let Some(query) = req.query_string().split('&').find_map(|pair| {
207 let mut parts = pair.split('=');
208 if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
209 if key == token_name {
210 return urlencoding::decode(value).ok().map(|s| s.to_string());
211 }
212 }
213 None
214 }) {
215 return Some(query);
216 }
217
218 None
219}
220
221fn extract_bearer_token(token: &str) -> String {
223 if token.starts_with("Bearer ") {
224 token[7..].to_string()
225 } else {
226 token.to_string()
227 }
228}