sa_token_plugin_rocket/
layer.rs1use rocket::{Request, Data, Response};
2use rocket::fairing::{Fairing, Info, Kind};
3use sa_token_core::{token::TokenValue, SaTokenContext};
4use crate::SaTokenState;
5use std::sync::Arc;
6
7pub struct SaTokenLayer {
8 state: SaTokenState,
9}
10
11impl SaTokenLayer {
12 pub fn new(state: SaTokenState) -> Self {
13 Self { state }
14 }
15}
16
17#[rocket::async_trait]
18impl Fairing for SaTokenLayer {
19 fn info(&self) -> Info {
20 Info {
21 name: "Sa-Token Authentication",
22 kind: Kind::Request | Kind::Response,
23 }
24 }
25
26 async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
27 let mut ctx = SaTokenContext::new();
28
29 if let Some(token_str) = extract_token_from_request(req, &self.state) {
30 tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
31 let token = TokenValue::new(token_str);
32
33 if self.state.manager.is_valid(&token).await {
34 req.local_cache(|| Some(token.clone()));
35
36 if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
37 let login_id = token_info.login_id.clone();
38 req.local_cache(|| Some(login_id.clone()));
39
40 ctx.token = Some(token.clone());
41 ctx.token_info = Some(Arc::new(token_info));
42 ctx.login_id = Some(login_id);
43 }
44 }
45 }
46
47 SaTokenContext::set_current(ctx);
48 }
49
50 async fn on_response<'r>(&self, _req: &'r Request<'_>, _res: &mut Response<'r>) {
51 SaTokenContext::clear();
52 }
53}
54
55fn extract_token_from_request(req: &Request, state: &SaTokenState) -> Option<String> {
56 use sa_token_adapter::utils::extract_bearer_token as utils_extract_bearer_token;
57 let token_name = &state.manager.config.token_name;
58
59 if let Some(header_value) = req.headers().get_one(token_name) {
61 if let Some(token) = utils_extract_bearer_token(header_value) {
62 return Some(token);
63 }
64 }
65
66 if let Some(auth_header) = req.headers().get_one("authorization") {
68 if let Some(token) = utils_extract_bearer_token(auth_header) {
69 return Some(token);
70 }
71 }
72
73 if let Some(cookie_value) = req.cookies().get(token_name) {
75 return Some(cookie_value.value().to_string());
76 }
77
78 if let Some(query) = req.uri().query() {
80 let params = parse_query_string(query.as_str());
81 if let Some(token) = params.get(token_name) {
82 return Some(token.clone());
83 }
84 }
85
86 None
87}
88
89fn parse_query_string(query: &str) -> std::collections::HashMap<String, String> {
91 let mut params = std::collections::HashMap::new();
92 for pair in query.split('&') {
93 if let Some((key, value)) = pair.split_once('=') {
94 if let Ok(decoded_value) = urlencoding::decode(value) {
95 params.insert(key.to_string(), decoded_value.to_string());
96 }
97 }
98 }
99 params
100}