tower_http/cors/
allow_origin.rs1use http::{
2 header::{self, HeaderName, HeaderValue},
3 request::Parts as RequestParts,
4};
5use pin_project_lite::pin_project;
6use std::{
7 fmt,
8 future::Future,
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll},
12};
13
14use super::{Any, WILDCARD};
15
16#[derive(Clone, Default)]
23#[must_use]
24pub struct AllowOrigin(OriginInner);
25
26impl AllowOrigin {
27 pub fn any() -> Self {
33 Self(OriginInner::Const(WILDCARD))
34 }
35
36 pub fn exact(origin: HeaderValue) -> Self {
42 Self(OriginInner::Const(origin))
43 }
44
45 #[allow(clippy::borrow_interior_mutable_const)]
55 pub fn list<I>(origins: I) -> Self
56 where
57 I: IntoIterator<Item = HeaderValue>,
58 {
59 let origins = origins.into_iter().collect::<Vec<_>>();
60 if origins.contains(&WILDCARD) {
61 panic!(
62 "Wildcard origin (`*`) cannot be passed to `AllowOrigin::list`. \
63 Use `AllowOrigin::any()` instead"
64 );
65 }
66
67 Self(OriginInner::List(origins))
68 }
69
70 pub fn predicate<F>(f: F) -> Self
76 where
77 F: Fn(&HeaderValue, &RequestParts) -> bool + Send + Sync + 'static,
78 {
79 Self(OriginInner::Predicate(Arc::new(f)))
80 }
81
82 pub fn async_predicate<F, Fut>(f: F) -> Self
88 where
89 F: FnOnce(HeaderValue, &RequestParts) -> Fut + Send + Sync + 'static + Clone,
90 Fut: Future<Output = bool> + Send + 'static,
91 {
92 Self(OriginInner::AsyncPredicate(Arc::new(move |v, p| {
93 Box::pin((f.clone())(v, p))
94 })))
95 }
96
97 pub fn mirror_request() -> Self {
106 Self::predicate(|_, _| true)
107 }
108
109 #[allow(clippy::borrow_interior_mutable_const)]
110 pub(super) fn is_wildcard(&self) -> bool {
111 matches!(&self.0, OriginInner::Const(v) if v == WILDCARD)
112 }
113
114 pub(super) fn varies_with_origin(&self) -> bool {
115 !matches!(&self.0, OriginInner::Const(_))
116 }
117
118 pub(super) fn to_future(
119 &self,
120 origin: Option<&HeaderValue>,
121 parts: &RequestParts,
122 ) -> AllowOriginFuture {
123 let name = header::ACCESS_CONTROL_ALLOW_ORIGIN;
124
125 match &self.0 {
126 OriginInner::Const(v) => AllowOriginFuture::ok(Some((name, v.clone()))),
127 OriginInner::List(l) => {
128 AllowOriginFuture::ok(origin.filter(|o| l.contains(o)).map(|o| (name, o.clone())))
129 }
130 OriginInner::Predicate(c) => AllowOriginFuture::ok(
131 origin
132 .filter(|origin| c(origin, parts))
133 .map(|o| (name, o.clone())),
134 ),
135 OriginInner::AsyncPredicate(f) => {
136 if let Some(origin) = origin.cloned() {
137 let fut = f(origin.clone(), parts);
138 AllowOriginFuture::fut(async move { fut.await.then_some((name, origin)) })
139 } else {
140 AllowOriginFuture::ok(None)
141 }
142 }
143 }
144 }
145}
146
147pin_project! {
148 #[project = AllowOriginFutureProj]
149 pub(super) enum AllowOriginFuture {
150 Ok{
151 res: Option<(HeaderName, HeaderValue)>
152 },
153 Future{
154 #[pin]
155 future: Pin<Box<dyn Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>>
156 },
157 }
158}
159
160impl AllowOriginFuture {
161 fn ok(res: Option<(HeaderName, HeaderValue)>) -> Self {
162 Self::Ok { res }
163 }
164
165 fn fut<F: Future<Output = Option<(HeaderName, HeaderValue)>> + Send + 'static>(
166 future: F,
167 ) -> Self {
168 Self::Future {
169 future: Box::pin(future),
170 }
171 }
172}
173
174impl Future for AllowOriginFuture {
175 type Output = Option<(HeaderName, HeaderValue)>;
176
177 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
178 match self.project() {
179 AllowOriginFutureProj::Ok { res } => Poll::Ready(res.take()),
180 AllowOriginFutureProj::Future { future } => future.poll(cx),
181 }
182 }
183}
184
185impl fmt::Debug for AllowOrigin {
186 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 match &self.0 {
188 OriginInner::Const(inner) => f.debug_tuple("Const").field(inner).finish(),
189 OriginInner::List(inner) => f.debug_tuple("List").field(inner).finish(),
190 OriginInner::Predicate(_) => f.debug_tuple("Predicate").finish(),
191 OriginInner::AsyncPredicate(_) => f.debug_tuple("AsyncPredicate").finish(),
192 }
193 }
194}
195
196impl From<Any> for AllowOrigin {
197 fn from(_: Any) -> Self {
198 Self::any()
199 }
200}
201
202impl From<HeaderValue> for AllowOrigin {
203 fn from(val: HeaderValue) -> Self {
204 Self::exact(val)
205 }
206}
207
208impl<const N: usize> From<[HeaderValue; N]> for AllowOrigin {
209 fn from(arr: [HeaderValue; N]) -> Self {
210 Self::list(arr)
211 }
212}
213
214impl From<Vec<HeaderValue>> for AllowOrigin {
215 fn from(vec: Vec<HeaderValue>) -> Self {
216 Self::list(vec)
217 }
218}
219
220#[derive(Clone)]
221enum OriginInner {
222 Const(HeaderValue),
223 List(Vec<HeaderValue>),
224 Predicate(
225 Arc<dyn for<'a> Fn(&'a HeaderValue, &'a RequestParts) -> bool + Send + Sync + 'static>,
226 ),
227 AsyncPredicate(
228 Arc<
229 dyn for<'a> Fn(
230 HeaderValue,
231 &'a RequestParts,
232 ) -> Pin<Box<dyn Future<Output = bool> + Send + 'static>>
233 + Send
234 + Sync
235 + 'static,
236 >,
237 ),
238}
239
240impl Default for OriginInner {
241 fn default() -> Self {
242 Self::List(Vec::new())
243 }
244}