sa_token_plugin_rocket/
middleware.rs

1// Author: 金书记
2//
3//! Rocket Fairing (中间件)
4
5use rocket::{Request, Data, Response};
6use rocket::fairing::{Fairing, Info, Kind};
7use rocket::http::Status;
8use crate::SaTokenState;
9use sa_token_core::token::TokenValue;
10
11/// sa-token Fairing - 提取并验证 token
12pub struct SaTokenFairing {
13    state: SaTokenState,
14}
15
16impl SaTokenFairing {
17    pub fn new(state: SaTokenState) -> Self {
18        Self { state }
19    }
20}
21
22#[rocket::async_trait]
23impl Fairing for SaTokenFairing {
24    fn info(&self) -> Info {
25        Info {
26            name: "SaToken Authentication",
27            kind: Kind::Request,
28        }
29    }
30    
31    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
32        // 提取 token
33        let token_str = {
34            let token_name = &self.state.manager.config.token_name;
35            
36            // 1. 从 Header 获取
37            if let Some(header_val) = request.headers().get_one(token_name) {
38                Some(extract_bearer_token(header_val))
39            }
40            // 2. 从 Cookie 获取
41            else if let Some(cookie) = request.cookies().get(token_name) {
42                Some(cookie.value().to_string())
43            }
44            // 3. 从 Query 参数获取
45            else if let Some(query) = request.uri().query() {
46                parse_query_param(query.as_str(), token_name)
47            } else {
48                None
49            }
50        };
51        
52        if let Some(token_str) = token_str {
53            let token = TokenValue::new(token_str);
54            
55            // 验证 token
56            if self.state.manager.is_valid(&token).await {
57                // 存储 token 到本地缓存
58                request.local_cache(|| Some(token.clone()));
59                
60                // 获取并存储 login_id
61                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
62                    request.local_cache(|| Some(token_info.login_id.clone()));
63                }
64            }
65        }
66    }
67}
68
69/// sa-token 登录检查 Fairing - 强制要求登录
70pub struct SaCheckLoginFairing {
71    state: SaTokenState,
72}
73
74impl SaCheckLoginFairing {
75    pub fn new(state: SaTokenState) -> Self {
76        Self { state }
77    }
78}
79
80#[rocket::async_trait]
81impl Fairing for SaCheckLoginFairing {
82    fn info(&self) -> Info {
83        Info {
84            name: "SaToken Check Login",
85            kind: Kind::Request | Kind::Response,
86        }
87    }
88    
89    async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
90        // 提取 token
91        let token_str = {
92            let token_name = &self.state.manager.config.token_name;
93            
94            // 1. 从 Header 获取
95            if let Some(header_val) = request.headers().get_one(token_name) {
96                Some(extract_bearer_token(header_val))
97            }
98            // 2. 从 Cookie 获取
99            else if let Some(cookie) = request.cookies().get(token_name) {
100                Some(cookie.value().to_string())
101            }
102            // 3. 从 Query 参数获取
103            else if let Some(query) = request.uri().query() {
104                parse_query_param(query.as_str(), token_name)
105            } else {
106                None
107            }
108        };
109        
110        if let Some(token_str) = token_str {
111            let token = TokenValue::new(token_str);
112            
113            // 验证 token
114            if self.state.manager.is_valid(&token).await {
115                // 存储 token
116                request.local_cache(|| Some(token.clone()));
117                
118                // 获取并存储 login_id
119                if let Ok(token_info) = self.state.manager.get_token_info(&token).await {
120                    request.local_cache(|| Some(token_info.login_id.clone()));
121                }
122                return;
123            }
124        }
125        
126        // 未登录,标记为未授权
127        request.local_cache(|| Some("unauthorized"));
128    }
129    
130    async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
131        // 检查是否标记为未授权
132        if let Some(_) = request.local_cache(|| None::<&str>) {
133            if *request.local_cache(|| None::<&str>) == Some("unauthorized") {
134                response.set_status(Status::Unauthorized);
135                response.set_sized_body(None, std::io::Cursor::new(
136                    serde_json::json!({
137                        "code": 401,
138                        "message": "未登录"
139                    }).to_string()
140                ));
141            }
142        }
143    }
144}
145
146/// 提取 Bearer token
147fn extract_bearer_token(token: &str) -> String {
148    if token.starts_with("Bearer ") {
149        token[7..].to_string()
150    } else {
151        token.to_string()
152    }
153}
154
155/// 解析查询参数
156fn parse_query_param(query: &str, name: &str) -> Option<String> {
157    for pair in query.split('&') {
158        if let Some((key, value)) = pair.split_once('=') {
159            if key == name {
160                return urlencoding::decode(value).ok().map(|s| s.to_string());
161            }
162        }
163    }
164    None
165}