sa_token_plugin_axum/
layer.rs1use std::task::{Context, Poll};
6use tower::{Layer, Service};
7use http::{Request, Response};
8use sa_token_adapter::context::SaRequest;
9use crate::{SaTokenState, adapter::AxumRequestAdapter};
10use sa_token_core::{SaTokenContext, router::PathAuthConfig};
11use std::sync::Arc;
12
13#[derive(Clone)]
16pub struct SaTokenLayer {
17 state: SaTokenState,
18 path_config: Option<PathAuthConfig>,
21}
22
23impl SaTokenLayer {
24 pub fn new(state: SaTokenState) -> Self {
25 Self { state, path_config: None }
26 }
27
28 pub fn with_path_auth(state: SaTokenState, config: PathAuthConfig) -> Self {
29 Self { state, path_config: Some(config) }
30 }
31}
32
33impl<S> Layer<S> for SaTokenLayer {
34 type Service = SaTokenMiddleware<S>;
35
36 fn layer(&self, inner: S) -> Self::Service {
37 SaTokenMiddleware {
38 inner,
39 state: self.state.clone(),
40 path_config: self.path_config.clone(),
41 }
42 }
43}
44
45#[derive(Clone)]
46pub struct SaTokenMiddleware<S> {
47 pub(crate) inner: S,
48 pub(crate) state: SaTokenState,
49 pub(crate) path_config: Option<PathAuthConfig>,
52}
53
54impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for SaTokenMiddleware<S>
55where
56 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
57 S::Future: Send + 'static,
58 ReqBody: Send + 'static,
59 ResBody: Default + Send + 'static,
60{
61 type Response = S::Response;
62 type Error = S::Error;
63 type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
64
65 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
66 self.inner.poll_ready(cx)
67 }
68
69 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
70 let mut inner = self.inner.clone();
71 let state = self.state.clone();
72 let path_config = self.path_config.clone();
73
74 Box::pin(async move {
75 if let Some(config) = path_config {
76 let path = request.uri().path();
77 let token_str = extract_token_from_request(&request, &state);
78 let result = sa_token_core::router::process_auth(path, token_str, &config, &state.manager).await;
79
80 if result.should_reject() {
81 let mut response = Response::new(ResBody::default());
82 *response.status_mut() = http::StatusCode::UNAUTHORIZED;
83 return Ok(response);
84 }
85
86 if let Some(token) = &result.token {
87 request.extensions_mut().insert(token.clone());
88 }
89 if let Some(login_id) = result.login_id() {
90 request.extensions_mut().insert(login_id.to_string());
91 }
92
93 let ctx = sa_token_core::router::create_context(&result);
94 SaTokenContext::set_current(ctx);
95 let response = inner.call(request).await;
96 SaTokenContext::clear();
97 return response;
98 }
99
100 let mut ctx = SaTokenContext::new();
103 if let Some(token_str) = extract_token_from_request(&request, &state) {
104 let token = sa_token_core::token::TokenValue::new(token_str);
105 if state.manager.is_valid(&token).await {
106 request.extensions_mut().insert(token.clone());
107 if let Ok(token_info) = state.manager.get_token_info(&token).await {
108 let login_id = token_info.login_id.clone();
109 request.extensions_mut().insert(login_id.clone());
110 ctx.token = Some(token.clone());
111 ctx.token_info = Some(Arc::new(token_info));
112 ctx.login_id = Some(login_id);
113 }
114 }
115 }
116
117 SaTokenContext::set_current(ctx);
118 let response = inner.call(request).await;
119 SaTokenContext::clear();
120 response
121 })
122 }
123}
124
125pub fn extract_token_from_request<T>(request: &Request<T>, state: &SaTokenState) -> Option<String> {
141 let adapter = AxumRequestAdapter::new(request);
142 let token_name = &state.manager.config.token_name;
144
145 if let Some(token) = adapter.get_header(token_name) {
147 return Some(extract_bearer_token(&token));
148 }
149
150 if token_name != "Authorization" {
152 if let Some(token) = adapter.get_header("Authorization") {
153 return Some(extract_bearer_token(&token));
154 }
155 }
156
157 if let Some(token) = adapter.get_cookie(token_name) {
159 return Some(token);
160 }
161
162 if let Some(query) = request.uri().query() {
164 if let Some(token) = parse_query_param(query, token_name) {
165 return Some(token);
166 }
167 }
168
169 None
170}
171
172fn extract_bearer_token(header_value: &str) -> String {
178 const BEARER_PREFIX: &str = "Bearer ";
179
180 if header_value.starts_with(BEARER_PREFIX) {
181 header_value[BEARER_PREFIX.len()..].trim().to_string()
183 } else {
184 header_value.trim().to_string()
186 }
187}
188
189fn parse_query_param(query: &str, param_name: &str) -> Option<String> {
190 for pair in query.split('&') {
191 let parts: Vec<&str> = pair.splitn(2, '=').collect();
192 if parts.len() == 2 && parts[0] == param_name {
193 return urlencoding::decode(parts[1])
194 .ok()
195 .map(|s| s.into_owned());
196 }
197 }
198 None
199}