1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Instant;
6use tokio::sync::RwLock;
7
8#[derive(Debug, Clone)]
10pub struct AltSvcEntry {
11 pub protocol: String,
13 pub host: Option<String>,
15 pub port: u16,
17 pub max_age: u64,
19 pub received_at: Instant,
21 pub persist: bool,
23}
24
25impl AltSvcEntry {
26 pub fn is_expired(&self) -> bool {
28 let age = self.received_at.elapsed().as_secs();
29 age >= self.max_age
30 }
31
32 pub fn is_h3(&self) -> bool {
34 self.protocol == "h3" || self.protocol.starts_with("h3-")
35 }
36}
37
38pub struct AltSvcCache {
40 entries: Arc<RwLock<HashMap<String, Vec<AltSvcEntry>>>>,
41 default_max_age: u64,
42}
43
44impl AltSvcCache {
45 pub fn new() -> Self {
46 Self {
47 entries: Arc::new(RwLock::new(HashMap::new())),
48 default_max_age: 86400, }
50 }
51
52 pub async fn parse_and_store(&self, origin: &str, header: &str) -> Vec<AltSvcEntry> {
54 if header.trim() == "clear" {
56 self.clear_origin(origin).await;
57 return vec![];
58 }
59
60 let entries = parse_alt_svc(header, self.default_max_age);
61
62 if !entries.is_empty() {
63 let mut cache = self.entries.write().await;
64 cache.insert(origin.to_string(), entries.clone());
65 }
66
67 entries
68 }
69
70 pub async fn get_h3_alternative(&self, origin: &str) -> Option<AltSvcEntry> {
72 let cache = self.entries.read().await;
73 cache.get(origin).and_then(|entries| {
74 entries
75 .iter()
76 .find(|e| e.is_h3() && !e.is_expired())
77 .cloned()
78 })
79 }
80
81 pub async fn clear_origin(&self, origin: &str) {
83 let mut cache = self.entries.write().await;
84 cache.remove(origin);
85 }
86
87 pub async fn cleanup_expired(&self) {
89 let mut cache = self.entries.write().await;
90 for entries in cache.values_mut() {
91 entries.retain(|e| !e.is_expired());
92 }
93 cache.retain(|_, entries| !entries.is_empty());
94 }
95}
96
97pub fn parse_alt_svc(header: &str, default_max_age: u64) -> Vec<AltSvcEntry> {
108 let mut entries = Vec::new();
109 let received_at = Instant::now();
110
111 let alternatives: Vec<&str> = header.split(',').collect();
113
114 for alt in alternatives {
115 let alt = alt.trim();
116 if alt.is_empty() {
117 continue;
118 }
119
120 let parts: Vec<&str> = alt.split(';').collect();
122 if parts.is_empty() {
123 continue;
124 }
125
126 let main_part = parts[0].trim();
127
128 let Some(equals_pos) = main_part.find('=') else {
130 continue; };
132
133 let protocol = main_part[..equals_pos].trim();
134 if protocol.is_empty() {
135 continue;
136 }
137
138 let value_part = main_part[equals_pos + 1..].trim();
139
140 let (host, port) = match parse_quoted_value(value_part) {
142 Some((h, p)) => (h, p),
143 None => continue, };
145
146 let mut max_age = default_max_age;
148 let mut persist = false;
149
150 for param_part in parts.iter().skip(1) {
151 let param_part = param_part.trim();
152 if param_part.is_empty() {
153 continue;
154 }
155
156 if let Some(param_equals) = param_part.find('=') {
158 let key = param_part[..param_equals].trim();
159 let value = param_part[param_equals + 1..].trim();
160
161 match key {
162 "ma" => {
163 if let Ok(age) = value.parse::<u64>() {
164 max_age = age;
165 }
166 }
167 "persist" => {
168 persist = value == "1" || value.eq_ignore_ascii_case("true");
169 }
170 _ => {
171 }
173 }
174 }
175 }
176
177 entries.push(AltSvcEntry {
178 protocol: protocol.to_string(),
179 host,
180 port,
181 max_age,
182 received_at,
183 persist,
184 });
185 }
186
187 entries
188}
189
190fn parse_quoted_value(value: &str) -> Option<(Option<String>, u16)> {
197 let value = value.trim();
198
199 let unquoted = if value.starts_with('"') && value.ends_with('"') {
201 &value[1..value.len() - 1]
202 } else {
203 value
204 };
205
206 let unquoted = unquoted.trim();
207
208 if let Some(port_str) = unquoted.strip_prefix(':') {
210 if let Ok(port) = port_str.parse::<u16>() {
211 return Some((None, port));
212 }
213 return None;
214 }
215
216 if unquoted.parse::<u16>().is_ok() && unquoted.chars().all(|c| c.is_ascii_digit()) {
218 if let Ok(port) = unquoted.parse::<u16>() {
219 return Some((None, port));
220 }
221 }
222
223 if let Some(colon_pos) = unquoted.rfind(':') {
225 let host = unquoted[..colon_pos].trim();
226 let port_str = unquoted[colon_pos + 1..].trim();
227
228 if host.is_empty() {
229 if let Ok(port) = port_str.parse::<u16>() {
231 return Some((None, port));
232 }
233 return None;
234 }
235
236 if let Ok(port) = port_str.parse::<u16>() {
237 return Some((Some(host.to_string()), port));
238 }
239 } else {
240 if !unquoted.is_empty() {
242 return Some((Some(unquoted.to_string()), 443));
243 }
244 }
245
246 None
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_parse_simple_h3() {
255 let header = r#"h3=":443"; ma=86400"#;
256 let entries = parse_alt_svc(header, 3600);
257
258 assert_eq!(entries.len(), 1);
259 assert_eq!(entries[0].protocol, "h3");
260 assert_eq!(entries[0].host, None);
261 assert_eq!(entries[0].port, 443);
262 assert_eq!(entries[0].max_age, 86400);
263 assert!(entries[0].is_h3());
264 }
265
266 #[test]
267 fn test_parse_with_host() {
268 let header = r#"h3="alt.example.com:443"; ma=3600; persist=1"#;
269 let entries = parse_alt_svc(header, 86400);
270
271 assert_eq!(entries.len(), 1);
272 assert_eq!(entries[0].protocol, "h3");
273 assert_eq!(entries[0].host, Some("alt.example.com".to_string()));
274 assert_eq!(entries[0].port, 443);
275 assert_eq!(entries[0].max_age, 3600);
276 assert!(entries[0].persist);
277 }
278
279 #[test]
280 fn test_parse_multiple_alternatives() {
281 let header = r#"h3=":443"; ma=86400, h3-29=":443"; ma=86400"#;
282 let entries = parse_alt_svc(header, 3600);
283
284 assert_eq!(entries.len(), 2);
285 assert_eq!(entries[0].protocol, "h3");
286 assert_eq!(entries[1].protocol, "h3-29");
287 assert!(entries[0].is_h3());
288 assert!(entries[1].is_h3());
289 }
290
291 #[test]
292 fn test_parse_mixed_protocols() {
293 let header = r#"h3=":443", h2=":443""#;
294 let entries = parse_alt_svc(header, 86400);
295
296 assert_eq!(entries.len(), 2);
297 assert_eq!(entries[0].protocol, "h3");
298 assert_eq!(entries[1].protocol, "h2");
299 assert!(entries[0].is_h3());
300 assert!(!entries[1].is_h3());
301 }
302
303 #[test]
304 fn test_parse_without_quotes() {
305 let header = r#"h3=:443; ma=86400"#;
307 let entries = parse_alt_svc(header, 3600);
308
309 assert_eq!(entries.len(), 1);
311 assert_eq!(entries[0].protocol, "h3");
312 assert_eq!(entries[0].port, 443);
313 }
314
315 #[test]
316 fn test_parse_default_max_age() {
317 let header = r#"h3=":443""#;
318 let entries = parse_alt_svc(header, 7200);
319
320 assert_eq!(entries.len(), 1);
321 assert_eq!(entries[0].max_age, 7200); }
323
324 #[test]
325 fn test_parse_persist_false() {
326 let header = r#"h3=":443"; persist=0"#;
327 let entries = parse_alt_svc(header, 86400);
328
329 assert_eq!(entries.len(), 1);
330 assert!(!entries[0].persist);
331 }
332
333 #[test]
334 fn test_parse_persist_true() {
335 let header = r#"h3=":443"; persist=1"#;
336 let entries = parse_alt_svc(header, 86400);
337
338 assert_eq!(entries.len(), 1);
339 assert!(entries[0].persist);
340 }
341
342 #[test]
343 fn test_parse_custom_port() {
344 let header = r#"h3="alt.com:8443"; ma=86400"#;
345 let entries = parse_alt_svc(header, 3600);
346
347 assert_eq!(entries.len(), 1);
348 assert_eq!(entries[0].host, Some("alt.com".to_string()));
349 assert_eq!(entries[0].port, 8443);
350 }
351
352 #[test]
353 fn test_parse_host_without_port() {
354 let header = r#"h3="alt.example.com""#;
355 let entries = parse_alt_svc(header, 86400);
356
357 assert_eq!(entries.len(), 1);
358 assert_eq!(entries[0].host, Some("alt.example.com".to_string()));
359 assert_eq!(entries[0].port, 443); }
361
362 #[test]
363 fn test_parse_malformed_entries() {
364 let header = r#"=":443""#;
366 let entries = parse_alt_svc(header, 86400);
367 assert_eq!(entries.len(), 0);
368
369 let header = r#"h3":443""#;
371 let entries = parse_alt_svc(header, 86400);
372 assert_eq!(entries.len(), 0);
373
374 let header = r#"h3=":99999""#;
376 let entries = parse_alt_svc(header, 86400);
377 assert_eq!(entries.len(), 0);
378 }
379
380 #[test]
381 fn test_parse_empty_and_whitespace() {
382 let header = "";
383 let entries = parse_alt_svc(header, 86400);
384 assert_eq!(entries.len(), 0);
385
386 let header = " ";
387 let entries = parse_alt_svc(header, 86400);
388 assert_eq!(entries.len(), 0);
389
390 let header = r#"h3=":443", , h2=":443""#;
391 let entries = parse_alt_svc(header, 86400);
392 assert_eq!(entries.len(), 2); }
394
395 #[tokio::test]
396 async fn test_cache_operations() {
397 let cache = AltSvcCache::new();
398
399 let header = r#"h3=":443"; ma=3600"#;
401 let entries = cache.parse_and_store("https://example.com", header).await;
402 assert_eq!(entries.len(), 1);
403
404 let h3_entry = cache.get_h3_alternative("https://example.com").await;
406 assert!(h3_entry.is_some());
407 assert_eq!(h3_entry.unwrap().protocol, "h3");
408
409 cache.clear_origin("https://example.com").await;
411 let h3_entry = cache.get_h3_alternative("https://example.com").await;
412 assert!(h3_entry.is_none());
413 }
414
415 #[tokio::test]
416 async fn test_cache_clear_directive() {
417 let cache = AltSvcCache::new();
418
419 let header = r#"h3=":443"; ma=3600"#;
421 cache.parse_and_store("https://example.com", header).await;
422
423 let entries = cache.parse_and_store("https://example.com", "clear").await;
425 assert_eq!(entries.len(), 0);
426
427 let h3_entry = cache.get_h3_alternative("https://example.com").await;
429 assert!(h3_entry.is_none());
430 }
431
432 #[tokio::test]
433 async fn test_cache_expiration() {
434 let cache = AltSvcCache::new();
435
436 let header = r#"h3=":443"; ma=1"#;
438 cache.parse_and_store("https://example.com", header).await;
439
440 let h3_entry = cache.get_h3_alternative("https://example.com").await;
442 assert!(h3_entry.is_some());
443
444 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
446
447 let h3_entry = cache.get_h3_alternative("https://example.com").await;
449 assert!(h3_entry.is_none());
450
451 cache.cleanup_expired().await;
453 let h3_entry = cache.get_h3_alternative("https://example.com").await;
454 assert!(h3_entry.is_none());
455 }
456}