sa_token_plugin_ntex/
layer.rs1use ntex::service::{Service, ServiceCtx, Middleware};
2use ntex::web::{Error, ErrorRenderer, WebRequest, WebResponse};
3use crate::state::SaTokenState;
4use sa_token_core::{token::TokenValue, SaTokenContext};
5use std::sync::Arc;
6
7#[derive(Clone)]
8pub struct SaTokenLayer {
9 state: SaTokenState,
10}
11
12impl SaTokenLayer {
13 pub fn new(state: SaTokenState) -> Self {
14 Self { state }
15 }
16}
17
18impl<S> Middleware<S> for SaTokenLayer {
19 type Service = SaTokenMiddleware<S>;
20
21 fn create(&self, service: S) -> Self::Service {
22 SaTokenMiddleware {
23 service,
24 state: self.state.clone(),
25 }
26 }
27}
28
29pub struct SaTokenMiddleware<S> {
30 service: S,
31 state: SaTokenState,
32}
33
34impl<S, Err> Service<WebRequest<Err>> for SaTokenMiddleware<S>
35where
36 S: Service<WebRequest<Err>, Response = WebResponse, Error = Error>,
37 Err: ErrorRenderer,
38{
39 type Response = WebResponse;
40 type Error = Error;
41
42 async fn call(&self, req: WebRequest<Err>, ctx: ServiceCtx<'_, Self>) -> Result<Self::Response, Self::Error> {
43 let mut sa_ctx = SaTokenContext::new();
44
45 if let Some(token_str) = extract_token_from_request(&req, &self.state) {
46 tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
47 let token = TokenValue::new(token_str);
48
49 if self.state.manager.is_valid(&token).await {
50 req.extensions_mut().insert(token.clone());
51
52 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
53 let login_id = token_info.login_id.clone();
54 req.extensions_mut().insert(login_id.clone());
55
56 sa_ctx.token = Some(token.clone());
57 sa_ctx.token_info = Some(Arc::new(token_info));
58 sa_ctx.login_id = Some(login_id);
59 }
60 }
61 }
62
63 SaTokenContext::set_current(sa_ctx);
64 let result = ctx.call(&self.service, req).await;
65 SaTokenContext::clear();
66 result
67 }
68}
69
70fn extract_token_from_request<Err>(req: &WebRequest<Err>, state: &SaTokenState) -> Option<String>
71where
72 Err: ErrorRenderer,
73{
74 let token_name = &state.manager.config.token_name;
75 let headers = req.headers();
76
77 if let Some(header_value) = headers.get(token_name) {
79 if let Ok(value_str) = header_value.to_str() {
80 return Some(extract_bearer_token(value_str));
81 }
82 }
83
84 if let Some(auth_header) = headers.get("Authorization") {
86 if let Ok(auth_str) = auth_header.to_str() {
87 return Some(extract_bearer_token(auth_str));
88 }
89 }
90
91 if let Some(cookie_header) = headers.get("cookie") {
93 if let Ok(cookie_str) = cookie_header.to_str() {
94 if let Some(token) = parse_cookie(cookie_str, token_name) {
95 return Some(token);
96 }
97 }
98 }
99
100 if let Some(query) = req.uri().query() {
102 if let Some(token) = parse_query_param(query, token_name) {
103 return Some(token);
104 }
105 }
106
107 None
108}
109
110fn extract_bearer_token(header_value: &str) -> String {
111 if header_value.starts_with("Bearer ") {
112 header_value[7..].trim().to_string()
113 } else {
114 header_value.trim().to_string()
115 }
116}
117
118fn parse_cookie(cookie_str: &str, token_name: &str) -> Option<String> {
119 for part in cookie_str.split(';') {
120 let part = part.trim();
121 if let Some(eq_pos) = part.find('=') {
122 let (name, value) = part.split_at(eq_pos);
123 if name.trim() == token_name {
124 return Some(value[1..].trim().to_string());
125 }
126 }
127 }
128 None
129}
130
131fn parse_query_param(query: &str, param_name: &str) -> Option<String> {
132 for pair in query.split('&') {
133 let parts: Vec<&str> = pair.splitn(2, '=').collect();
134 if parts.len() == 2 && parts[0] == param_name {
135 return urlencoding::decode(parts[1])
136 .ok()
137 .map(|s| s.into_owned());
138 }
139 }
140 None
141}