torch_web/extractors/
cookies.rs1use std::pin::Pin;
6use std::future::Future;
7use std::collections::HashMap;
8use crate::{Request, extractors::{FromRequestParts, ExtractionError}};
9
10pub struct Cookies(pub HashMap<String, String>);
25
26impl FromRequestParts for Cookies {
27 type Error = ExtractionError;
28
29 fn from_request_parts(
30 req: &mut Request,
31 ) -> Pin<Box<dyn Future<Output = Result<Self, Self::Error>> + Send + 'static>> {
32 let cookie_header = req.headers()
33 .get("cookie")
34 .and_then(|v| v.to_str().ok())
35 .unwrap_or("")
36 .to_string();
37
38 Box::pin(async move {
39 let cookies = parse_cookies(&cookie_header)?;
40 Ok(Cookies(cookies))
41 })
42 }
43}
44
45pub fn get_cookie<'a>(cookies: &'a std::collections::HashMap<String, String>, name: &str) -> Option<&'a String> {
66 cookies.get(name)
67}
68
69pub fn get_required_cookie<'a>(cookies: &'a std::collections::HashMap<String, String>, name: &str) -> Result<&'a String, ExtractionError> {
71 cookies.get(name).ok_or_else(|| ExtractionError::MissingHeader(
72 format!("Required cookie '{}' not found", name)
73 ))
74}
75
76pub struct SessionCookie(pub Option<String>);
78
79impl FromRequestParts for SessionCookie {
80 type Error = ExtractionError;
81
82 fn from_request_parts(
83 req: &mut Request,
84 ) -> Pin<Box<dyn Future<Output = Result<Self, Self::Error>> + Send + 'static>> {
85 let cookie_header = req.headers()
86 .get("cookie")
87 .and_then(|v| v.to_str().ok())
88 .unwrap_or("")
89 .to_string();
90
91 Box::pin(async move {
92 let cookies = parse_cookies(&cookie_header)?;
93 let session = cookies.get("session_id")
94 .or_else(|| cookies.get("sessionid"))
95 .or_else(|| cookies.get("SESSIONID"))
96 .cloned();
97 Ok(SessionCookie(session))
98 })
99 }
100}
101
102fn parse_cookies(cookie_header: &str) -> Result<HashMap<String, String>, ExtractionError> {
104 let mut cookies = HashMap::new();
105
106 if cookie_header.is_empty() {
107 return Ok(cookies);
108 }
109
110 for cookie_pair in cookie_header.split(';') {
111 let cookie_pair = cookie_pair.trim();
112 if let Some((name, value)) = cookie_pair.split_once('=') {
113 let name = name.trim().to_string();
114 let value = value.trim().to_string();
115
116 let decoded_value = urlencoding::decode(&value)
118 .map_err(|e| ExtractionError::InvalidHeader(format!("Invalid cookie encoding: {}", e)))?
119 .into_owned();
120
121 cookies.insert(name, decoded_value);
122 } else {
123 let name = cookie_pair.to_string();
125 cookies.insert(name, String::new());
126 }
127 }
128
129 Ok(cookies)
130}
131
132#[derive(Debug, Clone)]
134pub struct CookieBuilder {
135 name: String,
136 value: String,
137 domain: Option<String>,
138 path: Option<String>,
139 max_age: Option<i64>,
140 expires: Option<String>,
141 secure: bool,
142 http_only: bool,
143 same_site: Option<SameSite>,
144}
145
146#[derive(Debug, Clone)]
147pub enum SameSite {
148 Strict,
149 Lax,
150 None,
151}
152
153impl CookieBuilder {
154 pub fn new(name: impl Into<String>, value: impl Into<String>) -> Self {
155 Self {
156 name: name.into(),
157 value: value.into(),
158 domain: None,
159 path: None,
160 max_age: None,
161 expires: None,
162 secure: false,
163 http_only: false,
164 same_site: None,
165 }
166 }
167
168 pub fn domain(mut self, domain: impl Into<String>) -> Self {
169 self.domain = Some(domain.into());
170 self
171 }
172
173 pub fn path(mut self, path: impl Into<String>) -> Self {
174 self.path = Some(path.into());
175 self
176 }
177
178 pub fn max_age(mut self, seconds: i64) -> Self {
179 self.max_age = Some(seconds);
180 self
181 }
182
183 pub fn expires(mut self, expires: impl Into<String>) -> Self {
184 self.expires = Some(expires.into());
185 self
186 }
187
188 pub fn secure(mut self, secure: bool) -> Self {
189 self.secure = secure;
190 self
191 }
192
193 pub fn http_only(mut self, http_only: bool) -> Self {
194 self.http_only = http_only;
195 self
196 }
197
198 pub fn same_site(mut self, same_site: SameSite) -> Self {
199 self.same_site = Some(same_site);
200 self
201 }
202
203 pub fn build(self) -> String {
204 let mut cookie = format!("{}={}", self.name, urlencoding::encode(&self.value));
205
206 if let Some(domain) = self.domain {
207 cookie.push_str(&format!("; Domain={}", domain));
208 }
209
210 if let Some(path) = self.path {
211 cookie.push_str(&format!("; Path={}", path));
212 }
213
214 if let Some(max_age) = self.max_age {
215 cookie.push_str(&format!("; Max-Age={}", max_age));
216 }
217
218 if let Some(expires) = self.expires {
219 cookie.push_str(&format!("; Expires={}", expires));
220 }
221
222 if self.secure {
223 cookie.push_str("; Secure");
224 }
225
226 if self.http_only {
227 cookie.push_str("; HttpOnly");
228 }
229
230 if let Some(same_site) = self.same_site {
231 let same_site_str = match same_site {
232 SameSite::Strict => "Strict",
233 SameSite::Lax => "Lax",
234 SameSite::None => "None",
235 };
236 cookie.push_str(&format!("; SameSite={}", same_site_str));
237 }
238
239 cookie
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_parse_empty_cookies() {
249 let result = parse_cookies("");
250 assert!(result.unwrap().is_empty());
251 }
252
253 #[test]
254 fn test_parse_single_cookie() {
255 let result = parse_cookies("session_id=abc123");
256 let cookies = result.unwrap();
257 assert_eq!(cookies.get("session_id"), Some(&"abc123".to_string()));
258 }
259
260 #[test]
261 fn test_parse_multiple_cookies() {
262 let result = parse_cookies("session_id=abc123; user_id=456; theme=dark");
263 let cookies = result.unwrap();
264 assert_eq!(cookies.get("session_id"), Some(&"abc123".to_string()));
265 assert_eq!(cookies.get("user_id"), Some(&"456".to_string()));
266 assert_eq!(cookies.get("theme"), Some(&"dark".to_string()));
267 }
268
269 #[test]
270 fn test_parse_cookies_with_spaces() {
271 let result = parse_cookies(" session_id = abc123 ; user_id = 456 ");
272 let cookies = result.unwrap();
273 assert_eq!(cookies.get("session_id"), Some(&"abc123".to_string()));
274 assert_eq!(cookies.get("user_id"), Some(&"456".to_string()));
275 }
276
277 #[test]
278 fn test_cookie_builder() {
279 let cookie = CookieBuilder::new("session_id", "abc123")
280 .domain("example.com")
281 .path("/")
282 .max_age(3600)
283 .secure(true)
284 .http_only(true)
285 .same_site(SameSite::Lax)
286 .build();
287
288 assert!(cookie.contains("session_id=abc123"));
289 assert!(cookie.contains("Domain=example.com"));
290 assert!(cookie.contains("Path=/"));
291 assert!(cookie.contains("Max-Age=3600"));
292 assert!(cookie.contains("Secure"));
293 assert!(cookie.contains("HttpOnly"));
294 assert!(cookie.contains("SameSite=Lax"));
295 }
296}