sa_token_plugin_rocket_v05/
adapter.rs1use rocket::{Request, Response};
7use rocket::http::{Header, Cookie, Status, ContentType};
8use sa_token_adapter::context::{SaRequest, SaResponse, CookieOptions};
9use serde::Serialize;
10use std::collections::HashMap;
11
12pub struct RocketRequestAdapter<'a> {
15 request: &'a Request<'a>,
16}
17
18impl<'a> RocketRequestAdapter<'a> {
19 pub fn new(request: &'a Request<'a>) -> Self {
20 Self { request }
21 }
22}
23
24impl<'a> SaRequest for RocketRequestAdapter<'a> {
25 fn get_header(&self, name: &str) -> Option<String> {
26 self.request.headers().get_one(name)
27 .map(|s| s.to_string())
28 }
29
30 fn get_cookie(&self, name: &str) -> Option<String> {
31 self.request.cookies().get(name)
32 .map(|c| c.value().to_string())
33 }
34
35 fn get_param(&self, name: &str) -> Option<String> {
36 if let Some(query) = self.request.uri().query() {
38 return parse_query_string(query.as_str())
39 .get(name)
40 .cloned();
41 }
42 None
43 }
44
45 fn get_path(&self) -> String {
46 self.request.uri().path().to_string()
47 }
48
49 fn get_method(&self) -> String {
50 self.request.method().to_string()
51 }
52
53 fn get_client_ip(&self) -> Option<String> {
54 self.request.client_ip()
55 .map(|ip| ip.to_string())
56 }
57}
58
59pub struct RocketCapturedRequest {
65 token_name: String,
66 token_name_header: Option<String>,
67 authorization: Option<String>,
68 cookie_token: Option<String>,
69 query_token: Option<String>,
70 path: String,
71 method: String,
72 client_ip: Option<String>,
73}
74
75impl RocketCapturedRequest {
76 pub fn capture(req: &Request<'_>, token_name: &str) -> Self {
78 let path = req.uri().path().to_string();
79 let method = req.method().to_string();
80 let client_ip = req.client_ip().map(|ip| ip.to_string());
81 let token_name_header = req.headers().get_one(token_name).map(|s| s.to_string());
82 let authorization = if !token_name.eq_ignore_ascii_case("authorization") {
83 req.headers().get_one("Authorization").map(|s| s.to_string())
84 } else {
85 None
86 };
87 let cookie_token = req.cookies().get(token_name).map(|c| c.value().to_string());
88 let query_token = req.uri().query().and_then(|q| {
89 parse_query_string(q.as_str()).get(token_name).cloned()
90 });
91 Self {
92 token_name: token_name.to_string(),
93 token_name_header,
94 authorization,
95 cookie_token,
96 query_token,
97 path,
98 method,
99 client_ip,
100 }
101 }
102}
103
104impl SaRequest for RocketCapturedRequest {
105 fn get_header(&self, name: &str) -> Option<String> {
106 if name.eq_ignore_ascii_case(&self.token_name) {
107 return self.token_name_header.clone();
108 }
109 if !self.token_name.eq_ignore_ascii_case("authorization")
110 && name.eq_ignore_ascii_case("authorization")
111 {
112 return self.authorization.clone();
113 }
114 None
115 }
116
117 fn get_cookie(&self, name: &str) -> Option<String> {
118 if name.eq_ignore_ascii_case(&self.token_name) {
119 self.cookie_token.clone()
120 } else {
121 None
122 }
123 }
124
125 fn get_param(&self, name: &str) -> Option<String> {
126 if name.eq_ignore_ascii_case(&self.token_name) {
127 self.query_token.clone()
128 } else {
129 None
130 }
131 }
132
133 fn get_path(&self) -> String {
134 self.path.clone()
135 }
136
137 fn get_method(&self) -> String {
138 self.method.clone()
139 }
140
141 fn get_client_ip(&self) -> Option<String> {
142 self.client_ip.clone()
143 }
144}
145
146pub struct RocketResponseAdapter<'a> {
148 response: &'a mut Response<'a>,
149}
150
151impl<'a> RocketResponseAdapter<'a> {
152 pub fn new(response: &'a mut Response<'a>) -> Self {
153 Self { response }
154 }
155}
156
157impl<'a> SaResponse for RocketResponseAdapter<'a> {
158 fn set_header(&mut self, name: &str, value: &str) {
159 self.response.set_header(Header::new(name.to_string(), value.to_string()));
160 }
161
162 fn set_cookie(&mut self, name: &str, value: &str, options: CookieOptions) {
163 let mut cookie = Cookie::new(name.to_string(), value.to_string());
164
165 if let Some(domain) = options.domain {
166 cookie.set_domain(domain);
167 }
168 if let Some(path) = options.path {
169 cookie.set_path(path);
170 }
171 if let Some(max_age) = options.max_age {
172 cookie.set_max_age(rocket::time::Duration::seconds(max_age));
173 }
174 cookie.set_http_only(options.http_only);
175 cookie.set_secure(options.secure);
176
177 if let Some(same_site) = options.same_site {
178 use sa_token_adapter::context::SameSite as SaSameSite;
179 use rocket::http::SameSite;
180
181 let ss = match same_site {
182 SaSameSite::Strict => SameSite::Strict,
183 SaSameSite::Lax => SameSite::Lax,
184 SaSameSite::None => SameSite::None,
185 };
186 cookie.set_same_site(ss);
187 }
188
189 self.response.adjoin_header(cookie);
190 }
191
192 fn set_status(&mut self, status: u16) {
193 if let Some(status_code) = Status::from_code(status) {
194 self.response.set_status(status_code);
195 }
196 }
197
198 fn set_json_body<T: Serialize>(&mut self, body: T) -> Result<(), serde_json::Error> {
199 let json = serde_json::to_string(&body)?;
200 self.response.set_header(ContentType::JSON);
201 self.response.set_sized_body(Some(json.len()), std::io::Cursor::new(json));
202 Ok(())
203 }
204}
205
206fn parse_query_string(query: &str) -> HashMap<String, String> {
208 let mut params = HashMap::new();
209 for pair in query.split('&') {
210 if let Some((key, value)) = pair.split_once('=')
211 && let Ok(decoded_value) = urlencoding::decode(value) {
212 params.insert(key.to_string(), decoded_value.to_string());
213 }
214 }
215 params
216}