sa_token_plugin_rocket/
layer.rs

1use rocket::{Request, Data, Response};
2use rocket::fairing::{Fairing, Info, Kind};
3use sa_token_core::{token::TokenValue, SaTokenContext};
4use crate::SaTokenState;
5use std::sync::Arc;
6
7pub struct SaTokenLayer {
8    state: SaTokenState,
9}
10
11impl SaTokenLayer {
12    pub fn new(state: SaTokenState) -> Self {
13        Self { state }
14    }
15}
16
17#[rocket::async_trait]
18impl Fairing for SaTokenLayer {
19    fn info(&self) -> Info {
20        Info {
21            name: "Sa-Token Authentication",
22            kind: Kind::Request | Kind::Response,
23        }
24    }
25    
26    async fn on_request(&self, req: &mut Request<'_>, _: &mut Data<'_>) {
27        let mut ctx = SaTokenContext::new();
28        
29        if let Some(token_str) = extract_token_from_request(req, &self.state) {
30            tracing::debug!("Sa-Token: extracted token from request: {}", token_str);
31            let token = TokenValue::new(token_str);
32            
33            if self.state.manager.is_valid(&token).await {
34                req.local_cache(|| Some(token.clone()));
35                
36                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
37                    let login_id = token_info.login_id.clone();
38                    req.local_cache(|| Some(login_id.clone()));
39                    
40                    ctx.token = Some(token.clone());
41                    ctx.token_info = Some(Arc::new(token_info));
42                    ctx.login_id = Some(login_id);
43                }
44            }
45        }
46        
47        SaTokenContext::set_current(ctx);
48    }
49    
50    async fn on_response<'r>(&self, _req: &'r Request<'_>, _res: &mut Response<'r>) {
51        SaTokenContext::clear();
52    }
53}
54
55fn extract_token_from_request(req: &Request, state: &SaTokenState) -> Option<String> {
56    use sa_token_adapter::utils::extract_bearer_token as utils_extract_bearer_token;
57    let token_name = &state.manager.config.token_name;
58    
59    // 1. 优先从 Header 中获取
60    if let Some(header_value) = req.headers().get_one(token_name) {
61        if let Some(token) = utils_extract_bearer_token(header_value) {
62            return Some(token);
63        }
64    }
65    
66    // 检查 Authorization header
67    if let Some(auth_header) = req.headers().get_one("authorization") {
68        if let Some(token) = utils_extract_bearer_token(auth_header) {
69            return Some(token);
70        }
71    }
72    
73    // 2. 从 Cookie 中获取
74    if let Some(cookie_value) = req.cookies().get(token_name) {
75        return Some(cookie_value.value().to_string());
76    }
77    
78    // 3. 从 Query 参数中获取
79    if let Some(query) = req.uri().query() {
80        let params = parse_query_string(query.as_str());
81        if let Some(token) = params.get(token_name) {
82            return Some(token.clone());
83        }
84    }
85    
86    None
87}
88
89// 解析查询字符串
90fn parse_query_string(query: &str) -> std::collections::HashMap<String, String> {
91    let mut params = std::collections::HashMap::new();
92    for pair in query.split('&') {
93        if let Some((key, value)) = pair.split_once('=') {
94            if let Ok(decoded_value) = urlencoding::decode(value) {
95                params.insert(key.to_string(), decoded_value.to_string());
96            }
97        }
98    }
99    params
100}