1use std::collections::HashMap;
6use std::fmt;
7use std::path::Path;
8use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
9
10use chrono::{DateTime, TimeZone, Utc};
11use url::Url;
12
13use crate::error::{Error, Result};
14use crate::headers::Headers;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
20pub enum SameSite {
21 Strict,
23 Lax,
25 None,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31pub struct Cookie {
32 pub name: String,
33 pub value: String,
34 pub domain: String,
35 pub path: String,
36 pub secure: bool,
37 pub http_only: bool,
38 pub same_site: Option<SameSite>,
39 pub expires: Option<DateTime<Utc>>,
40 pub max_age: Option<i64>,
41 pub host_only: bool,
42 pub source_url: Option<String>,
43 pub raw_header: Option<String>,
44 pub creation_time: DateTime<Utc>,
46}
47
48impl Cookie {
49 pub fn new(
50 name: impl Into<String>,
51 value: impl Into<String>,
52 domain: impl Into<String>,
53 ) -> Self {
54 Self {
55 name: name.into(),
56 value: value.into(),
57 domain: normalize_domain(&domain.into()),
58 path: "/".to_string(),
59 secure: false,
60 http_only: false,
61 same_site: None,
62 expires: None,
63 max_age: None,
64 host_only: true,
65 source_url: None,
66 raw_header: None,
67 creation_time: Utc::now(),
68 }
69 }
70
71 pub fn with_path(mut self, path: impl Into<String>) -> Self {
73 self.path = path.into();
74 self
75 }
76
77 pub fn with_secure(mut self, secure: bool) -> Self {
79 self.secure = secure;
80 self
81 }
82
83 pub fn with_http_only(mut self, http_only: bool) -> Self {
85 self.http_only = http_only;
86 self
87 }
88
89 pub fn with_same_site(mut self, same_site: SameSite) -> Self {
91 self.same_site = Some(same_site);
92 self
93 }
94
95 pub fn with_expires(mut self, expires: DateTime<Utc>) -> Self {
97 self.expires = Some(expires);
98 self
99 }
100
101 pub fn with_host_only(mut self, host_only: bool) -> Self {
103 self.host_only = host_only;
104 self
105 }
106
107 pub fn from_set_cookie_header(header: &str, request_url: &str) -> Result<Self> {
108 let parsed_url = Url::parse(request_url).map_err(|e| Error::CookieParse(e.to_string()))?;
109 let request_domain = parsed_url
110 .host_str()
111 .ok_or_else(|| Error::CookieParse("No host in URL".to_string()))?;
112
113 let parts: Vec<&str> = header.split(';').map(str::trim).collect();
114 if parts.is_empty() {
115 return Err(Error::CookieParse("Empty cookie header".to_string()));
116 }
117
118 let (name, value) = match parts[0].split_once('=') {
119 Some((n, v)) => (n.trim().to_string(), v.trim().to_string()),
120 None => return Err(Error::CookieParse("No = in cookie".to_string())),
121 };
122
123 if name.is_empty() {
124 return Err(Error::CookieParse("Empty cookie name".to_string()));
125 }
126
127 let mut cookie = Cookie::new(name, value, request_domain);
128 cookie.raw_header = Some(header.to_string());
129 cookie.source_url = Some(request_url.to_string());
130
131 let mut domain_attr_present = false;
133
134 for attr in parts.iter().skip(1) {
135 let attr_lower = attr.to_lowercase();
136 if attr_lower == "secure" {
137 cookie.secure = true;
138 } else if attr_lower == "httponly" {
139 cookie.http_only = true;
140 } else if let Some((key, val)) = attr.split_once('=') {
141 match key.trim().to_lowercase().as_str() {
142 "domain" => {
143 cookie.domain = normalize_domain(val.trim());
144 domain_attr_present = true;
145 }
146 "path" => cookie.path = val.trim().to_string(),
147 "expires" => cookie.expires = parse_cookie_date(val.trim()),
148 "max-age" => cookie.max_age = val.trim().parse().ok(),
149 "samesite" => {
150 let ss_str = val.trim();
151 cookie.same_site = match ss_str.to_lowercase().as_str() {
152 "strict" => Some(SameSite::Strict),
153 "lax" => Some(SameSite::Lax),
154 "none" => Some(SameSite::None),
155 _ => None,
156 };
157 }
158 _ => {}
159 }
160 }
161 }
162
163 cookie.host_only = !domain_attr_present;
165
166 if let Some(max_age) = cookie.max_age {
169 if max_age > 0 {
170 cookie.expires = Some(Utc::now() + chrono::Duration::seconds(max_age));
171 } else {
172 cookie.expires = Some(Utc::now() - chrono::Duration::seconds(1));
174 }
175 }
176
177 if is_public_suffix(&cookie.domain) {
179 return Err(Error::CookieParse(format!(
180 "Cannot set cookie for public suffix: {}",
181 cookie.domain
182 )));
183 }
184
185 if cookie.same_site == Some(SameSite::None) && !cookie.secure {
187 return Err(Error::CookieParse(
188 "SameSite=None requires Secure attribute".to_string(),
189 ));
190 }
191
192 Ok(cookie)
193 }
194
195 pub fn matches_url(&self, url: &str) -> bool {
196 let parsed = match Url::parse(url) {
197 Ok(u) => u,
198 Err(_) => return false,
199 };
200 let request_domain = match parsed.host_str() {
201 Some(h) => h.to_lowercase(),
202 None => return false,
203 };
204
205 if self.secure && parsed.scheme() != "https" {
207 return false;
208 }
209
210 if let Some(expires) = self.expires {
212 if expires < Utc::now() {
213 return false;
214 }
215 }
216
217 if !self.domain_matches(&request_domain) {
219 return false;
220 }
221
222 let request_path = parsed.path();
224 if !self.path_matches(request_path) {
225 return false;
226 }
227
228 true
229 }
230
231 pub fn domain_matches(&self, request_domain: &str) -> bool {
234 let cookie_domain = self.domain.to_lowercase();
235 let request_domain_lower = request_domain.to_lowercase();
236
237 if self.host_only {
239 return request_domain_lower == cookie_domain;
240 }
241
242 if request_domain_lower == cookie_domain {
244 return true;
245 }
246
247 if request_domain_lower.len() > cookie_domain.len() {
251 let expected_suffix = format!(".{}", cookie_domain);
252 if request_domain_lower.ends_with(&expected_suffix) {
253 return true;
254 }
255 }
256
257 false
258 }
259
260 pub fn path_matches(&self, request_path: &str) -> bool {
263 let cookie_path = &self.path;
264
265 if request_path == cookie_path {
267 return true;
268 }
269
270 if !request_path.starts_with(cookie_path) {
272 return false;
273 }
274
275 if cookie_path.ends_with('/') {
277 return true;
278 }
279
280 if let Some(next_char) = request_path.chars().nth(cookie_path.len()) {
283 return next_char == '/';
284 }
285
286 false
287 }
288
289 pub fn to_netscape_line(&self) -> String {
290 format!(
292 "{}\t{}\t{}\t{}\t{}\t{}\t{}",
293 self.domain,
294 if self.host_only { "FALSE" } else { "TRUE" },
295 self.path,
296 if self.secure { "TRUE" } else { "FALSE" },
297 self.expires
298 .map(|dt| dt.timestamp().to_string())
299 .unwrap_or_else(|| "0".to_string()),
300 self.name,
301 self.value
302 )
303 }
304
305 pub fn from_netscape_line(line: &str) -> Result<Self> {
306 let parts: Vec<&str> = line.split('\t').collect();
307 if parts.len() < 7 {
308 return Err(Error::CookieParse(format!(
309 "Invalid Netscape format: expected 7 fields, got {}",
310 parts.len()
311 )));
312 }
313 let include_subdomains = parts[1].eq_ignore_ascii_case("true");
316 Ok(Cookie {
317 name: parts[5].to_string(),
318 value: parts[6].to_string(),
319 domain: normalize_domain(parts[0]),
320 path: parts[2].to_string(),
321 secure: parts[3].eq_ignore_ascii_case("true"),
322 http_only: false,
323 same_site: None,
324 expires: parts[4]
325 .parse::<i64>()
326 .ok()
327 .filter(|&ts| ts > 0)
328 .and_then(|ts| Utc.timestamp_opt(ts, 0).single()),
329 max_age: None,
330 host_only: !include_subdomains,
331 source_url: None,
332 raw_header: None,
333 creation_time: Utc::now(),
334 })
335 }
336
337 pub fn value_hash(&self) -> String {
338 hash_cookie_value(&self.value)
339 }
340}
341
342pub fn hash_cookie_value(value: &str) -> String {
355 use sha2::{Digest, Sha256};
356 let result = Sha256::digest(value.as_bytes());
357 result[..4].iter().map(|b| format!("{:02x}", b)).collect()
358}
359
360impl fmt::Display for Cookie {
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 write!(f, "{}={}", self.name, self.value)
363 }
364}
365
366#[derive(Debug, Default, Clone)]
368pub struct CookieJar {
369 cookies: HashMap<String, Vec<Cookie>>,
370}
371
372impl CookieJar {
373 pub fn new() -> Self {
374 Self::default()
375 }
376
377 pub fn store(&mut self, cookie: Cookie) {
378 let list = self.cookies.entry(cookie.domain.clone()).or_default();
379
380 if let Some(pos) = list
386 .iter()
387 .position(|c| c.name == cookie.name && c.path == cookie.path)
388 {
389 list[pos] = cookie;
390 } else {
391 list.push(cookie);
392 }
393 }
394
395 pub fn add(&mut self, cookie: Cookie) {
396 self.store(cookie);
397 }
398
399 pub fn cookies(&self) -> Vec<&Cookie> {
400 self.cookies.values().flat_map(|v| v.iter()).collect()
401 }
402
403 pub fn cookies_for_url(&self, url: &str) -> Vec<&Cookie> {
404 self.cookies
405 .values()
406 .flat_map(|v| v.iter())
407 .filter(|c| c.matches_url(url))
408 .collect()
409 }
410
411 pub fn build_cookie_header(&self, url: &str) -> Option<String> {
412 let mut cookies = self.cookies_for_url(url);
413 if cookies.is_empty() {
414 return None;
415 }
416
417 cookies.sort_by(|a, b| {
419 b.path
420 .len()
421 .cmp(&a.path.len())
422 .then_with(|| a.creation_time.cmp(&b.creation_time))
423 });
424
425 Some(
426 cookies
427 .iter()
428 .map(|c| format!("{}={}", c.name, c.value))
429 .collect::<Vec<_>>()
430 .join("; "),
431 )
432 }
433
434 pub fn store_from_headers(&mut self, headers: &Headers, request_url: &str) {
435 for (name, value) in headers.iter() {
436 if name.eq_ignore_ascii_case("set-cookie") {
437 if let Ok(cookie) = Cookie::from_set_cookie_header(value.trim(), request_url) {
438 self.store(cookie);
439 }
440 }
441 }
442 }
443
444 pub async fn save_to_file(&self, path: impl AsRef<Path>) -> Result<()> {
445 let mut file = tokio::fs::File::create(path).await.map_err(Error::Io)?;
446 file.write_all(b"# Netscape HTTP Cookie File\n")
447 .await
448 .map_err(Error::Io)?;
449 for cookies in self.cookies.values() {
450 for cookie in cookies {
451 let line = format!("{}\n", cookie.to_netscape_line());
452 file.write_all(line.as_bytes()).await.map_err(Error::Io)?;
453 }
454 }
455 Ok(())
456 }
457
458 pub async fn load_from_file(&mut self, path: impl AsRef<Path>) -> Result<()> {
459 let file = tokio::fs::File::open(path).await.map_err(Error::Io)?;
460 let mut reader = BufReader::new(file);
461 let mut line = String::new();
462 while reader.read_line(&mut line).await.map_err(Error::Io)? > 0 {
463 let trimmed = line.trim_end();
464 if !trimmed.is_empty() && !trimmed.starts_with('#') {
465 if let Ok(cookie) = Cookie::from_netscape_line(trimmed) {
466 self.store(cookie);
467 }
468 }
469 line.clear();
470 }
471 Ok(())
472 }
473
474 pub fn get(&self, domain: &str, name: &str) -> Option<&Cookie> {
475 self.cookies
477 .get(&normalize_domain(domain))?
478 .iter()
479 .find(|c| c.name == name)
480 }
481
482 pub fn remove(&mut self, domain: &str, name: &str) -> Option<Cookie> {
483 let list = self.cookies.get_mut(&normalize_domain(domain))?;
485 list.iter()
486 .position(|c| c.name == name)
487 .map(|pos| list.remove(pos))
488 }
489
490 pub fn clear(&mut self) {
491 self.cookies.clear();
492 }
493 pub fn len(&self) -> usize {
494 self.cookies.values().map(|v| v.len()).sum()
495 }
496 pub fn is_empty(&self) -> bool {
497 self.cookies.is_empty()
498 }
499}
500
501fn normalize_domain(domain: &str) -> String {
502 domain
503 .trim_start_matches('.')
504 .trim_end_matches('.')
505 .to_lowercase()
506}
507
508fn parse_cookie_date(date_str: &str) -> Option<DateTime<Utc>> {
509 const FORMATS: &[&str] = &[
512 "%a, %d %b %Y %H:%M:%S GMT", "%A, %d-%b-%y %H:%M:%S GMT", "%a %b %e %H:%M:%S %Y", "%a, %d-%b-%Y %H:%M:%S GMT", "%d %b %Y %H:%M:%S GMT", "%a, %d %b %Y %H:%M:%S %z", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M:%S%.fZ", ];
521
522 for fmt in FORMATS {
523 if let Ok(dt) = chrono::DateTime::parse_from_str(date_str, fmt) {
524 return Some(dt.with_timezone(&Utc));
525 }
526 if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(date_str, fmt) {
527 return Some(chrono::TimeZone::from_utc_datetime(&Utc, &dt));
528 }
529 }
530
531 date_str
533 .parse::<i64>()
534 .ok()
535 .and_then(|ts| Utc.timestamp_opt(ts, 0).single())
536}
537
538fn is_public_suffix(domain: &str) -> bool {
541 let domain_clean = domain.strip_prefix('.').unwrap_or(domain);
543
544 psl::suffix(domain_clean.as_bytes())
546 .map(|suffix| {
547 suffix.is_known() && suffix.as_bytes() == domain_clean.as_bytes()
549 })
550 .unwrap_or(false)
551}