tower_async_http/cors/
allow_origin.rs

1use std::{array, fmt, sync::Arc};
2
3use http::{
4    header::{self, HeaderName, HeaderValue},
5    request::Parts as RequestParts,
6};
7
8use super::{Any, WILDCARD};
9
10/// Holds configuration for how to set the [`Access-Control-Allow-Origin`][mdn] header.
11///
12/// See [`CorsLayer::allow_origin`] for more details.
13///
14/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
15/// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
16#[derive(Clone, Default)]
17#[must_use]
18pub struct AllowOrigin(OriginInner);
19
20impl AllowOrigin {
21    /// Allow any origin by sending a wildcard (`*`)
22    ///
23    /// See [`CorsLayer::allow_origin`] for more details.
24    ///
25    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
26    pub fn any() -> Self {
27        Self(OriginInner::Const(WILDCARD))
28    }
29
30    /// Set a single allowed origin
31    ///
32    /// See [`CorsLayer::allow_origin`] for more details.
33    ///
34    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
35    pub fn exact(origin: HeaderValue) -> Self {
36        Self(OriginInner::Const(origin))
37    }
38
39    /// Set multiple allowed origins
40    ///
41    /// See [`CorsLayer::allow_origin`] for more details.
42    ///
43    /// # Panics
44    ///
45    /// If the iterator contains a wildcard (`*`).
46    ///
47    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
48    #[allow(clippy::borrow_interior_mutable_const)]
49    pub fn list<I>(origins: I) -> Self
50    where
51        I: IntoIterator<Item = HeaderValue>,
52    {
53        let origins = origins.into_iter().collect::<Vec<_>>();
54        if origins.iter().any(|o| o == WILDCARD) {
55            panic!("Wildcard origin (`*`) cannot be passed to `AllowOrigin::list`. Use `AllowOrigin::any()` instead");
56        } else {
57            Self(OriginInner::List(origins))
58        }
59    }
60
61    /// Set the allowed origins from a predicate
62    ///
63    /// See [`CorsLayer::allow_origin`] for more details.
64    ///
65    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
66    pub fn predicate<F>(f: F) -> Self
67    where
68        F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
69    {
70        Self(OriginInner::Predicate(Arc::new(f)))
71    }
72
73    /// Allow any origin, by mirroring the request origin
74    ///
75    /// This is equivalent to
76    /// [`AllowOrigin::predicate(|_, _| true)`][Self::predicate].
77    ///
78    /// See [`CorsLayer::allow_origin`] for more details.
79    ///
80    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
81    pub fn mirror_request() -> Self {
82        Self::predicate(|_, _| true)
83    }
84
85    #[allow(clippy::borrow_interior_mutable_const)]
86    pub(super) fn is_wildcard(&self) -> bool {
87        matches!(&self.0, OriginInner::Const(v) if v == WILDCARD)
88    }
89
90    pub(super) fn to_header(
91        &self,
92        origin: Option<&HeaderValue>,
93        parts: &RequestParts,
94    ) -> Option<(HeaderName, HeaderValue)> {
95        let allow_origin = match &self.0 {
96            OriginInner::Const(v) => v.clone(),
97            OriginInner::List(l) => origin.filter(|o| l.contains(o))?.clone(),
98            OriginInner::Predicate(c) => origin.filter(|origin| c(origin, parts))?.clone(),
99        };
100
101        Some((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin))
102    }
103}
104
105impl fmt::Debug for AllowOrigin {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        match &self.0 {
108            OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
109            OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(),
110            OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
111        }
112    }
113}
114
115impl From<Any> for AllowOrigin {
116    fn from(_: Any) -> Self {
117        Self::any()
118    }
119}
120
121impl From<HeaderValue> for AllowOrigin {
122    fn from(val: HeaderValue) -> Self {
123        Self::exact(val)
124    }
125}
126
127impl<const N: usize> From<[HeaderValue; N]> for AllowOrigin {
128    fn from(arr: [HeaderValue; N]) -> Self {
129        #[allow(deprecated)] // Can be changed when MSRV >= 1.53
130        Self::list(array::IntoIter::new(arr))
131    }
132}
133
134impl From<Vec<HeaderValue>> for AllowOrigin {
135    fn from(vec: Vec<HeaderValue>) -> Self {
136        Self::list(vec)
137    }
138}
139
140#[derive(Clone)]
141enum OriginInner {
142    Const(HeaderValue),
143    List(Vec<HeaderValue>),
144    Predicate(
145        Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
146    ),
147}
148
149impl Default for OriginInner {
150    fn default() -> Self {
151        Self::List(Vec::new())
152    }
153}