Skip to main content

specter/pool/
alt_svc.rs

1//! Alt-Svc header parsing and caching for HTTP/3 discovery.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Instant;
6use tokio::sync::RwLock;
7
8/// Parsed Alt-Svc entry (RFC 7838)
9#[derive(Debug, Clone)]
10pub struct AltSvcEntry {
11    /// Protocol identifier (e.g., "h3", "h3-29", "h2")
12    pub protocol: String,
13    /// Alternative host (None means same host)
14    pub host: Option<String>,
15    /// Alternative port
16    pub port: u16,
17    /// Max age in seconds
18    pub max_age: u64,
19    /// When this entry was received
20    pub received_at: Instant,
21    /// Persist across restarts (persist parameter)
22    pub persist: bool,
23}
24
25impl AltSvcEntry {
26    /// Check if this entry has expired
27    pub fn is_expired(&self) -> bool {
28        let age = self.received_at.elapsed().as_secs();
29        age >= self.max_age
30    }
31
32    /// Check if this is HTTP/3
33    pub fn is_h3(&self) -> bool {
34        self.protocol == "h3" || self.protocol.starts_with("h3-")
35    }
36}
37
38/// Alt-Svc cache for HTTP/3 discovery
39pub struct AltSvcCache {
40    entries: Arc<RwLock<HashMap<String, Vec<AltSvcEntry>>>>,
41    default_max_age: u64,
42}
43
44impl AltSvcCache {
45    pub fn new() -> Self {
46        Self {
47            entries: Arc::new(RwLock::new(HashMap::new())),
48            default_max_age: 86400, // 24 hours default
49        }
50    }
51
52    /// Parse Alt-Svc header and store entries for origin
53    pub async fn parse_and_store(&self, origin: &str, header: &str) -> Vec<AltSvcEntry> {
54        // Handle "clear" directive
55        if header.trim() == "clear" {
56            self.clear_origin(origin).await;
57            return vec![];
58        }
59
60        let entries = parse_alt_svc(header, self.default_max_age);
61
62        if !entries.is_empty() {
63            let mut cache = self.entries.write().await;
64            cache.insert(origin.to_string(), entries.clone());
65        }
66
67        entries
68    }
69
70    /// Get best HTTP/3 alternative for origin
71    pub async fn get_h3_alternative(&self, origin: &str) -> Option<AltSvcEntry> {
72        let cache = self.entries.read().await;
73        cache.get(origin).and_then(|entries| {
74            entries
75                .iter()
76                .find(|e| e.is_h3() && !e.is_expired())
77                .cloned()
78        })
79    }
80
81    /// Clear entries for an origin
82    pub async fn clear_origin(&self, origin: &str) {
83        let mut cache = self.entries.write().await;
84        cache.remove(origin);
85    }
86
87    /// Remove expired entries from cache
88    pub async fn cleanup_expired(&self) {
89        let mut cache = self.entries.write().await;
90        for entries in cache.values_mut() {
91            entries.retain(|e| !e.is_expired());
92        }
93        cache.retain(|_, entries| !entries.is_empty());
94    }
95}
96
97/// Parse Alt-Svc header value into a vector of entries
98///
99/// # Examples
100///
101/// ```
102/// use specter::pool::alt_svc::parse_alt_svc;
103///
104/// let header = r#"h3=":443"; ma=86400, h3-29="alt.com:8443"; persist=1"#;
105/// let entries = parse_alt_svc(header, 3600);
106/// ```
107pub fn parse_alt_svc(header: &str, default_max_age: u64) -> Vec<AltSvcEntry> {
108    let mut entries = Vec::new();
109    let received_at = Instant::now();
110
111    // Split by commas to get individual alternatives
112    let alternatives: Vec<&str> = header.split(',').collect();
113
114    for alt in alternatives {
115        let alt = alt.trim();
116        if alt.is_empty() {
117            continue;
118        }
119
120        // Split into protocol=value and parameters
121        let parts: Vec<&str> = alt.split(';').collect();
122        if parts.is_empty() {
123            continue;
124        }
125
126        let main_part = parts[0].trim();
127
128        // Parse protocol=value
129        let Some(equals_pos) = main_part.find('=') else {
130            continue; // Skip malformed entries without =
131        };
132
133        let protocol = main_part[..equals_pos].trim();
134        if protocol.is_empty() {
135            continue;
136        }
137
138        let value_part = main_part[equals_pos + 1..].trim();
139
140        // Extract host:port from quoted value
141        let (host, port) = match parse_quoted_value(value_part) {
142            Some((h, p)) => (h, p),
143            None => continue, // Skip if value parsing fails
144        };
145
146        // Parse parameters (ma, persist)
147        let mut max_age = default_max_age;
148        let mut persist = false;
149
150        for param_part in parts.iter().skip(1) {
151            let param_part = param_part.trim();
152            if param_part.is_empty() {
153                continue;
154            }
155
156            // Parse key=value parameter
157            if let Some(param_equals) = param_part.find('=') {
158                let key = param_part[..param_equals].trim();
159                let value = param_part[param_equals + 1..].trim();
160
161                match key {
162                    "ma" => {
163                        if let Ok(age) = value.parse::<u64>() {
164                            max_age = age;
165                        }
166                    }
167                    "persist" => {
168                        persist = value == "1" || value.eq_ignore_ascii_case("true");
169                    }
170                    _ => {
171                        // Unknown parameter, ignore
172                    }
173                }
174            }
175        }
176
177        entries.push(AltSvcEntry {
178            protocol: protocol.to_string(),
179            host,
180            port,
181            max_age,
182            received_at,
183            persist,
184        });
185    }
186
187    entries
188}
189
190/// Parse quoted value to extract host and port
191///
192/// Handles formats:
193/// - `":443"` -> (None, 443)
194/// - `"alt.example.com:443"` -> (Some("alt.example.com"), 443)
195/// - `"alt.example.com"` -> (Some("alt.example.com"), 443) [default port]
196fn parse_quoted_value(value: &str) -> Option<(Option<String>, u16)> {
197    let value = value.trim();
198
199    // Remove quotes if present
200    let unquoted = if value.starts_with('"') && value.ends_with('"') {
201        &value[1..value.len() - 1]
202    } else {
203        value
204    };
205
206    let unquoted = unquoted.trim();
207
208    // Check if it starts with ':' (same-origin case)
209    if let Some(port_str) = unquoted.strip_prefix(':') {
210        if let Ok(port) = port_str.parse::<u16>() {
211            return Some((None, port));
212        }
213        return None;
214    }
215
216    // Check if it's a pure numeric string (treat as :port for robustness)
217    if unquoted.parse::<u16>().is_ok() && unquoted.chars().all(|c| c.is_ascii_digit()) {
218        if let Ok(port) = unquoted.parse::<u16>() {
219            return Some((None, port));
220        }
221    }
222
223    // Parse host:port
224    if let Some(colon_pos) = unquoted.rfind(':') {
225        let host = unquoted[..colon_pos].trim();
226        let port_str = unquoted[colon_pos + 1..].trim();
227
228        if host.is_empty() {
229            // Handle ":port" case (should have been caught above, but double-check)
230            if let Ok(port) = port_str.parse::<u16>() {
231                return Some((None, port));
232            }
233            return None;
234        }
235
236        if let Ok(port) = port_str.parse::<u16>() {
237            return Some((Some(host.to_string()), port));
238        }
239    } else {
240        // No port specified, assume default HTTPS port
241        if !unquoted.is_empty() {
242            return Some((Some(unquoted.to_string()), 443));
243        }
244    }
245
246    None
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_parse_simple_h3() {
255        let header = r#"h3=":443"; ma=86400"#;
256        let entries = parse_alt_svc(header, 3600);
257
258        assert_eq!(entries.len(), 1);
259        assert_eq!(entries[0].protocol, "h3");
260        assert_eq!(entries[0].host, None);
261        assert_eq!(entries[0].port, 443);
262        assert_eq!(entries[0].max_age, 86400);
263        assert!(entries[0].is_h3());
264    }
265
266    #[test]
267    fn test_parse_with_host() {
268        let header = r#"h3="alt.example.com:443"; ma=3600; persist=1"#;
269        let entries = parse_alt_svc(header, 86400);
270
271        assert_eq!(entries.len(), 1);
272        assert_eq!(entries[0].protocol, "h3");
273        assert_eq!(entries[0].host, Some("alt.example.com".to_string()));
274        assert_eq!(entries[0].port, 443);
275        assert_eq!(entries[0].max_age, 3600);
276        assert!(entries[0].persist);
277    }
278
279    #[test]
280    fn test_parse_multiple_alternatives() {
281        let header = r#"h3=":443"; ma=86400, h3-29=":443"; ma=86400"#;
282        let entries = parse_alt_svc(header, 3600);
283
284        assert_eq!(entries.len(), 2);
285        assert_eq!(entries[0].protocol, "h3");
286        assert_eq!(entries[1].protocol, "h3-29");
287        assert!(entries[0].is_h3());
288        assert!(entries[1].is_h3());
289    }
290
291    #[test]
292    fn test_parse_mixed_protocols() {
293        let header = r#"h3=":443", h2=":443""#;
294        let entries = parse_alt_svc(header, 86400);
295
296        assert_eq!(entries.len(), 2);
297        assert_eq!(entries[0].protocol, "h3");
298        assert_eq!(entries[1].protocol, "h2");
299        assert!(entries[0].is_h3());
300        assert!(!entries[1].is_h3());
301    }
302
303    #[test]
304    fn test_parse_without_quotes() {
305        // Some servers may omit quotes, handle gracefully
306        let header = r#"h3=:443; ma=86400"#;
307        let entries = parse_alt_svc(header, 3600);
308
309        // Should still parse (unquoted value)
310        assert_eq!(entries.len(), 1);
311        assert_eq!(entries[0].protocol, "h3");
312        assert_eq!(entries[0].port, 443);
313    }
314
315    #[test]
316    fn test_parse_default_max_age() {
317        let header = r#"h3=":443""#;
318        let entries = parse_alt_svc(header, 7200);
319
320        assert_eq!(entries.len(), 1);
321        assert_eq!(entries[0].max_age, 7200); // Uses default
322    }
323
324    #[test]
325    fn test_parse_persist_false() {
326        let header = r#"h3=":443"; persist=0"#;
327        let entries = parse_alt_svc(header, 86400);
328
329        assert_eq!(entries.len(), 1);
330        assert!(!entries[0].persist);
331    }
332
333    #[test]
334    fn test_parse_persist_true() {
335        let header = r#"h3=":443"; persist=1"#;
336        let entries = parse_alt_svc(header, 86400);
337
338        assert_eq!(entries.len(), 1);
339        assert!(entries[0].persist);
340    }
341
342    #[test]
343    fn test_parse_custom_port() {
344        let header = r#"h3="alt.com:8443"; ma=86400"#;
345        let entries = parse_alt_svc(header, 3600);
346
347        assert_eq!(entries.len(), 1);
348        assert_eq!(entries[0].host, Some("alt.com".to_string()));
349        assert_eq!(entries[0].port, 8443);
350    }
351
352    #[test]
353    fn test_parse_host_without_port() {
354        let header = r#"h3="alt.example.com""#;
355        let entries = parse_alt_svc(header, 86400);
356
357        assert_eq!(entries.len(), 1);
358        assert_eq!(entries[0].host, Some("alt.example.com".to_string()));
359        assert_eq!(entries[0].port, 443); // Default HTTPS port
360    }
361
362    #[test]
363    fn test_parse_malformed_entries() {
364        // Missing protocol
365        let header = r#"=":443""#;
366        let entries = parse_alt_svc(header, 86400);
367        assert_eq!(entries.len(), 0);
368
369        // Missing equals
370        let header = r#"h3":443""#;
371        let entries = parse_alt_svc(header, 86400);
372        assert_eq!(entries.len(), 0);
373
374        // Invalid port
375        let header = r#"h3=":99999""#;
376        let entries = parse_alt_svc(header, 86400);
377        assert_eq!(entries.len(), 0);
378    }
379
380    #[test]
381    fn test_parse_empty_and_whitespace() {
382        let header = "";
383        let entries = parse_alt_svc(header, 86400);
384        assert_eq!(entries.len(), 0);
385
386        let header = "   ";
387        let entries = parse_alt_svc(header, 86400);
388        assert_eq!(entries.len(), 0);
389
390        let header = r#"h3=":443", , h2=":443""#;
391        let entries = parse_alt_svc(header, 86400);
392        assert_eq!(entries.len(), 2); // Empty alternative skipped
393    }
394
395    #[tokio::test]
396    async fn test_cache_operations() {
397        let cache = AltSvcCache::new();
398
399        // Store entries
400        let header = r#"h3=":443"; ma=3600"#;
401        let entries = cache.parse_and_store("https://example.com", header).await;
402        assert_eq!(entries.len(), 1);
403
404        // Retrieve H3 alternative
405        let h3_entry = cache.get_h3_alternative("https://example.com").await;
406        assert!(h3_entry.is_some());
407        assert_eq!(h3_entry.unwrap().protocol, "h3");
408
409        // Clear origin
410        cache.clear_origin("https://example.com").await;
411        let h3_entry = cache.get_h3_alternative("https://example.com").await;
412        assert!(h3_entry.is_none());
413    }
414
415    #[tokio::test]
416    async fn test_cache_clear_directive() {
417        let cache = AltSvcCache::new();
418
419        // Store entries first
420        let header = r#"h3=":443"; ma=3600"#;
421        cache.parse_and_store("https://example.com", header).await;
422
423        // Clear directive
424        let entries = cache.parse_and_store("https://example.com", "clear").await;
425        assert_eq!(entries.len(), 0);
426
427        // Verify cleared
428        let h3_entry = cache.get_h3_alternative("https://example.com").await;
429        assert!(h3_entry.is_none());
430    }
431
432    #[tokio::test]
433    async fn test_cache_expiration() {
434        let cache = AltSvcCache::new();
435
436        // Store entry with very short max_age
437        let header = r#"h3=":443"; ma=1"#;
438        cache.parse_and_store("https://example.com", header).await;
439
440        // Should be available immediately
441        let h3_entry = cache.get_h3_alternative("https://example.com").await;
442        assert!(h3_entry.is_some());
443
444        // Wait for expiration
445        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
446
447        // Should be expired
448        let h3_entry = cache.get_h3_alternative("https://example.com").await;
449        assert!(h3_entry.is_none());
450
451        // Cleanup should remove expired entries
452        cache.cleanup_expired().await;
453        let h3_entry = cache.get_h3_alternative("https://example.com").await;
454        assert!(h3_entry.is_none());
455    }
456}