tower_async_http/cors/
allow_origin.rs1use std::{array, fmt, sync::Arc};
2
3use http::{
4 header::{self, HeaderName, HeaderValue},
5 request::Parts as RequestParts,
6};
7
8use super::{Any, WILDCARD};
9
10#[derive(Clone, Default)]
17#[must_use]
18pub struct AllowOrigin(OriginInner);
19
20impl AllowOrigin {
21 pub fn any() -> Self {
27 Self(OriginInner::Const(WILDCARD))
28 }
29
30 pub fn exact(origin: HeaderValue) -> Self {
36 Self(OriginInner::Const(origin))
37 }
38
39 #[allow(clippy::borrow_interior_mutable_const)]
49 pub fn list<I>(origins: I) -> Self
50 where
51 I: IntoIterator<Item = HeaderValue>,
52 {
53 let origins = origins.into_iter().collect::<Vec<_>>();
54 if origins.iter().any(|o| o == WILDCARD) {
55 panic!("Wildcard origin (`*`) cannot be passed to `AllowOrigin::list`. Use `AllowOrigin::any()` instead");
56 } else {
57 Self(OriginInner::List(origins))
58 }
59 }
60
61 pub fn predicate<F>(f: F) -> Self
67 where
68 F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
69 {
70 Self(OriginInner::Predicate(Arc::new(f)))
71 }
72
73 pub fn mirror_request() -> Self {
82 Self::predicate(|_, _| true)
83 }
84
85 #[allow(clippy::borrow_interior_mutable_const)]
86 pub(super) fn is_wildcard(&self) -> bool {
87 matches!(&self.0, OriginInner::Const(v) if v == WILDCARD)
88 }
89
90 pub(super) fn to_header(
91 &self,
92 origin: Option<&HeaderValue>,
93 parts: &RequestParts,
94 ) -> Option<(HeaderName, HeaderValue)> {
95 let allow_origin = match &self.0 {
96 OriginInner::Const(v) => v.clone(),
97 OriginInner::List(l) => origin.filter(|o| l.contains(o))?.clone(),
98 OriginInner::Predicate(c) => origin.filter(|origin| c(origin, parts))?.clone(),
99 };
100
101 Some((header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin))
102 }
103}
104
105impl fmt::Debug for AllowOrigin {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 match &self.0 {
108 OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
109 OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(),
110 OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
111 }
112 }
113}
114
115impl From<Any> for AllowOrigin {
116 fn from(_: Any) -> Self {
117 Self::any()
118 }
119}
120
121impl From<HeaderValue> for AllowOrigin {
122 fn from(val: HeaderValue) -> Self {
123 Self::exact(val)
124 }
125}
126
127impl<const N: usize> From<[HeaderValue; N]> for AllowOrigin {
128 fn from(arr: [HeaderValue; N]) -> Self {
129 #[allow(deprecated)] Self::list(array::IntoIter::new(arr))
131 }
132}
133
134impl From<Vec<HeaderValue>> for AllowOrigin {
135 fn from(vec: Vec<HeaderValue>) -> Self {
136 Self::list(vec)
137 }
138}
139
140#[derive(Clone)]
141enum OriginInner {
142 Const(HeaderValue),
143 List(Vec<HeaderValue>),
144 Predicate(
145 Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
146 ),
147}
148
149impl Default for OriginInner {
150 fn default() -> Self {
151 Self::List(Vec::new())
152 }
153}