sa_token_plugin_actix_web/
middleware.rs1use std::future::{ready, Ready, Future};
6use std::pin::Pin;
7use std::rc::Rc;
8use actix_web::{
9 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
10 Error, HttpMessage, error::ErrorUnauthorized,
11};
12use crate::SaTokenState;
13use crate::adapter::ActixRequestAdapter;
14use sa_token_adapter::context::SaRequest;
15use sa_token_core::{token::TokenValue, SaTokenContext, error::messages};
16use std::sync::Arc;
17
18use sa_token_core::router::PathAuthConfig;
20
21pub struct SaTokenMiddleware {
24 pub state: SaTokenState,
25 pub path_config: Option<PathAuthConfig>,
28}
29
30impl SaTokenMiddleware {
31 pub fn new(state: SaTokenState) -> Self {
34 Self { state, path_config: None }
35 }
36
37 pub fn with_path_auth(state: SaTokenState, config: PathAuthConfig) -> Self {
40 Self { state, path_config: Some(config) }
41 }
42}
43
44impl<S, B> Transform<S, ServiceRequest> for SaTokenMiddleware
45where
46 S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
47 S::Future: 'static,
48 B: 'static,
49{
50 type Response = ServiceResponse<B>;
51 type Error = Error;
52 type InitError = ();
53 type Transform = SaTokenMiddlewareService<S>;
54 type Future = Ready<Result<Self::Transform, Self::InitError>>;
55
56 fn new_transform(&self, service: S) -> Self::Future {
57 ready(Ok(SaTokenMiddlewareService {
58 service: Rc::new(service),
59 state: self.state.clone(),
60 path_config: self.path_config.clone(),
61 }))
62 }
63}
64
65pub struct SaTokenMiddlewareService<S> {
68 service: Rc<S>,
69 state: SaTokenState,
70 path_config: Option<PathAuthConfig>,
73}
74
75impl<S, B> Service<ServiceRequest> for SaTokenMiddlewareService<S>
76where
77 S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
78 S::Future: 'static,
79 B: 'static,
80{
81 type Response = ServiceResponse<B>;
82 type Error = Error;
83 type Future = Pin<Box<dyn Future<Output=Result<Self::Response, Self::Error>>>>;
84
85 forward_ready!(service);
86
87 fn call(&self, req: ServiceRequest) -> Self::Future {
88 let service = Rc::clone(&self.service);
89 let state = self.state.clone();
90 let path_config = self.path_config.clone();
91
92 Box::pin(async move {
93 if let Some(config) = path_config {
94 let path = req.path();
95 let token_str = extract_token_from_request(&req, &state);
96 let result = sa_token_core::router::process_auth(path, token_str, &config, &state.manager).await;
97
98 if result.should_reject() {
99 return Err(ErrorUnauthorized(serde_json::json!({"code": 401, "message": messages::AUTH_ERROR}).to_string()));
100 }
101
102 if let Some(token) = &result.token {
103 req.extensions_mut().insert(token.clone());
104 }
105 if let Some(login_id) = result.login_id() {
106 req.extensions_mut().insert(login_id.to_string());
107 }
108
109 let ctx = sa_token_core::router::create_context(&result);
110 SaTokenContext::set_current(ctx);
111 let response = service.call(req).await;
112 SaTokenContext::clear();
113 return response;
114 }
115
116 let mut ctx = SaTokenContext::new();
117 if let Some(token_str) = extract_token_from_request(&req, &state) {
118 let token = TokenValue::new(token_str);
119 if state.manager.is_valid(&token).await {
120 req.extensions_mut().insert(token.clone());
121 if let Ok(token_info) = state.manager.get_token_info(&token).await {
122 let login_id = token_info.login_id.clone();
123 req.extensions_mut().insert(login_id.clone());
124 ctx.token = Some(token.clone());
125 ctx.token_info = Some(Arc::new(token_info));
126 ctx.login_id = Some(login_id);
127 }
128 }
129 }
130
131 SaTokenContext::set_current(ctx);
132 let result = service.call(req).await;
133 SaTokenContext::clear();
134 result
135 })
136 }
137}
138
139pub struct SaCheckLoginMiddleware {
141 pub state: SaTokenState,
142}
143
144impl SaCheckLoginMiddleware {
145 pub fn new(state: SaTokenState) -> Self {
146 Self { state }
147 }
148}
149
150impl<S, B> Transform<S, ServiceRequest> for SaCheckLoginMiddleware
151where
152 S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
153 S::Future: 'static,
154 B: 'static,
155{
156 type Response = ServiceResponse<B>;
157 type Error = Error;
158 type InitError = ();
159 type Transform = SaCheckLoginMiddlewareService<S>;
160 type Future = Ready<Result<Self::Transform, Self::InitError>>;
161
162 fn new_transform(&self, service: S) -> Self::Future {
163 ready(Ok(SaCheckLoginMiddlewareService {
164 service: Rc::new(service),
165 state: self.state.clone(),
166 }))
167 }
168}
169
170pub struct SaCheckLoginMiddlewareService<S> {
171 service: Rc<S>,
172 state: SaTokenState,
173}
174
175impl<S, B> Service<ServiceRequest> for SaCheckLoginMiddlewareService<S>
176where
177 S: Service<ServiceRequest, Response=ServiceResponse<B>, Error=Error> + 'static,
178 S::Future: 'static,
179 B: 'static,
180{
181 type Response = ServiceResponse<B>;
182 type Error = Error;
183 type Future = Pin<Box<dyn Future<Output=Result<Self::Response, Self::Error>>>>;
184
185 forward_ready!(service);
186
187 fn call(&self, req: ServiceRequest) -> Self::Future {
188 let service = Rc::clone(&self.service);
189 let state = self.state.clone();
190
191 Box::pin(async move {
192 let mut ctx = SaTokenContext::new();
193 if let Some(token_str) = extract_token_from_request(&req, &state) {
195 tracing::debug!("Sa-Token(login-check): extracted token from request: {}", token_str);
196 let token = TokenValue::new(token_str);
197
198 if state.manager.is_valid(&token).await {
200 req.extensions_mut().insert(token.clone());
202
203 if let Ok(token_info) = state.manager.get_token_info(&token).await {
204 let login_id = token_info.login_id.clone();
205 req.extensions_mut().insert(login_id.clone());
206 ctx.token = Some(token.clone());
207 ctx.token_info = Some(Arc::new(token_info));
208 ctx.login_id = Some(login_id);
209
210 SaTokenContext::set_current(ctx);
212 let result = service.call(req).await;
213 SaTokenContext::clear();
214 return result;
215 }
216 }
217 }
218
219 Err(ErrorUnauthorized(serde_json::json!({
221 "code": 401,
222 "message": messages::AUTH_ERROR
223 }).to_string()))
224 })
225 }
226}
227
228pub fn extract_token_from_request(req: &ServiceRequest, state: &SaTokenState) -> Option<String> {
230 let adapter = ActixRequestAdapter::new(req.request());
231 let token_name = &state.manager.config.token_name;
232
233 tracing::debug!("Sa-Token: 尝试从请求提取 token,token_name: {}", token_name);
234
235 if let Some(token) = adapter.get_header(token_name) {
237 tracing::debug!("Sa-Token: 从 Header[{}] 获取到 token", token_name);
238 return Some(extract_bearer_token(&token));
239 }
240
241 if token_name != "Authorization" {
243 if let Some(token) = adapter.get_header("Authorization") {
244 tracing::debug!("Sa-Token: 从 Header[Authorization] 获取到 token");
245 return Some(extract_bearer_token(&token));
246 }
247 }
248
249 if let Some(token) = adapter.get_cookie(token_name) {
251 tracing::debug!("Sa-Token: 从 Cookie[{}] 获取到 token", token_name);
252 return Some(token);
253 }
254
255 if let Some(query) = req.query_string().split('&').find_map(|pair| {
257 let mut parts = pair.split('=');
258 if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
259 if key == token_name {
260 return urlencoding::decode(value).ok().map(|s| s.to_string());
261 }
262 }
263 None
264 }) {
265 tracing::debug!("Sa-Token: 从 Query[{}] 获取到 token", token_name);
266 return Some(query);
267 }
268
269 tracing::debug!("Sa-Token: 所有位置都未找到 token");
270 None
271}
272
273fn extract_bearer_token(token: &str) -> String {
274 if token.starts_with("Bearer ") {
275 token[7..].to_string()
276 } else {
277 token.to_string()
278 }
279}