sa_token_plugin_actix_web/
middleware.rs1use std::future::{ready, Ready, Future};
6use std::pin::Pin;
7use std::rc::Rc;
8use actix_web::{
9 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
10 Error, HttpMessage, error::ErrorUnauthorized,
11};
12use crate::SaTokenState;
13use crate::adapter::ActixRequestAdapter;
14use sa_token_adapter::context::SaRequest;
15use sa_token_core::token::TokenValue;
16
17pub struct SaTokenMiddleware {
19 pub state: SaTokenState,
20}
21
22impl SaTokenMiddleware {
23 pub fn new(state: SaTokenState) -> Self {
24 Self { state }
25 }
26}
27
28impl<S, B> Transform<S, ServiceRequest> for SaTokenMiddleware
29where
30 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
31 S::Future: 'static,
32 B: 'static,
33{
34 type Response = ServiceResponse<B>;
35 type Error = Error;
36 type InitError = ();
37 type Transform = SaTokenMiddlewareService<S>;
38 type Future = Ready<Result<Self::Transform, Self::InitError>>;
39
40 fn new_transform(&self, service: S) -> Self::Future {
41 ready(Ok(SaTokenMiddlewareService {
42 service: Rc::new(service),
43 state: self.state.clone(),
44 }))
45 }
46}
47
48pub struct SaTokenMiddlewareService<S> {
49 service: Rc<S>,
50 state: SaTokenState,
51}
52
53impl<S, B> Service<ServiceRequest> for SaTokenMiddlewareService<S>
54where
55 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
56 S::Future: 'static,
57 B: 'static,
58{
59 type Response = ServiceResponse<B>;
60 type Error = Error;
61 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
62
63 forward_ready!(service);
64
65 fn call(&self, req: ServiceRequest) -> Self::Future {
66 let service = Rc::clone(&self.service);
67 let state = self.state.clone();
68
69 Box::pin(async move {
70 if let Some(token_str) = extract_token_from_request(&req, &state) {
72 let token = TokenValue::new(token_str);
73
74 if state.manager.is_valid(&token).await {
76 req.extensions_mut().insert(token.clone());
78
79 if let Ok(token_info) = state.manager.get_token_info(&token).await {
81 req.extensions_mut().insert(token_info.login_id.clone());
82 }
83 }
84 }
85
86 service.call(req).await
87 })
88 }
89}
90
91pub struct SaCheckLoginMiddleware {
93 pub state: SaTokenState,
94}
95
96impl SaCheckLoginMiddleware {
97 pub fn new(state: SaTokenState) -> Self {
98 Self { state }
99 }
100}
101
102impl<S, B> Transform<S, ServiceRequest> for SaCheckLoginMiddleware
103where
104 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
105 S::Future: 'static,
106 B: 'static,
107{
108 type Response = ServiceResponse<B>;
109 type Error = Error;
110 type InitError = ();
111 type Transform = SaCheckLoginMiddlewareService<S>;
112 type Future = Ready<Result<Self::Transform, Self::InitError>>;
113
114 fn new_transform(&self, service: S) -> Self::Future {
115 ready(Ok(SaCheckLoginMiddlewareService {
116 service: Rc::new(service),
117 state: self.state.clone(),
118 }))
119 }
120}
121
122pub struct SaCheckLoginMiddlewareService<S> {
123 service: Rc<S>,
124 state: SaTokenState,
125}
126
127impl<S, B> Service<ServiceRequest> for SaCheckLoginMiddlewareService<S>
128where
129 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
130 S::Future: 'static,
131 B: 'static,
132{
133 type Response = ServiceResponse<B>;
134 type Error = Error;
135 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
136
137 forward_ready!(service);
138
139 fn call(&self, req: ServiceRequest) -> Self::Future {
140 let service = Rc::clone(&self.service);
141 let state = self.state.clone();
142
143 Box::pin(async move {
144 if let Some(token_str) = extract_token_from_request(&req, &state) {
146 let token = TokenValue::new(token_str);
147
148 if state.manager.is_valid(&token).await {
150 req.extensions_mut().insert(token.clone());
152
153 if let Ok(token_info) = state.manager.get_token_info(&token).await {
154 req.extensions_mut().insert(token_info.login_id.clone());
155 }
156
157 return service.call(req).await;
158 }
159 }
160
161 Err(ErrorUnauthorized(serde_json::json!({
163 "code": 401,
164 "message": "未登录"
165 }).to_string()))
166 })
167 }
168}
169
170fn extract_token_from_request(req: &ServiceRequest, state: &SaTokenState) -> Option<String> {
172 let adapter = ActixRequestAdapter::new(req.request());
173 let token_name = &state.manager.config.token_name;
174
175 if let Some(token) = adapter.get_header(token_name) {
177 return Some(extract_bearer_token(&token));
178 }
179
180 if let Some(token) = adapter.get_cookie(token_name) {
182 return Some(token);
183 }
184
185 if let Some(query) = req.query_string().split('&').find_map(|pair| {
187 let mut parts = pair.split('=');
188 if let (Some(key), Some(value)) = (parts.next(), parts.next()) {
189 if key == token_name {
190 return urlencoding::decode(value).ok().map(|s| s.to_string());
191 }
192 }
193 None
194 }) {
195 return Some(query);
196 }
197
198 None
199}
200
201fn extract_bearer_token(token: &str) -> String {
203 if token.starts_with("Bearer ") {
204 token[7..].to_string()
205 } else {
206 token.to_string()
207 }
208}