torch_web/extractors/
cookies.rs

1//! Cookie extraction
2//!
3//! Extract and parse HTTP cookies from requests.
4
5use std::pin::Pin;
6use std::future::Future;
7use std::collections::HashMap;
8use crate::{Request, extractors::{FromRequestParts, ExtractionError}};
9
10/// Extract cookies from the request
11///
12/// # Example
13///
14/// ```rust,no_run
15/// use torch_web::extractors::Cookies;
16/// use std::collections::HashMap;
17///
18/// async fn handler(Cookies(cookies): Cookies) {
19///     if let Some(session_id) = cookies.get("session_id") {
20///         // Handle session
21///     }
22/// }
23/// ```
24pub struct Cookies(pub HashMap<String, String>);
25
26impl FromRequestParts for Cookies {
27    type Error = ExtractionError;
28
29    fn from_request_parts(
30        req: &mut Request,
31    ) -> Pin<Box<dyn Future<Output = Result<Self, Self::Error>> + Send + 'static>> {
32        let cookie_header = req.headers()
33            .get("cookie")
34            .and_then(|v| v.to_str().ok())
35            .unwrap_or("")
36            .to_string();
37        
38        Box::pin(async move {
39            let cookies = parse_cookies(&cookie_header)?;
40            Ok(Cookies(cookies))
41        })
42    }
43}
44
45/// Extract a specific cookie by name using a helper function
46///
47/// # Example
48///
49/// ```rust,no_run
50/// use torch_web::extractors::{Cookies, get_cookie};
51///
52/// async fn handler(Cookies(cookies): Cookies) {
53///     match get_cookie(&cookies, "session_id") {
54///         Some(id) => {
55///             // Handle session
56///         }
57///         None => {
58///             // No session cookie
59///         }
60///     }
61/// }
62/// ```
63
64/// Helper function to get a specific cookie by name
65pub fn get_cookie<'a>(cookies: &'a std::collections::HashMap<String, String>, name: &str) -> Option<&'a String> {
66    cookies.get(name)
67}
68
69/// Helper function to get a required cookie by name
70pub fn get_required_cookie<'a>(cookies: &'a std::collections::HashMap<String, String>, name: &str) -> Result<&'a String, ExtractionError> {
71    cookies.get(name).ok_or_else(|| ExtractionError::MissingHeader(
72        format!("Required cookie '{}' not found", name)
73    ))
74}
75
76/// Convenience extractors for common cookies
77pub struct SessionCookie(pub Option<String>);
78
79impl FromRequestParts for SessionCookie {
80    type Error = ExtractionError;
81
82    fn from_request_parts(
83        req: &mut Request,
84    ) -> Pin<Box<dyn Future<Output = Result<Self, Self::Error>> + Send + 'static>> {
85        let cookie_header = req.headers()
86            .get("cookie")
87            .and_then(|v| v.to_str().ok())
88            .unwrap_or("")
89            .to_string();
90        
91        Box::pin(async move {
92            let cookies = parse_cookies(&cookie_header)?;
93            let session = cookies.get("session_id")
94                .or_else(|| cookies.get("sessionid"))
95                .or_else(|| cookies.get("SESSIONID"))
96                .cloned();
97            Ok(SessionCookie(session))
98        })
99    }
100}
101
102/// Parse cookie header string into a HashMap
103fn parse_cookies(cookie_header: &str) -> Result<HashMap<String, String>, ExtractionError> {
104    let mut cookies = HashMap::new();
105    
106    if cookie_header.is_empty() {
107        return Ok(cookies);
108    }
109
110    for cookie_pair in cookie_header.split(';') {
111        let cookie_pair = cookie_pair.trim();
112        if let Some((name, value)) = cookie_pair.split_once('=') {
113            let name = name.trim().to_string();
114            let value = value.trim().to_string();
115            
116            // Basic URL decoding for cookie values
117            let decoded_value = urlencoding::decode(&value)
118                .map_err(|e| ExtractionError::InvalidHeader(format!("Invalid cookie encoding: {}", e)))?
119                .into_owned();
120            
121            cookies.insert(name, decoded_value);
122        } else {
123            // Handle cookies without values (rare but possible)
124            let name = cookie_pair.to_string();
125            cookies.insert(name, String::new());
126        }
127    }
128
129    Ok(cookies)
130}
131
132/// Cookie builder for creating Set-Cookie headers
133#[derive(Debug, Clone)]
134pub struct CookieBuilder {
135    name: String,
136    value: String,
137    domain: Option<String>,
138    path: Option<String>,
139    max_age: Option<i64>,
140    expires: Option<String>,
141    secure: bool,
142    http_only: bool,
143    same_site: Option<SameSite>,
144}
145
146#[derive(Debug, Clone)]
147pub enum SameSite {
148    Strict,
149    Lax,
150    None,
151}
152
153impl CookieBuilder {
154    pub fn new(name: impl Into<String>, value: impl Into<String>) -> Self {
155        Self {
156            name: name.into(),
157            value: value.into(),
158            domain: None,
159            path: None,
160            max_age: None,
161            expires: None,
162            secure: false,
163            http_only: false,
164            same_site: None,
165        }
166    }
167
168    pub fn domain(mut self, domain: impl Into<String>) -> Self {
169        self.domain = Some(domain.into());
170        self
171    }
172
173    pub fn path(mut self, path: impl Into<String>) -> Self {
174        self.path = Some(path.into());
175        self
176    }
177
178    pub fn max_age(mut self, seconds: i64) -> Self {
179        self.max_age = Some(seconds);
180        self
181    }
182
183    pub fn expires(mut self, expires: impl Into<String>) -> Self {
184        self.expires = Some(expires.into());
185        self
186    }
187
188    pub fn secure(mut self, secure: bool) -> Self {
189        self.secure = secure;
190        self
191    }
192
193    pub fn http_only(mut self, http_only: bool) -> Self {
194        self.http_only = http_only;
195        self
196    }
197
198    pub fn same_site(mut self, same_site: SameSite) -> Self {
199        self.same_site = Some(same_site);
200        self
201    }
202
203    pub fn build(self) -> String {
204        let mut cookie = format!("{}={}", self.name, urlencoding::encode(&self.value));
205
206        if let Some(domain) = self.domain {
207            cookie.push_str(&format!("; Domain={}", domain));
208        }
209
210        if let Some(path) = self.path {
211            cookie.push_str(&format!("; Path={}", path));
212        }
213
214        if let Some(max_age) = self.max_age {
215            cookie.push_str(&format!("; Max-Age={}", max_age));
216        }
217
218        if let Some(expires) = self.expires {
219            cookie.push_str(&format!("; Expires={}", expires));
220        }
221
222        if self.secure {
223            cookie.push_str("; Secure");
224        }
225
226        if self.http_only {
227            cookie.push_str("; HttpOnly");
228        }
229
230        if let Some(same_site) = self.same_site {
231            let same_site_str = match same_site {
232                SameSite::Strict => "Strict",
233                SameSite::Lax => "Lax",
234                SameSite::None => "None",
235            };
236            cookie.push_str(&format!("; SameSite={}", same_site_str));
237        }
238
239        cookie
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_parse_empty_cookies() {
249        let result = parse_cookies("");
250        assert!(result.unwrap().is_empty());
251    }
252
253    #[test]
254    fn test_parse_single_cookie() {
255        let result = parse_cookies("session_id=abc123");
256        let cookies = result.unwrap();
257        assert_eq!(cookies.get("session_id"), Some(&"abc123".to_string()));
258    }
259
260    #[test]
261    fn test_parse_multiple_cookies() {
262        let result = parse_cookies("session_id=abc123; user_id=456; theme=dark");
263        let cookies = result.unwrap();
264        assert_eq!(cookies.get("session_id"), Some(&"abc123".to_string()));
265        assert_eq!(cookies.get("user_id"), Some(&"456".to_string()));
266        assert_eq!(cookies.get("theme"), Some(&"dark".to_string()));
267    }
268
269    #[test]
270    fn test_parse_cookies_with_spaces() {
271        let result = parse_cookies(" session_id = abc123 ; user_id = 456 ");
272        let cookies = result.unwrap();
273        assert_eq!(cookies.get("session_id"), Some(&"abc123".to_string()));
274        assert_eq!(cookies.get("user_id"), Some(&"456".to_string()));
275    }
276
277    #[test]
278    fn test_cookie_builder() {
279        let cookie = CookieBuilder::new("session_id", "abc123")
280            .domain("example.com")
281            .path("/")
282            .max_age(3600)
283            .secure(true)
284            .http_only(true)
285            .same_site(SameSite::Lax)
286            .build();
287
288        assert!(cookie.contains("session_id=abc123"));
289        assert!(cookie.contains("Domain=example.com"));
290        assert!(cookie.contains("Path=/"));
291        assert!(cookie.contains("Max-Age=3600"));
292        assert!(cookie.contains("Secure"));
293        assert!(cookie.contains("HttpOnly"));
294        assert!(cookie.contains("SameSite=Lax"));
295    }
296}