Skip to main content

sa_token_plugin_rocket_v05/
adapter.rs

1// Author: 金书记
2//
3//! Rocket request/response adapters for `SaRequest` / `SaResponse`.
4//! Rocket 请求/响应适配器,实现 `SaRequest` / `SaResponse`。
5
6use rocket::{Request, Response};
7use rocket::http::{Header, Cookie, Status, ContentType};
8use sa_token_adapter::context::{SaRequest, SaResponse, CookieOptions};
9use serde::Serialize;
10use std::collections::HashMap;
11
12/// Borrows Rocket [`Request`] for synchronous `SaRequest` use (caller must not hold across `.await` with incompatible lifetimes).
13/// 借用 Rocket [`Request`] 实现同步 `SaRequest`(注意不要在不兼容的生命周期下跨 `.await` 持有)。
14pub struct RocketRequestAdapter<'a> {
15    request: &'a Request<'a>,
16}
17
18impl<'a> RocketRequestAdapter<'a> {
19    pub fn new(request: &'a Request<'a>) -> Self {
20        Self { request }
21    }
22}
23
24impl<'a> SaRequest for RocketRequestAdapter<'a> {
25    fn get_header(&self, name: &str) -> Option<String> {
26        self.request.headers().get_one(name)
27            .map(|s| s.to_string())
28    }
29    
30    fn get_cookie(&self, name: &str) -> Option<String> {
31        self.request.cookies().get(name)
32            .map(|c| c.value().to_string())
33    }
34    
35    fn get_param(&self, name: &str) -> Option<String> {
36        // Rocket 的查询参数需要从 URI 中提取
37        if let Some(query) = self.request.uri().query() {
38            return parse_query_string(query.as_str())
39                .get(name)
40                .cloned();
41        }
42        None
43    }
44    
45    fn get_path(&self) -> String {
46        self.request.uri().path().to_string()
47    }
48    
49    fn get_method(&self) -> String {
50        self.request.method().to_string()
51    }
52    
53    fn get_client_ip(&self) -> Option<String> {
54        self.request.client_ip()
55            .map(|ip| ip.to_string())
56    }
57}
58
59/// Owned snapshot of headers/cookies/query/path required by **`run_auth_flow`** in `rocket-core`.
60/// 承载 **`run_auth_flow`**(`rocket-core`)所需的请求头/Cookie/查询串/路径等字段副本。
61///
62/// English: do not hold a borrow of [`Request`] across `.await` in Fairings — build this struct first.
63/// 中文:Fairing 内勿让对 [`Request`] 的借用跨过 `.await`,须先构造本结构体。
64pub struct RocketCapturedRequest {
65    token_name: String,
66    token_name_header: Option<String>,
67    authorization: Option<String>,
68    cookie_token: Option<String>,
69    query_token: Option<String>,
70    path: String,
71    method: String,
72    client_ip: Option<String>,
73}
74
75impl RocketCapturedRequest {
76    /// Build snapshot from live request (sync only). | 从当前请求构建快照(仅同步调用)。
77    pub fn capture(req: &Request<'_>, token_name: &str) -> Self {
78        let path = req.uri().path().to_string();
79        let method = req.method().to_string();
80        let client_ip = req.client_ip().map(|ip| ip.to_string());
81        let token_name_header = req.headers().get_one(token_name).map(|s| s.to_string());
82        let authorization = if !token_name.eq_ignore_ascii_case("authorization") {
83            req.headers().get_one("Authorization").map(|s| s.to_string())
84        } else {
85            None
86        };
87        let cookie_token = req.cookies().get(token_name).map(|c| c.value().to_string());
88        let query_token = req.uri().query().and_then(|q| {
89            parse_query_string(q.as_str()).get(token_name).cloned()
90        });
91        Self {
92            token_name: token_name.to_string(),
93            token_name_header,
94            authorization,
95            cookie_token,
96            query_token,
97            path,
98            method,
99            client_ip,
100        }
101    }
102}
103
104impl SaRequest for RocketCapturedRequest {
105    fn get_header(&self, name: &str) -> Option<String> {
106        if name.eq_ignore_ascii_case(&self.token_name) {
107            return self.token_name_header.clone();
108        }
109        if !self.token_name.eq_ignore_ascii_case("authorization")
110            && name.eq_ignore_ascii_case("authorization")
111        {
112            return self.authorization.clone();
113        }
114        None
115    }
116
117    fn get_cookie(&self, name: &str) -> Option<String> {
118        if name.eq_ignore_ascii_case(&self.token_name) {
119            self.cookie_token.clone()
120        } else {
121            None
122        }
123    }
124
125    fn get_param(&self, name: &str) -> Option<String> {
126        if name.eq_ignore_ascii_case(&self.token_name) {
127            self.query_token.clone()
128        } else {
129            None
130        }
131    }
132
133    fn get_path(&self) -> String {
134        self.path.clone()
135    }
136
137    fn get_method(&self) -> String {
138        self.method.clone()
139    }
140
141    fn get_client_ip(&self) -> Option<String> {
142        self.client_ip.clone()
143    }
144}
145
146/// Rocket 响应适配器
147pub struct RocketResponseAdapter<'a> {
148    response: &'a mut Response<'a>,
149}
150
151impl<'a> RocketResponseAdapter<'a> {
152    pub fn new(response: &'a mut Response<'a>) -> Self {
153        Self { response }
154    }
155}
156
157impl<'a> SaResponse for RocketResponseAdapter<'a> {
158    fn set_header(&mut self, name: &str, value: &str) {
159        self.response.set_header(Header::new(name.to_string(), value.to_string()));
160    }
161    
162    fn set_cookie(&mut self, name: &str, value: &str, options: CookieOptions) {
163        let mut cookie = Cookie::new(name.to_string(), value.to_string());
164        
165        if let Some(domain) = options.domain {
166            cookie.set_domain(domain);
167        }
168        if let Some(path) = options.path {
169            cookie.set_path(path);
170        }
171        if let Some(max_age) = options.max_age {
172            cookie.set_max_age(rocket::time::Duration::seconds(max_age));
173        }
174        cookie.set_http_only(options.http_only);
175        cookie.set_secure(options.secure);
176        
177        if let Some(same_site) = options.same_site {
178            use sa_token_adapter::context::SameSite as SaSameSite;
179            use rocket::http::SameSite;
180            
181            let ss = match same_site {
182                SaSameSite::Strict => SameSite::Strict,
183                SaSameSite::Lax => SameSite::Lax,
184                SaSameSite::None => SameSite::None,
185            };
186            cookie.set_same_site(ss);
187        }
188        
189        self.response.adjoin_header(cookie);
190    }
191    
192    fn set_status(&mut self, status: u16) {
193        if let Some(status_code) = Status::from_code(status) {
194            self.response.set_status(status_code);
195        }
196    }
197    
198    fn set_json_body<T: Serialize>(&mut self, body: T) -> Result<(), serde_json::Error> {
199        let json = serde_json::to_string(&body)?;
200        self.response.set_header(ContentType::JSON);
201        self.response.set_sized_body(Some(json.len()), std::io::Cursor::new(json));
202        Ok(())
203    }
204}
205
206/// 解析查询字符串
207fn parse_query_string(query: &str) -> HashMap<String, String> {
208    let mut params = HashMap::new();
209    for pair in query.split('&') {
210        if let Some((key, value)) = pair.split_once('=')
211            && let Ok(decoded_value) = urlencoding::decode(value) {
212                params.insert(key.to_string(), decoded_value.to_string());
213            }
214    }
215    params
216}