Skip to main content

typespec/http/
headers.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//! HTTP headers.
5
6// cspell:ignore hasher traceparent tracestate
7
8use crate::error::{Error, ErrorKind, ResultExt};
9use std::{
10    borrow::Cow, collections::HashSet, convert::Infallible, fmt, str::FromStr, sync::LazyLock,
11};
12
13/// Default set of allowed headers. Headers not in this list will be redacted.
14pub static DEFAULT_ALLOWED_HEADER_NAMES: LazyLock<HashSet<Cow<'static, str>>> =
15    LazyLock::new(|| {
16        [
17            "accept",
18            "cache-control",
19            "connection",
20            "content-length",
21            "content-type",
22            "date",
23            "etag",
24            "expires",
25            "if-match",
26            "if-modified-since",
27            "if-none-match",
28            "if-unmodified-since",
29            "last-modified",
30            "ms-cv",
31            "pragma",
32            "request-id",
33            "retry-after",
34            "server",
35            "traceparent",
36            "tracestate",
37            "transfer-encoding",
38            "user-agent",
39            "www-authenticate",
40            "x-ms-request-id",
41            "x-ms-client-request-id",
42            "x-ms-return-client-request-id",
43        ]
44        .iter()
45        .map(|s| Cow::Borrowed(*s))
46        .collect()
47    });
48
49/// A trait for converting a type into request headers.
50pub trait AsHeaders {
51    /// The error type which can occur when converting the type into headers.
52    type Error: std::error::Error + Send + Sync + 'static;
53
54    /// The iterator type which yields header name/value pairs.
55    type Iter: Iterator<Item = (HeaderName, HeaderValue)>;
56
57    /// Iterate over all the header name/value pairs.
58    fn as_headers(&self) -> Result<Self::Iter, Self::Error>;
59}
60
61impl<T> AsHeaders for T
62where
63    T: Header,
64{
65    type Error = Infallible;
66    type Iter = std::vec::IntoIter<(HeaderName, HeaderValue)>;
67
68    /// Iterate over all the header name/value pairs.
69    fn as_headers(&self) -> Result<Self::Iter, Self::Error> {
70        Ok(vec![(self.name(), self.value())].into_iter())
71    }
72}
73
74impl<T> AsHeaders for Option<T>
75where
76    T: AsHeaders<Iter = std::vec::IntoIter<(HeaderName, HeaderValue)>>,
77{
78    type Error = T::Error;
79    type Iter = T::Iter;
80
81    /// Iterate over all the header name/value pairs.
82    fn as_headers(&self) -> Result<Self::Iter, T::Error> {
83        match self {
84            Some(h) => h.as_headers(),
85            None => Ok(vec![].into_iter()),
86        }
87    }
88}
89
90/// Extract a value from the [`Headers`] collection.
91///
92/// The [`FromHeaders::from_headers()`] method is usually used implicitly, through [`Headers::get()`] or [`Headers::get_optional()`].
93pub trait FromHeaders: Sized {
94    /// The error type which can occur when extracting the value from headers.
95    type Error: std::error::Error + Send + Sync + 'static;
96
97    /// Gets a list of the header names that [`FromHeaders::from_headers`] expects.
98    ///
99    /// Used by [`Headers::get()`] to generate an error if the headers are not present.
100    fn header_names() -> &'static [&'static str];
101
102    /// Extracts the value from the provided [`Headers`] collection.
103    ///
104    /// This method returns one of the following three values:
105    /// * `Ok(Some(...))` if the relevant headers are present and could be parsed into the value.
106    /// * `Ok(None)` if the relevant headers are not present, so no attempt to parse them can be made.
107    /// * `Err(...)` if an error occurred when trying to parse the headers. This likely indicates that the headers are present but invalid.
108    fn from_headers(headers: &Headers) -> Result<Option<Self>, Self::Error>;
109}
110
111/// View a type as an HTTP header.
112///
113// Ad interim there are two default functions: `add_to_builder` and `add_to_request`.
114//
115// While not restricted by the type system, please add HTTP headers only. In particular, do not
116// interact with the body of the request.
117//
118// As soon as the migration to the pipeline architecture will be complete we will phase out
119// `add_to_builder`.
120pub trait Header {
121    /// Get the name of the header.
122    fn name(&self) -> HeaderName;
123    /// Get the value of the header.
124    fn value(&self) -> HeaderValue;
125}
126
127/// A collection of headers.
128#[derive(Clone, PartialEq, Eq, Default)]
129pub struct Headers(std::collections::HashMap<HeaderName, HeaderValue>);
130
131impl Headers {
132    /// Create a new headers collection.
133    pub fn new() -> Self {
134        Self::default()
135    }
136
137    /// Gets the headers represented by `H`, or return an error if the header is not found.
138    pub fn get<H: FromHeaders>(&self) -> crate::Result<H> {
139        match H::from_headers(self) {
140            Ok(Some(x)) => Ok(x),
141            Ok(None) => Err(crate::Error::with_message_fn(
142                ErrorKind::DataConversion,
143                || {
144                    let required_headers = H::header_names();
145                    format!(
146                        "required header(s) not found: {}",
147                        required_headers.join(", ")
148                    )
149                },
150            )),
151            Err(e) => Err(crate::Error::new(ErrorKind::DataConversion, e)),
152        }
153    }
154
155    /// Gets the headers represented by `H`, if they are present.
156    ///
157    /// This method returns one of the following three values:
158    /// * `Ok(Some(...))` if the relevant headers are present and could be parsed into the value.
159    /// * `Ok(None)` if the relevant headers are not present, so no attempt to parse them can be made.
160    /// * `Err(...)` if an error occurred when trying to parse the headers. This likely indicates that the headers are present but invalid.
161    pub fn get_optional<H: FromHeaders>(&self) -> Result<Option<H>, H::Error> {
162        H::from_headers(self)
163    }
164
165    /// Optionally get a header value as a `String`.
166    pub fn get_optional_string(&self, key: &HeaderName) -> Option<String> {
167        self.get_as(key).ok()
168    }
169
170    /// Get a header value as a `str`, or err if it is not found.
171    pub fn get_str(&self, key: &HeaderName) -> crate::Result<&str> {
172        self.get_with(key, |s| crate::Result::Ok(s.as_str()))
173    }
174
175    /// Optionally get a header value as a `str`.
176    pub fn get_optional_str(&self, key: &HeaderName) -> Option<&str> {
177        self.get_str(key).ok()
178    }
179
180    /// Get a header value parsing it as the type, or err if it's not found or it fails to parse.
181    pub fn get_as<V, E>(&self, key: &HeaderName) -> crate::Result<V>
182    where
183        V: FromStr<Err = E>,
184        E: std::error::Error + Send + Sync + 'static,
185    {
186        self.get_with(key, |s| s.as_str().parse())
187    }
188
189    /// Optionally get a header value parsing it as the type, or err if it fails to parse.
190    pub fn get_optional_as<V, E>(&self, key: &HeaderName) -> crate::Result<Option<V>>
191    where
192        V: FromStr<Err = E>,
193        E: std::error::Error + Send + Sync + 'static,
194    {
195        self.get_optional_with(key, |s| s.as_str().parse())
196    }
197
198    /// Get a header value using the parser, or err if it is not found or fails to parse.
199    pub fn get_with<'a, V, F, E>(&'a self, key: &HeaderName, parser: F) -> crate::Result<V>
200    where
201        F: FnOnce(&'a HeaderValue) -> Result<V, E>,
202        E: std::error::Error + Send + Sync + 'static,
203    {
204        self.get_optional_with(key, parser)?.ok_or_else(|| {
205            Error::with_message_fn(ErrorKind::DataConversion, || {
206                format!("header not found {}", key.as_str())
207            })
208        })
209    }
210
211    /// Optionally get a header value using the parser, or err if it fails to parse.
212    pub fn get_optional_with<'a, V, F, E>(
213        &'a self,
214        key: &HeaderName,
215        parser: F,
216    ) -> crate::Result<Option<V>>
217    where
218        F: FnOnce(&'a HeaderValue) -> Result<V, E>,
219        E: std::error::Error + Send + Sync + 'static,
220    {
221        self.0
222            .get(key)
223            .map(|v: &HeaderValue| {
224                parser(v).with_context_fn(ErrorKind::DataConversion, || {
225                    let ty = std::any::type_name::<V>();
226                    format!("unable to parse header '{key:?}: {v:?}' into {ty}",)
227                })
228            })
229            .transpose()
230    }
231
232    /// Insert a header name/value pair.
233    pub fn insert<K, V>(&mut self, key: K, value: V)
234    where
235        K: Into<HeaderName>,
236        V: Into<HeaderValue>,
237    {
238        self.0.insert(key.into(), value.into());
239    }
240
241    /// Add headers to the headers collection.
242    ///
243    /// ## Errors
244    ///
245    /// The error this returns depends on the type `H`.
246    /// Many header types are infallible, return a `Result` with [`Infallible`] as the error type.
247    /// In this case, you can safely `.unwrap()` the value without risking a panic.
248    pub fn add<H>(&mut self, header: H) -> Result<(), H::Error>
249    where
250        H: AsHeaders,
251    {
252        for (key, value) in header.as_headers()? {
253            self.insert(key, value);
254        }
255        Ok(())
256    }
257
258    /// Iterate over all the header name/value pairs.
259    pub fn iter(&self) -> impl Iterator<Item = (&HeaderName, &HeaderValue)> {
260        self.0.iter()
261    }
262
263    /// Remove a header by name, returning the previous value if present.
264    pub fn remove<K>(&mut self, key: K) -> Option<HeaderValue>
265    where
266        K: Into<HeaderName>,
267    {
268        self.0.remove(&key.into())
269    }
270}
271
272impl fmt::Debug for Headers {
273    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274        // TODO: Sanitize all but safe headers.
275        f.debug_map()
276            .entries(self.0.iter().map(|(k, v)| {
277                (
278                    k.as_str(),
279                    if DEFAULT_ALLOWED_HEADER_NAMES.contains(k.as_str()) {
280                        v.as_str()
281                    } else {
282                        super::REDACTED_PATTERN
283                    },
284                )
285            }))
286            .finish()
287    }
288}
289
290impl IntoIterator for Headers {
291    type Item = (HeaderName, HeaderValue);
292
293    type IntoIter = std::collections::hash_map::IntoIter<HeaderName, HeaderValue>;
294
295    fn into_iter(self) -> Self::IntoIter {
296        self.0.into_iter()
297    }
298}
299
300impl From<std::collections::HashMap<HeaderName, HeaderValue>> for Headers {
301    fn from(c: std::collections::HashMap<HeaderName, HeaderValue>) -> Self {
302        Self(c)
303    }
304}
305
306/// A header name.
307#[derive(Clone, Debug, Eq, PartialOrd, Ord)]
308pub struct HeaderName {
309    /// Name of the header.
310    name: Cow<'static, str>,
311
312    /// Marker indicating if the header is a standard header or not.
313    /// Note that this field is not part of equality or hashing.
314    pub(crate) is_standard: bool,
315}
316
317impl HeaderName {
318    /// Create a header name from a static `str`.
319    pub const fn from_static(s: &'static str) -> Self {
320        ensure_no_uppercase(s);
321        Self {
322            name: Cow::Borrowed(s),
323            is_standard: false,
324        }
325    }
326
327    /// Create a header name from a static `str`.
328    pub const fn from_static_standard(s: &'static str) -> Self {
329        ensure_no_uppercase(s);
330        Self {
331            name: Cow::Borrowed(s),
332            is_standard: true,
333        }
334    }
335
336    fn from_cow<C>(c: C) -> Self
337    where
338        C: Into<Cow<'static, str>>,
339    {
340        let c = c.into();
341        assert!(
342            c.chars().all(|c| c.is_lowercase() || !c.is_alphabetic()),
343            "header names must be lowercase: {c}"
344        );
345        Self {
346            name: c,
347            is_standard: false,
348        }
349    }
350
351    /// Get a header name as a `str`.
352    pub fn as_str(&self) -> &str {
353        self.name.as_ref()
354    }
355
356    /// Get whether the header was defined as a standard HTTP header.
357    pub fn is_standard(&self) -> bool {
358        self.is_standard
359    }
360}
361
362impl PartialEq for HeaderName {
363    fn eq(&self, other: &Self) -> bool {
364        self.name.eq_ignore_ascii_case(&other.name)
365    }
366}
367
368impl PartialEq<&str> for HeaderName {
369    fn eq(&self, other: &&str) -> bool {
370        self.name.eq_ignore_ascii_case(other)
371    }
372}
373
374impl std::hash::Hash for HeaderName {
375    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
376        // Keep hashing consistent with PartialEq: include the case-insensitive name.
377        std::hash::Hash::hash(&self.name, state);
378    }
379}
380
381/// Ensures the supplied string does not contain any uppercase ascii characters
382const fn ensure_no_uppercase(s: &str) {
383    let bytes = s.as_bytes();
384    let mut i = 0;
385    while i < bytes.len() {
386        let byte = bytes[i];
387        assert!(
388            !(byte >= 65u8 && byte <= 90u8),
389            "header names must not contain uppercase letters"
390        );
391        i += 1;
392    }
393}
394
395impl From<&'static str> for HeaderName {
396    fn from(s: &'static str) -> Self {
397        Self::from_cow(s)
398    }
399}
400
401impl From<String> for HeaderName {
402    fn from(s: String) -> Self {
403        Self::from_cow(s.to_lowercase())
404    }
405}
406
407/// A header value.
408#[derive(Clone, PartialEq, Eq)]
409pub struct HeaderValue(Cow<'static, str>);
410
411impl HeaderValue {
412    /// Create a header value from a static `str`.
413    pub const fn from_static(s: &'static str) -> Self {
414        Self(Cow::Borrowed(s))
415    }
416
417    /// Create a header value from a [`Cow`].
418    pub fn from_cow<C>(c: C) -> Self
419    where
420        C: Into<Cow<'static, str>>,
421    {
422        Self(c.into())
423    }
424
425    /// Get a header value as a `str`.
426    pub fn as_str(&self) -> &str {
427        self.0.as_ref()
428    }
429}
430
431impl fmt::Debug for HeaderValue {
432    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
433        f.write_str("HeaderValue")
434    }
435}
436
437impl From<&'static str> for HeaderValue {
438    fn from(s: &'static str) -> Self {
439        Self::from_cow(s)
440    }
441}
442
443impl From<String> for HeaderValue {
444    fn from(s: String) -> Self {
445        Self::from_cow(s)
446    }
447}
448
449impl From<&String> for HeaderValue {
450    fn from(s: &String) -> Self {
451        s.clone().into()
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use crate::error::ErrorKind;
458    use url::Url;
459
460    use super::{FromHeaders, HeaderName, Headers};
461
462    // Just in case we add a ContentLocation struct later, this one is named "ForTest" to indicate it's just here for this test.
463    #[derive(Debug)]
464    struct ContentLocationForTest(Url);
465
466    impl FromHeaders for ContentLocationForTest {
467        type Error = url::ParseError;
468
469        fn header_names() -> &'static [&'static str] {
470            &["content-location"]
471        }
472
473        fn from_headers(headers: &super::Headers) -> Result<Option<Self>, Self::Error> {
474            let Some(loc) = headers.get_optional_str(&HeaderName::from("content-location")) else {
475                return Ok(None);
476            };
477
478            Ok(Some(ContentLocationForTest(loc.parse()?)))
479        }
480    }
481
482    #[test]
483    pub fn headers_get_optional_returns_ok_some_if_header_present_and_valid() {
484        let mut headers = Headers::new();
485        headers.insert("content-location", "https://example.com");
486        let content_location: ContentLocationForTest = headers.get_optional().unwrap().unwrap();
487        assert_eq!("https://example.com/", content_location.0.as_str())
488    }
489
490    #[test]
491    pub fn headers_get_optional_returns_ok_none_if_header_not_present() {
492        let headers = Headers::new();
493        let content_location: Option<ContentLocationForTest> = headers.get_optional().unwrap();
494        assert!(content_location.is_none())
495    }
496
497    #[test]
498    pub fn headers_get_optional_returns_err_if_conversion_fails() {
499        let mut headers = Headers::new();
500        headers.insert("content-location", "not a URL");
501        let err = headers
502            .get_optional::<ContentLocationForTest>()
503            .unwrap_err();
504        assert_eq!(url::ParseError::RelativeUrlWithoutBase, err)
505    }
506
507    #[test]
508    pub fn headers_get_returns_ok_if_header_present_and_valid() {
509        let mut headers = Headers::new();
510        headers.insert("content-location", "https://example.com");
511        let content_location: ContentLocationForTest = headers.get().unwrap();
512        assert_eq!("https://example.com/", content_location.0.as_str())
513    }
514
515    #[test]
516    pub fn headers_get_returns_err_if_header_not_present() {
517        let headers = Headers::new();
518        let err = headers.get::<ContentLocationForTest>().unwrap_err();
519        assert_eq!(&ErrorKind::DataConversion, err.kind());
520
521        // The "Display" implementation is the canonical way to get an error's "message"
522        assert_eq!(
523            "required header(s) not found: content-location",
524            format!("{}", err)
525        );
526    }
527
528    #[test]
529    pub fn headers_get_returns_err_if_header_requiring_multiple_headers_not_present() {
530        #[derive(Debug)]
531        struct HasTwoHeaders;
532
533        impl FromHeaders for HasTwoHeaders {
534            type Error = std::convert::Infallible;
535
536            fn header_names() -> &'static [&'static str] {
537                &["header-a", "header-b"]
538            }
539
540            fn from_headers(_: &Headers) -> Result<Option<Self>, Self::Error> {
541                Ok(None)
542            }
543        }
544
545        let headers = Headers::new();
546        let err = headers.get::<HasTwoHeaders>().unwrap_err();
547        assert_eq!(&ErrorKind::DataConversion, err.kind());
548
549        // The "Display" implementation is the canonical way to get an error's "message"
550        assert_eq!(
551            "required header(s) not found: header-a, header-b",
552            format!("{}", err)
553        );
554    }
555
556    #[test]
557    pub fn headers_get_returns_err_if_conversion_fails() {
558        let mut headers = Headers::new();
559        headers.insert("content-location", "not a URL");
560        let err = headers.get::<ContentLocationForTest>().unwrap_err();
561        assert_eq!(&ErrorKind::DataConversion, err.kind());
562        let inner: Box<url::ParseError> = err.into_inner().unwrap().downcast().unwrap();
563        assert_eq!(Box::new(url::ParseError::RelativeUrlWithoutBase), inner)
564    }
565
566    #[test]
567    pub fn headers_remove_existing_header_returns_value() {
568        let mut headers = Headers::new();
569        headers.insert("test-header", "test-value");
570
571        // Verify the header is present
572        assert_eq!(
573            headers.get_optional_str(&HeaderName::from("test-header")),
574            Some("test-value")
575        );
576
577        // Remove the header and verify it returns the previous value
578        let removed_value = headers.remove("test-header");
579        assert!(removed_value.is_some());
580        assert_eq!(removed_value.unwrap().as_str(), "test-value");
581
582        // Verify the header is no longer present
583        assert_eq!(
584            headers.get_optional_str(&HeaderName::from("test-header")),
585            None
586        );
587    }
588
589    #[test]
590    pub fn headers_remove_nonexistent_header_returns_none() {
591        let mut headers = Headers::new();
592
593        // Try to remove a header that doesn't exist
594        let removed_value = headers.remove("nonexistent-header");
595        assert_eq!(removed_value, None);
596    }
597
598    #[test]
599    pub fn headers_remove_works_with_different_key_types() {
600        let mut headers = Headers::new();
601        headers.insert("test-header", "test-value");
602
603        // Test removing with &str
604        let removed_value = headers.remove("test-header");
605        assert!(removed_value.is_some());
606        assert_eq!(removed_value.unwrap().as_str(), "test-value");
607
608        // Re-add the header
609        headers.insert("test-header", "test-value");
610
611        // Test removing with HeaderName
612        let removed_value = headers.remove(HeaderName::from("test-header"));
613        assert!(removed_value.is_some());
614        assert_eq!(removed_value.unwrap().as_str(), "test-value");
615
616        // Re-add the header
617        headers.insert("test-header", "test-value");
618
619        // Test removing with String
620        let removed_value = headers.remove("test-header".to_string());
621        assert!(removed_value.is_some());
622        assert_eq!(removed_value.unwrap().as_str(), "test-value");
623    }
624}