sa_token_plugin_rocket/
adapter.rs

1// Author: 金书记
2//
3//! Rocket 请求/响应适配器
4
5use rocket::{Request, Response};
6use rocket::http::{Header, Cookie, Status, ContentType};
7use sa_token_adapter::context::{SaRequest, SaResponse, CookieOptions};
8use serde::Serialize;
9use std::collections::HashMap;
10
11/// Rocket 请求适配器
12pub struct RocketRequestAdapter<'a> {
13    request: &'a Request<'a>,
14}
15
16impl<'a> RocketRequestAdapter<'a> {
17    pub fn new(request: &'a Request<'a>) -> Self {
18        Self { request }
19    }
20}
21
22impl<'a> SaRequest for RocketRequestAdapter<'a> {
23    fn get_header(&self, name: &str) -> Option<String> {
24        self.request.headers().get_one(name)
25            .map(|s| s.to_string())
26    }
27    
28    fn get_cookie(&self, name: &str) -> Option<String> {
29        self.request.cookies().get(name)
30            .map(|c| c.value().to_string())
31    }
32    
33    fn get_param(&self, name: &str) -> Option<String> {
34        // Rocket 的查询参数需要从 URI 中提取
35        if let Some(query) = self.request.uri().query() {
36            return parse_query_string(query.as_str())
37                .get(name)
38                .cloned();
39        }
40        None
41    }
42    
43    fn get_path(&self) -> String {
44        self.request.uri().path().to_string()
45    }
46    
47    fn get_method(&self) -> String {
48        self.request.method().to_string()
49    }
50    
51    fn get_client_ip(&self) -> Option<String> {
52        self.request.client_ip()
53            .map(|ip| ip.to_string())
54    }
55}
56
57/// Rocket 响应适配器
58pub struct RocketResponseAdapter<'a> {
59    response: &'a mut Response<'a>,
60}
61
62impl<'a> RocketResponseAdapter<'a> {
63    pub fn new(response: &'a mut Response<'a>) -> Self {
64        Self { response }
65    }
66}
67
68impl<'a> SaResponse for RocketResponseAdapter<'a> {
69    fn set_header(&mut self, name: &str, value: &str) {
70        self.response.set_header(Header::new(name.to_string(), value.to_string()));
71    }
72    
73    fn set_cookie(&mut self, name: &str, value: &str, options: CookieOptions) {
74        let mut cookie = Cookie::new(name.to_string(), value.to_string());
75        
76        if let Some(domain) = options.domain {
77            cookie.set_domain(domain);
78        }
79        if let Some(path) = options.path {
80            cookie.set_path(path);
81        }
82        if let Some(max_age) = options.max_age {
83            cookie.set_max_age(rocket::time::Duration::seconds(max_age));
84        }
85        cookie.set_http_only(options.http_only);
86        cookie.set_secure(options.secure);
87        
88        if let Some(same_site) = options.same_site {
89            use sa_token_adapter::context::SameSite as SaSameSite;
90            use rocket::http::SameSite;
91            
92            let ss = match same_site {
93                SaSameSite::Strict => SameSite::Strict,
94                SaSameSite::Lax => SameSite::Lax,
95                SaSameSite::None => SameSite::None,
96            };
97            cookie.set_same_site(ss);
98        }
99        
100        self.response.adjoin_header(cookie);
101    }
102    
103    fn set_status(&mut self, status: u16) {
104        if let Some(status_code) = Status::from_code(status) {
105            self.response.set_status(status_code);
106        }
107    }
108    
109    fn set_json_body<T: Serialize>(&mut self, body: T) -> Result<(), serde_json::Error> {
110        let json = serde_json::to_string(&body)?;
111        self.response.set_header(ContentType::JSON);
112        self.response.set_sized_body(Some(json.len()), std::io::Cursor::new(json));
113        Ok(())
114    }
115}
116
117/// 解析查询字符串
118fn parse_query_string(query: &str) -> HashMap<String, String> {
119    let mut params = HashMap::new();
120    for pair in query.split('&') {
121        if let Some((key, value)) = pair.split_once('=') {
122            if let Ok(decoded_value) = urlencoding::decode(value) {
123                params.insert(key.to_string(), decoded_value.to_string());
124            }
125        }
126    }
127    params
128}