Skip to main content

tower_http/cors/
allow_origin.rs

1use http::{
2    header::{self, HeaderName, HeaderValue},
3    request::Parts as RequestParts,
4};
5use pin_project_lite::pin_project;
6use std::{
7    fmt,
8    future::Future,
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12};
13
14use super::{Any, WILDCARD};
15
16/// Holds configuration for how to set the [`Access-Control-Allow-Origin`][mdn] header.
17///
18/// See [`CorsLayer::allow_origin`] for more details.
19///
20/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
21/// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
22#[derive(Clone, Default)]
23#[must_use]
24pub struct AllowOrigin(OriginInner);
25
26impl AllowOrigin {
27    /// Allow any origin by sending a wildcard (`*`)
28    ///
29    /// See [`CorsLayer::allow_origin`] for more details.
30    ///
31    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
32    pub fn any() -> Self {
33        Self(OriginInner::Const(WILDCARD))
34    }
35
36    /// Set a single allowed origin
37    ///
38    /// See [`CorsLayer::allow_origin`] for more details.
39    ///
40    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
41    pub fn exact(origin: HeaderValue) -> Self {
42        Self(OriginInner::Const(origin))
43    }
44
45    /// Set multiple allowed origins
46    ///
47    /// See [`CorsLayer::allow_origin`] for more details.
48    ///
49    /// # Panics
50    ///
51    /// If the iterator contains a wildcard (`*`).
52    ///
53    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
54    #[allow(clippy::borrow_interior_mutable_const)]
55    pub fn list<I>(origins: I) -> Self
56    where
57        I: IntoIterator<Item = HeaderValue>,
58    {
59        let origins = origins.into_iter().collect::<Vec<_>>();
60        if origins.contains(&WILDCARD) {
61            panic!(
62                "Wildcard origin (`*`) cannot be passed to `AllowOrigin::list`. \
63                 Use `AllowOrigin::any()` instead"
64            );
65        }
66
67        Self(OriginInner::List(origins))
68    }
69
70    /// Set the allowed origins from a predicate
71    ///
72    /// See [`CorsLayer::allow_origin`] for more details.
73    ///
74    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
75    pub fn predicate<F>(f: F) -> Self
76    where
77        F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
78    {
79        Self(OriginInner::Predicate(Arc::new(f)))
80    }
81
82    /// Set the allowed origins from an async predicate
83    ///
84    /// See [`CorsLayer::allow_origin`] for more details.
85    ///
86    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
87    pub fn async_predicate<F, Fut>(f: F) -> Self
88    where
89        F: FnOnce(HeaderValue, &RequestParts) -> Fut + Send + Sync + 'static + Clone,
90        Fut: Future<Output = bool> + Send + 'static,
91    {
92        Self(OriginInner::AsyncPredicate(Arc::new(move |v, p| {
93            Box::pin((f.clone())(v, p))
94        })))
95    }
96
97    /// Allow any origin, by mirroring the request origin
98    ///
99    /// This is equivalent to
100    /// [`AllowOrigin::predicate(|_, _| true)`][Self::predicate].
101    ///
102    /// See [`CorsLayer::allow_origin`] for more details.
103    ///
104    /// [`CorsLayer::allow_origin`]: super::CorsLayer::allow_origin
105    pub fn mirror_request() -> Self {
106        Self::predicate(|_, _| true)
107    }
108
109    #[allow(clippy::borrow_interior_mutable_const)]
110    pub(super) fn is_wildcard(&self) -> bool {
111        matches!(&self.0, OriginInner::Const(v) if v == WILDCARD)
112    }
113
114    pub(super) fn varies_with_origin(&self) -> bool {
115        !matches!(&self.0, OriginInner::Const(_))
116    }
117
118    pub(super) fn to_future(
119        &self,
120        origin: Option<&HeaderValue>,
121        parts: &RequestParts,
122    ) -> AllowOriginFuture {
123        let name = header::ACCESS_CONTROL_ALLOW_ORIGIN;
124
125        match &self.0 {
126            OriginInner::Const(v) => AllowOriginFuture::ok(Some((name, v.clone()))),
127            OriginInner::List(l) => {
128                AllowOriginFuture::ok(origin.filter(|o| l.contains(o)).map(|o| (name, o.clone())))
129            }
130            OriginInner::Predicate(c) => AllowOriginFuture::ok(
131                origin
132                    .filter(|origin| c(origin, parts))
133                    .map(|o| (name, o.clone())),
134            ),
135            OriginInner::AsyncPredicate(f) => {
136                if let Some(origin) = origin.cloned() {
137                    let fut = f(origin.clone(), parts);
138                    AllowOriginFuture::fut(async move { fut.await.then_some((name, origin)) })
139                } else {
140                    AllowOriginFuture::ok(None)
141                }
142            }
143        }
144    }
145}
146
147pin_project! {
148    #[project = AllowOriginFutureProj]
149    pub(super) enum AllowOriginFuture {
150        Ok{
151            res: Option<(HeaderName, HeaderValue)>
152        },
153        Future{
154            #[pin]
155            future: Pin<Box<dyn Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>>
156        },
157    }
158}
159
160impl AllowOriginFuture {
161    fn ok(res: Option<(HeaderName, HeaderValue)>) -> Self {
162        Self::Ok { res }
163    }
164
165    fn fut<F: Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>(
166        future: F,
167    ) -> Self {
168        Self::Future {
169            future: Box::pin(future),
170        }
171    }
172}
173
174impl Future for AllowOriginFuture {
175    type Output = Option<(HeaderName, HeaderValue)>;
176
177    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
178        match self.project() {
179            AllowOriginFutureProj::Ok { res } => Poll::Ready(res.take()),
180            AllowOriginFutureProj::Future { future } => future.poll(cx),
181        }
182    }
183}
184
185impl fmt::Debug for AllowOrigin {
186    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187        match &self.0 {
188            OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
189            OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(),
190            OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
191            OriginInner::AsyncPredicate(_) => f.debug_tuple("AsyncPredicate").finish(),
192        }
193    }
194}
195
196impl From<Any> for AllowOrigin {
197    fn from(_: Any) -> Self {
198        Self::any()
199    }
200}
201
202impl From<HeaderValue> for AllowOrigin {
203    fn from(val: HeaderValue) -> Self {
204        Self::exact(val)
205    }
206}
207
208impl<const N: usize> From<[HeaderValue; N]> for AllowOrigin {
209    fn from(arr: [HeaderValue; N]) -> Self {
210        Self::list(arr)
211    }
212}
213
214impl From<Vec<HeaderValue>> for AllowOrigin {
215    fn from(vec: Vec<HeaderValue>) -> Self {
216        Self::list(vec)
217    }
218}
219
220#[derive(Clone)]
221enum OriginInner {
222    Const(HeaderValue),
223    List(Vec<HeaderValue>),
224    Predicate(
225        Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
226    ),
227    AsyncPredicate(
228        Arc<
229            dyn for<'a> Fn(
230                    HeaderValue,
231                    &'a RequestParts,
232                ) -> Pin<Box<dyn Future<Output = bool> + Send + 'static>>
233                + Send
234                + Sync
235                + 'static,
236        >,
237    ),
238}
239
240impl Default for OriginInner {
241    fn default() -> Self {
242        Self::List(Vec::new())
243    }
244}