Skip to main content

h11/
_headers.rs

1use std::collections::HashSet;
2
3use crate::{
4    _abnf::{FIELD_NAME, FIELD_VALUE},
5    _events::Request,
6    _util::ProtocolError,
7};
8use lazy_static::lazy_static;
9use regex::bytes::Regex;
10
11lazy_static! {
12    static ref CONTENT_LENGTH_RE: Regex = Regex::new(r"^[0-9]+$").unwrap();
13    static ref FIELD_NAME_RE: Regex = Regex::new(&format!(r"^{}$", FIELD_NAME)).unwrap();
14    static ref FIELD_VALUE_RE: Regex = Regex::new(&format!(r"^{}$", *FIELD_VALUE)).unwrap();
15}
16
17fn trim_ascii_whitespace(value: &[u8]) -> &[u8] {
18    let start = value
19        .iter()
20        .position(|byte| !byte.is_ascii_whitespace())
21        .unwrap_or(value.len());
22    let end = value
23        .iter()
24        .rposition(|byte| !byte.is_ascii_whitespace())
25        .map(|idx| idx + 1)
26        .unwrap_or(start);
27    &value[start..end]
28}
29
30/// HTTP header collection.
31///
32/// Header names are stored in normalized lowercase form for lookup, while the
33/// original raw casing is retained for serialization.
34#[derive(Clone, PartialEq, Eq, Hash, Default, PartialOrd, Ord)]
35pub struct Headers(Vec<(Vec<u8>, Vec<u8>, Vec<u8>)>);
36
37impl std::fmt::Debug for Headers {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        let mut debug_struct = f.debug_struct("Headers");
40        self.0.iter().for_each(|(raw_name, _, value)| {
41            debug_struct.field(
42                &String::from_utf8_lossy(raw_name),
43                &String::from_utf8_lossy(value),
44            );
45        });
46        debug_struct.finish()
47    }
48}
49
50impl Headers {
51    /// Returns normalized `(name, value)` pairs.
52    ///
53    /// Names are lowercase. Values preserve their stored bytes.
54    pub fn iter(&self) -> impl Iterator<Item = (Vec<u8>, Vec<u8>)> + '_ {
55        self.0
56            .iter()
57            .map(|(_, name, value)| ((*name).clone(), (*value).clone()))
58    }
59
60    /// Returns raw `(raw_name, normalized_name, value)` header triples.
61    pub fn raw_items(&self) -> Vec<&(Vec<u8>, Vec<u8>, Vec<u8>)> {
62        self.0.iter().collect()
63    }
64
65    /// Returns the number of header fields.
66    pub fn len(&self) -> usize {
67        self.0.len()
68    }
69
70    /// Returns true when the collection has no header fields.
71    pub fn is_empty(&self) -> bool {
72        self.0.is_empty()
73    }
74
75    /// Builds and validates a header collection from byte-like name/value pairs.
76    ///
77    /// This validates field syntax, normalizes names for lookup, preserves raw
78    /// name casing for output, and enforces `Content-Length` /
79    /// `Transfer-Encoding` consistency.
80    pub fn new<I, N, V>(headers: I) -> Result<Self, ProtocolError>
81    where
82        I: IntoIterator<Item = (N, V)>,
83        N: AsRef<[u8]>,
84        V: AsRef<[u8]>,
85    {
86        normalize_and_validate(
87            headers
88                .into_iter()
89                .map(|(name, value)| (name.as_ref().to_vec(), value.as_ref().to_vec()))
90                .collect(),
91            false,
92        )
93    }
94}
95
96impl From<Vec<(Vec<u8>, Vec<u8>)>> for Headers {
97    /// Builds headers from owned byte vectors.
98    ///
99    /// This conversion panics if the headers are invalid. Prefer
100    /// [`Headers::new`] when handling untrusted or fallible input.
101    fn from(value: Vec<(Vec<u8>, Vec<u8>)>) -> Self {
102        Headers::new(value)
103            .expect("invalid HTTP header list; use Headers::new for fallible construction")
104    }
105}
106
107/// Normalizes and validates HTTP header fields.
108///
109/// This is primarily used by parsers and [`Headers::new`]. The `_parsed`
110/// argument skips field syntax checks for already-parsed wire input.
111pub fn normalize_and_validate(
112    headers: Vec<(Vec<u8>, Vec<u8>)>,
113    _parsed: bool,
114) -> Result<Headers, ProtocolError> {
115    let mut new_headers = vec![];
116    let mut seen_content_length = None;
117    let mut saw_transfer_encoding = false;
118    for (name, value) in headers {
119        if !_parsed {
120            if !FIELD_NAME_RE.is_match(&name) {
121                return Err(ProtocolError::LocalProtocolError(
122                    format!("Illegal header name {:?}", &name).into(),
123                ));
124            }
125            if !FIELD_VALUE_RE.is_match(&value) {
126                return Err(ProtocolError::LocalProtocolError(
127                    format!("Illegal header value {:?}", &value).into(),
128                ));
129            }
130        }
131        let raw_name = name.clone();
132        let name = name.to_ascii_lowercase();
133        if name == b"content-length" {
134            let lengths: HashSet<Vec<u8>> = value
135                .split(|&b| b == b',')
136                .map(|length| trim_ascii_whitespace(length).to_vec())
137                .collect();
138            if lengths.len() != 1 {
139                return Err(ProtocolError::LocalProtocolError(
140                    "conflicting Content-Length headers".into(),
141                ));
142            }
143            let value = lengths.iter().next().unwrap();
144            if !CONTENT_LENGTH_RE.is_match(value) {
145                return Err(ProtocolError::LocalProtocolError(
146                    "bad Content-Length".into(),
147                ));
148            }
149            if seen_content_length.is_none() {
150                seen_content_length = Some(value.clone());
151                new_headers.push((raw_name, name, value.clone()));
152            } else if seen_content_length != Some(value.clone()) {
153                return Err(ProtocolError::LocalProtocolError(
154                    "conflicting Content-Length headers".into(),
155                ));
156            }
157        } else if name == b"transfer-encoding" {
158            // "A server that receives a request message with a transfer coding
159            // it does not understand SHOULD respond with 501 (Not
160            // Implemented)."
161            // https://www.rfc-editor.org/rfc/rfc9112.html#section-6.1
162            if saw_transfer_encoding {
163                return Err(ProtocolError::LocalProtocolError(
164                    ("multiple Transfer-Encoding headers", 501).into(),
165                ));
166            }
167            // "All transfer-coding names are case-insensitive"
168            // -- https://www.rfc-editor.org/rfc/rfc9112.html#section-7
169            let value = value.to_ascii_lowercase();
170            if value != b"chunked" {
171                return Err(ProtocolError::LocalProtocolError(
172                    ("Only Transfer-Encoding: chunked is supported", 501).into(),
173                ));
174            }
175            saw_transfer_encoding = true;
176            new_headers.push((raw_name, name, value));
177        } else {
178            new_headers.push((raw_name, name, value.to_vec()));
179        }
180    }
181
182    Ok(Headers(new_headers))
183}
184
185/// Reads a comma-separated header value as lowercase trimmed byte values.
186pub fn get_comma_header(headers: &Headers, name: &[u8]) -> Vec<Vec<u8>> {
187    let mut out: Vec<Vec<u8>> = vec![];
188    let name = name.to_ascii_lowercase();
189    for (found_name, found_value) in headers.iter() {
190        if found_name == name {
191            for found_split_value in found_value.to_ascii_lowercase().split(|&b| b == b',') {
192                let found_split_value = trim_ascii_whitespace(found_split_value);
193                if !found_split_value.is_empty() {
194                    out.push(found_split_value.to_vec());
195                }
196            }
197        }
198    }
199    out
200}
201
202/// Replaces all instances of a comma-separated header.
203pub fn set_comma_header(
204    headers: &Headers,
205    name: &[u8],
206    new_values: Vec<Vec<u8>>,
207) -> Result<Headers, ProtocolError> {
208    let mut new_headers = vec![];
209    for (found_name, found_value) in headers.iter() {
210        if found_name != name {
211            new_headers.push((found_name, found_value));
212        }
213    }
214    for new_value in new_values {
215        new_headers.push((name.to_vec(), new_value));
216    }
217    normalize_and_validate(new_headers, false)
218}
219
220/// Returns whether a request contains an active `Expect: 100-continue`.
221pub fn has_expect_100_continue(request: &Request) -> bool {
222    // https://www.rfc-editor.org/rfc/rfc9110.html#section-10.1.1
223    // "A server that receives a 100-continue expectation in an HTTP/1.0 request
224    // MUST ignore that expectation."
225    if request.http_version < b"1.1".to_vec() {
226        return false;
227    }
228    let expect = get_comma_header(&request.headers, b"expect");
229    expect.contains(&b"100-continue".to_vec())
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_headers_new_rejects_invalid_input() {
238        assert!(Headers::new(vec![(b"bad header".to_vec(), b"value".to_vec())]).is_err());
239    }
240
241    #[test]
242    fn test_non_utf8_comma_headers_do_not_panic() {
243        assert_eq!(
244            normalize_and_validate(vec![(b"Content-Length".to_vec(), b"\xff".to_vec())], true)
245                .unwrap_err(),
246            ProtocolError::LocalProtocolError("bad Content-Length".into())
247        );
248
249        let headers = normalize_and_validate(
250            vec![(b"Connection".to_vec(), b"close, \xff".to_vec())],
251            true,
252        )
253        .unwrap();
254        assert_eq!(
255            get_comma_header(&headers, b"connection"),
256            vec![b"close".to_vec(), b"\xff".to_vec()]
257        );
258    }
259
260    #[test]
261    fn test_headers_new_accepts_borrowed_inputs() {
262        assert_eq!(
263            Headers::new([("Host", "example.com"), ("Accept", "*/*")]).unwrap(),
264            Headers(vec![
265                (b"Host".to_vec(), b"host".to_vec(), b"example.com".to_vec()),
266                (b"Accept".to_vec(), b"accept".to_vec(), b"*/*".to_vec()),
267            ])
268        );
269        assert_eq!(
270            Headers::new([(b"Host".as_slice(), b"example.com".as_slice())]).unwrap(),
271            Headers(vec![(
272                b"Host".to_vec(),
273                b"host".to_vec(),
274                b"example.com".to_vec()
275            )])
276        );
277    }
278
279    #[test]
280    fn test_normalize_and_validate() {
281        assert_eq!(
282            normalize_and_validate(vec![(b"foo".to_vec(), b"bar".to_vec())], false).unwrap(),
283            Headers(vec![(b"foo".to_vec(), b"foo".to_vec(), b"bar".to_vec())])
284        );
285
286        // no leading/trailing whitespace in names
287        assert_eq!(
288            normalize_and_validate(vec![(b"foo ".to_vec(), b"bar".to_vec())], false)
289                .expect_err("Expect ProtocolError::LocalProtocolError"),
290            ProtocolError::LocalProtocolError(
291                ("Illegal header name [102, 111, 111, 32]".to_string(), 400).into()
292            )
293        );
294        assert_eq!(
295            normalize_and_validate(vec![(b" foo".to_vec(), b"bar".to_vec())], false)
296                .expect_err("Expect ProtocolError::LocalProtocolError"),
297            ProtocolError::LocalProtocolError(
298                ("Illegal header name [32, 102, 111, 111]".to_string(), 400).into()
299            )
300        );
301
302        // no weird characters in names
303        assert_eq!(
304            normalize_and_validate(vec![(b"foo bar".to_vec(), b"baz".to_vec())], false)
305                .expect_err("Expect ProtocolError::LocalProtocolError"),
306            ProtocolError::LocalProtocolError(
307                (
308                    "Illegal header name [102, 111, 111, 32, 98, 97, 114]".to_string(),
309                    400
310                )
311                    .into()
312            )
313        );
314        assert_eq!(
315            normalize_and_validate(vec![(b"foo\x00bar".to_vec(), b"baz".to_vec())], false)
316                .expect_err("Expect ProtocolError::LocalProtocolError"),
317            ProtocolError::LocalProtocolError(
318                (
319                    "Illegal header name [102, 111, 111, 0, 98, 97, 114]".to_string(),
320                    400
321                )
322                    .into()
323            )
324        );
325        // Not even 8-bit characters:
326        assert_eq!(
327            normalize_and_validate(vec![(b"foo\xffbar".to_vec(), b"baz".to_vec())], false)
328                .expect_err("Expect ProtocolError::LocalProtocolError"),
329            ProtocolError::LocalProtocolError(
330                (
331                    "Illegal header name [102, 111, 111, 255, 98, 97, 114]".to_string(),
332                    400
333                )
334                    .into()
335            )
336        );
337        // And not even the control characters we allow in values:
338        assert_eq!(
339            normalize_and_validate(vec![(b"foo\x01bar".to_vec(), b"baz".to_vec())], false)
340                .expect_err("Expect ProtocolError::LocalProtocolError"),
341            ProtocolError::LocalProtocolError(
342                (
343                    "Illegal header name [102, 111, 111, 1, 98, 97, 114]".to_string(),
344                    400
345                )
346                    .into()
347            )
348        );
349
350        // no return or NUL characters in values
351        assert_eq!(
352            normalize_and_validate(vec![(b"foo".to_vec(), b"bar\rbaz".to_vec())], false)
353                .expect_err("Expect ProtocolError::LocalProtocolError"),
354            ProtocolError::LocalProtocolError(
355                (
356                    "Illegal header value [98, 97, 114, 13, 98, 97, 122]".to_string(),
357                    400
358                )
359                    .into()
360            )
361        );
362        assert_eq!(
363            normalize_and_validate(vec![(b"foo".to_vec(), b"bar\nbaz".to_vec())], false)
364                .expect_err("Expect ProtocolError::LocalProtocolError"),
365            ProtocolError::LocalProtocolError(
366                (
367                    "Illegal header value [98, 97, 114, 10, 98, 97, 122]".to_string(),
368                    400
369                )
370                    .into()
371            )
372        );
373        assert_eq!(
374            normalize_and_validate(vec![(b"foo".to_vec(), b"bar\x00baz".to_vec())], false)
375                .expect_err("Expect ProtocolError::LocalProtocolError"),
376            ProtocolError::LocalProtocolError(
377                (
378                    "Illegal header value [98, 97, 114, 0, 98, 97, 122]".to_string(),
379                    400
380                )
381                    .into()
382            )
383        );
384        // no leading/trailing whitespace
385        assert_eq!(
386            normalize_and_validate(vec![(b"foo".to_vec(), b"barbaz  ".to_vec())], false)
387                .expect_err("Expect ProtocolError::LocalProtocolError"),
388            ProtocolError::LocalProtocolError(
389                (
390                    "Illegal header value [98, 97, 114, 98, 97, 122, 32, 32]".to_string(),
391                    400
392                )
393                    .into()
394            )
395        );
396        assert_eq!(
397            normalize_and_validate(vec![(b"foo".to_vec(), b"  barbaz".to_vec())], false)
398                .expect_err("Expect ProtocolError::LocalProtocolError"),
399            ProtocolError::LocalProtocolError(
400                (
401                    "Illegal header value [32, 32, 98, 97, 114, 98, 97, 122]".to_string(),
402                    400
403                )
404                    .into()
405            )
406        );
407        assert_eq!(
408            normalize_and_validate(vec![(b"foo".to_vec(), b"barbaz\t".to_vec())], false)
409                .expect_err("Expect ProtocolError::LocalProtocolError"),
410            ProtocolError::LocalProtocolError(
411                (
412                    "Illegal header value [98, 97, 114, 98, 97, 122, 9]".to_string(),
413                    400
414                )
415                    .into()
416            )
417        );
418        assert_eq!(
419            normalize_and_validate(vec![(b"foo".to_vec(), b"\tbarbaz".to_vec())], false)
420                .expect_err("Expect ProtocolError::LocalProtocolError"),
421            ProtocolError::LocalProtocolError(
422                (
423                    "Illegal header value [9, 98, 97, 114, 98, 97, 122]".to_string(),
424                    400
425                )
426                    .into()
427            )
428        );
429
430        // content-length
431        assert_eq!(
432            normalize_and_validate(vec![(b"Content-Length".to_vec(), b"1".to_vec())], false)
433                .unwrap(),
434            Headers(vec![(
435                b"Content-Length".to_vec(),
436                b"content-length".to_vec(),
437                b"1".to_vec()
438            )])
439        );
440        assert_eq!(
441            normalize_and_validate(vec![(b"Content-Length".to_vec(), b"asdf".to_vec())], false)
442                .expect_err("Expect ProtocolError::LocalProtocolError"),
443            ProtocolError::LocalProtocolError(("bad Content-Length".to_string(), 400).into())
444        );
445        assert_eq!(
446            normalize_and_validate(vec![(b"Content-Length".to_vec(), b"1x".to_vec())], false)
447                .expect_err("Expect ProtocolError::LocalProtocolError"),
448            ProtocolError::LocalProtocolError(("bad Content-Length".to_string(), 400).into())
449        );
450        assert_eq!(
451            normalize_and_validate(
452                vec![
453                    (b"Content-Length".to_vec(), b"1".to_vec()),
454                    (b"Content-Length".to_vec(), b"2".to_vec())
455                ],
456                false
457            )
458            .expect_err("Expect ProtocolError::LocalProtocolError"),
459            ProtocolError::LocalProtocolError(
460                ("conflicting Content-Length headers".to_string(), 400).into()
461            )
462        );
463        assert_eq!(
464            normalize_and_validate(
465                vec![
466                    (b"Content-Length".to_vec(), b"0".to_vec()),
467                    (b"Content-Length".to_vec(), b"0".to_vec())
468                ],
469                false
470            )
471            .unwrap(),
472            Headers(vec![(
473                b"Content-Length".to_vec(),
474                b"content-length".to_vec(),
475                b"0".to_vec()
476            )])
477        );
478        assert_eq!(
479            normalize_and_validate(vec![(b"Content-Length".to_vec(), b"0 , 0".to_vec())], false)
480                .unwrap(),
481            Headers(vec![(
482                b"Content-Length".to_vec(),
483                b"content-length".to_vec(),
484                b"0".to_vec()
485            )])
486        );
487        assert_eq!(
488            normalize_and_validate(
489                vec![
490                    (b"Content-Length".to_vec(), b"1".to_vec()),
491                    (b"Content-Length".to_vec(), b"1".to_vec()),
492                    (b"Content-Length".to_vec(), b"2".to_vec())
493                ],
494                false
495            )
496            .expect_err("Expect ProtocolError::LocalProtocolError"),
497            ProtocolError::LocalProtocolError(
498                ("conflicting Content-Length headers".to_string(), 400).into()
499            )
500        );
501        assert_eq!(
502            normalize_and_validate(
503                vec![(b"Content-Length".to_vec(), b"1 , 1,2".to_vec())],
504                false
505            )
506            .expect_err("Expect ProtocolError::LocalProtocolError"),
507            ProtocolError::LocalProtocolError(
508                ("conflicting Content-Length headers".to_string(), 400).into()
509            )
510        );
511
512        // transfer-encoding
513        assert_eq!(
514            normalize_and_validate(
515                vec![(b"Transfer-Encoding".to_vec(), b"chunked".to_vec())],
516                false
517            )
518            .unwrap(),
519            Headers(vec![(
520                b"Transfer-Encoding".to_vec(),
521                b"transfer-encoding".to_vec(),
522                b"chunked".to_vec()
523            )])
524        );
525        assert_eq!(
526            normalize_and_validate(
527                vec![(b"Transfer-Encoding".to_vec(), b"cHuNkEd".to_vec())],
528                false
529            )
530            .unwrap(),
531            Headers(vec![(
532                b"Transfer-Encoding".to_vec(),
533                b"transfer-encoding".to_vec(),
534                b"chunked".to_vec()
535            )])
536        );
537        assert_eq!(
538            normalize_and_validate(
539                vec![(b"Transfer-Encoding".to_vec(), b"gzip".to_vec())],
540                false
541            )
542            .expect_err("Expect ProtocolError::LocalProtocolError"),
543            ProtocolError::LocalProtocolError(
544                (
545                    "Only Transfer-Encoding: chunked is supported".to_string(),
546                    501
547                )
548                    .into()
549            )
550        );
551        assert_eq!(
552            normalize_and_validate(
553                vec![
554                    (b"Transfer-Encoding".to_vec(), b"chunked".to_vec()),
555                    (b"Transfer-Encoding".to_vec(), b"gzip".to_vec())
556                ],
557                false
558            )
559            .expect_err("Expect ProtocolError::LocalProtocolError"),
560            ProtocolError::LocalProtocolError(
561                ("multiple Transfer-Encoding headers".to_string(), 501).into()
562            )
563        );
564    }
565
566    #[test]
567    fn test_get_set_comma_header() {
568        let headers = normalize_and_validate(
569            vec![
570                (b"Connection".to_vec(), b"close".to_vec()),
571                (b"whatever".to_vec(), b"something".to_vec()),
572                (b"connectiON".to_vec(), b"fOo,, , BAR".to_vec()),
573            ],
574            false,
575        )
576        .unwrap();
577
578        assert_eq!(
579            get_comma_header(&headers, b"connection"),
580            vec![b"close".to_vec(), b"foo".to_vec(), b"bar".to_vec()]
581        );
582
583        let headers =
584            set_comma_header(&headers, b"newthing", vec![b"a".to_vec(), b"b".to_vec()]).unwrap();
585
586        assert_eq!(
587            headers,
588            Headers(vec![
589                (
590                    b"connection".to_vec(),
591                    b"connection".to_vec(),
592                    b"close".to_vec()
593                ),
594                (
595                    b"whatever".to_vec(),
596                    b"whatever".to_vec(),
597                    b"something".to_vec()
598                ),
599                (
600                    b"connection".to_vec(),
601                    b"connection".to_vec(),
602                    b"fOo,, , BAR".to_vec()
603                ),
604                (b"newthing".to_vec(), b"newthing".to_vec(), b"a".to_vec()),
605                (b"newthing".to_vec(), b"newthing".to_vec(), b"b".to_vec()),
606            ])
607        );
608
609        let headers =
610            set_comma_header(&headers, b"whatever", vec![b"different thing".to_vec()]).unwrap();
611
612        assert_eq!(
613            headers,
614            Headers(vec![
615                (
616                    b"connection".to_vec(),
617                    b"connection".to_vec(),
618                    b"close".to_vec()
619                ),
620                (
621                    b"connection".to_vec(),
622                    b"connection".to_vec(),
623                    b"fOo,, , BAR".to_vec()
624                ),
625                (b"newthing".to_vec(), b"newthing".to_vec(), b"a".to_vec()),
626                (b"newthing".to_vec(), b"newthing".to_vec(), b"b".to_vec()),
627                (
628                    b"whatever".to_vec(),
629                    b"whatever".to_vec(),
630                    b"different thing".to_vec()
631                ),
632            ])
633        );
634    }
635
636    #[test]
637    fn test_has_100_continue() {
638        assert!(has_expect_100_continue(&Request {
639            method: b"GET".to_vec(),
640            target: b"/".to_vec(),
641            headers: normalize_and_validate(
642                vec![
643                    (b"Host".to_vec(), b"example.com".to_vec()),
644                    (b"Expect".to_vec(), b"100-continue".to_vec())
645                ],
646                false
647            )
648            .unwrap(),
649            http_version: b"1.1".to_vec(),
650        }));
651        assert!(!has_expect_100_continue(&Request {
652            method: b"GET".to_vec(),
653            target: b"/".to_vec(),
654            headers: normalize_and_validate(
655                vec![(b"Host".to_vec(), b"example.com".to_vec())],
656                false
657            )
658            .unwrap(),
659            http_version: b"1.1".to_vec(),
660        }));
661        // Case insensitive
662        assert!(has_expect_100_continue(&Request {
663            method: b"GET".to_vec(),
664            target: b"/".to_vec(),
665            headers: normalize_and_validate(
666                vec![
667                    (b"Host".to_vec(), b"example.com".to_vec()),
668                    (b"Expect".to_vec(), b"100-Continue".to_vec())
669                ],
670                false
671            )
672            .unwrap(),
673            http_version: b"1.1".to_vec(),
674        }));
675        // Doesn't work in HTTP/1.0
676        assert!(!has_expect_100_continue(&Request {
677            method: b"GET".to_vec(),
678            target: b"/".to_vec(),
679            headers: normalize_and_validate(
680                vec![
681                    (b"Host".to_vec(), b"example.com".to_vec()),
682                    (b"Expect".to_vec(), b"100-continue".to_vec())
683                ],
684                false
685            )
686            .unwrap(),
687            http_version: b"1.0".to_vec(),
688        }));
689    }
690}