tower_http/cors/mod.rs
1//! Middleware which adds headers for [CORS][mdn].
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, Method, header};
7//! use http_body_util::Full;
8//! use bytes::Bytes;
9//! use tower::{ServiceBuilder, ServiceExt, Service};
10//! use tower_http::cors::{Any, CorsLayer};
11//! use std::convert::Infallible;
12//!
13//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
14//! Ok(Response::new(Full::default()))
15//! }
16//!
17//! # #[tokio::main]
18//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
19//! let cors = CorsLayer::new()
20//! // allow `GET` and `POST` when accessing the resource
21//! .allow_methods([Method::GET, Method::POST])
22//! // allow requests from any origin
23//! .allow_origin(Any);
24//!
25//! let mut service = ServiceBuilder::new()
26//! .layer(cors)
27//! .service_fn(handle);
28//!
29//! let request = Request::builder()
30//! .header(header::ORIGIN, "https://example.com")
31//! .body(Full::default())
32//! .unwrap();
33//!
34//! let response = service
35//! .ready()
36//! .await?
37//! .call(request)
38//! .await?;
39//!
40//! assert_eq!(
41//! response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(),
42//! "*",
43//! );
44//! # Ok(())
45//! # }
46//! ```
47//!
48//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
49
50#![allow(clippy::enum_variant_names)]
51
52use allow_origin::AllowOriginFuture;
53use bytes::{BufMut, BytesMut};
54use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response};
55use pin_project_lite::pin_project;
56use std::{
57 future::Future,
58 mem,
59 pin::Pin,
60 task::{ready, Context, Poll},
61};
62use tower_layer::Layer;
63use tower_service::Service;
64
65mod allow_credentials;
66mod allow_headers;
67mod allow_methods;
68mod allow_origin;
69mod allow_private_network;
70mod expose_headers;
71mod max_age;
72mod vary;
73
74#[cfg(test)]
75mod tests;
76
77pub use self::{
78 allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
79 allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork,
80 expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
81};
82
83/// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn].
84///
85/// See the [module docs](crate::cors) for an example.
86///
87/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
88#[derive(Debug, Clone)]
89#[must_use]
90pub struct CorsLayer {
91 allow_credentials: AllowCredentials,
92 allow_headers: AllowHeaders,
93 allow_methods: AllowMethods,
94 allow_origin: AllowOrigin,
95 allow_private_network: AllowPrivateNetwork,
96 expose_headers: ExposeHeaders,
97 max_age: MaxAge,
98 vary: Vary,
99 is_vary_custom: bool,
100}
101
102#[allow(clippy::declare_interior_mutable_const)]
103const WILDCARD: HeaderValue = HeaderValue::from_static("*");
104
105impl CorsLayer {
106 /// Create a new `CorsLayer`.
107 ///
108 /// No headers are sent by default. Use the builder methods to customize
109 /// the behavior.
110 ///
111 /// You need to set at least an allowed origin for browsers to make
112 /// successful cross-origin requests to your service.
113 pub fn new() -> Self {
114 Self {
115 allow_credentials: Default::default(),
116 allow_headers: Default::default(),
117 allow_methods: Default::default(),
118 allow_origin: Default::default(),
119 allow_private_network: Default::default(),
120 expose_headers: Default::default(),
121 max_age: Default::default(),
122 vary: Default::default(),
123 is_vary_custom: false,
124 }
125 }
126
127 /// A permissive configuration:
128 ///
129 /// - All request headers allowed.
130 /// - All methods allowed.
131 /// - All origins allowed.
132 /// - All headers exposed.
133 pub fn permissive() -> Self {
134 Self::new()
135 .allow_headers(Any)
136 .allow_methods(Any)
137 .allow_origin(Any)
138 .expose_headers(Any)
139 }
140
141 /// A very permissive configuration:
142 ///
143 /// - **Credentials allowed.**
144 /// - The method received in `Access-Control-Request-Method` is sent back
145 /// as an allowed method.
146 /// - The origin of the preflight request is sent back as an allowed origin.
147 /// - The header names received in `Access-Control-Request-Headers` are sent
148 /// back as allowed headers.
149 /// - No headers are currently exposed, but this may change in the future.
150 pub fn very_permissive() -> Self {
151 Self::new()
152 .allow_credentials(true)
153 .allow_headers(AllowHeaders::mirror_request())
154 .allow_methods(AllowMethods::mirror_request())
155 .allow_origin(AllowOrigin::mirror_request())
156 }
157
158 /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
159 ///
160 /// ```
161 /// use tower_http::cors::CorsLayer;
162 ///
163 /// let layer = CorsLayer::new().allow_credentials(true);
164 /// ```
165 ///
166 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
167 pub fn allow_credentials<T>(mut self, allow_credentials: T) -> Self
168 where
169 T: Into<AllowCredentials>,
170 {
171 self.allow_credentials = allow_credentials.into();
172 self
173 }
174
175 /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
176 ///
177 /// ```
178 /// use tower_http::cors::CorsLayer;
179 /// use http::header::{AUTHORIZATION, ACCEPT};
180 ///
181 /// let layer = CorsLayer::new().allow_headers([AUTHORIZATION, ACCEPT]);
182 /// ```
183 ///
184 /// All headers can be allowed with
185 ///
186 /// ```
187 /// use tower_http::cors::{Any, CorsLayer};
188 ///
189 /// let layer = CorsLayer::new().allow_headers(Any);
190 /// ```
191 ///
192 /// Note that multiple calls to this method will override any previous
193 /// calls.
194 ///
195 /// Also note that `Access-Control-Allow-Headers` is required for requests that have
196 /// `Access-Control-Request-Headers`.
197 ///
198 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
199 pub fn allow_headers<T>(mut self, headers: T) -> Self
200 where
201 T: Into<AllowHeaders>,
202 {
203 self.allow_headers = headers.into();
204 self
205 }
206
207 /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
208 ///
209 /// ```
210 /// use std::time::Duration;
211 /// use tower_http::cors::CorsLayer;
212 ///
213 /// let layer = CorsLayer::new().max_age(Duration::from_secs(60) * 10);
214 /// ```
215 ///
216 /// By default the header will not be set which disables caching and will
217 /// require a preflight call for all requests.
218 ///
219 /// Note that each browser has a maximum internal value that takes
220 /// precedence when the Access-Control-Max-Age is greater. For more details
221 /// see [mdn].
222 ///
223 /// If you need more flexibility, you can use supply a function which can
224 /// dynamically decide the max-age based on the origin and other parts of
225 /// each preflight request:
226 ///
227 /// ```
228 /// # struct MyServerConfig { cors_max_age: Duration }
229 /// use std::time::Duration;
230 ///
231 /// use http::{request::Parts as RequestParts, HeaderValue};
232 /// use tower_http::cors::{CorsLayer, MaxAge};
233 ///
234 /// let layer = CorsLayer::new().max_age(MaxAge::dynamic(
235 /// |_origin: &HeaderValue, parts: &RequestParts| -> Duration {
236 /// // Let's say you want to be able to reload your config at
237 /// // runtime and have another middleware that always inserts
238 /// // the current config into the request extensions
239 /// let config = parts.extensions.get::<MyServerConfig>().unwrap();
240 /// config.cors_max_age
241 /// },
242 /// ));
243 /// ```
244 ///
245 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
246 pub fn max_age<T>(mut self, max_age: T) -> Self
247 where
248 T: Into<MaxAge>,
249 {
250 self.max_age = max_age.into();
251 self
252 }
253
254 /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
255 ///
256 /// ```
257 /// use tower_http::cors::CorsLayer;
258 /// use http::Method;
259 ///
260 /// let layer = CorsLayer::new().allow_methods([Method::GET, Method::POST]);
261 /// ```
262 ///
263 /// All methods can be allowed with
264 ///
265 /// ```
266 /// use tower_http::cors::{Any, CorsLayer};
267 ///
268 /// let layer = CorsLayer::new().allow_methods(Any);
269 /// ```
270 ///
271 /// Note that multiple calls to this method will override any previous
272 /// calls.
273 ///
274 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
275 pub fn allow_methods<T>(mut self, methods: T) -> Self
276 where
277 T: Into<AllowMethods>,
278 {
279 self.allow_methods = methods.into();
280 self
281 }
282
283 /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
284 ///
285 /// ```
286 /// use http::HeaderValue;
287 /// use tower_http::cors::CorsLayer;
288 ///
289 /// let layer = CorsLayer::new().allow_origin(
290 /// "http://example.com".parse::<HeaderValue>().unwrap(),
291 /// );
292 /// ```
293 ///
294 /// Multiple origins can be allowed with
295 ///
296 /// ```
297 /// use tower_http::cors::CorsLayer;
298 ///
299 /// let origins = [
300 /// "http://example.com".parse().unwrap(),
301 /// "http://api.example.com".parse().unwrap(),
302 /// ];
303 ///
304 /// let layer = CorsLayer::new().allow_origin(origins);
305 /// ```
306 ///
307 /// All origins can be allowed with
308 ///
309 /// ```
310 /// use tower_http::cors::{Any, CorsLayer};
311 ///
312 /// let layer = CorsLayer::new().allow_origin(Any);
313 /// ```
314 ///
315 /// You can also use a closure
316 ///
317 /// ```
318 /// use tower_http::cors::{CorsLayer, AllowOrigin};
319 /// use http::{request::Parts as RequestParts, HeaderValue};
320 ///
321 /// let layer = CorsLayer::new().allow_origin(AllowOrigin::predicate(
322 /// |origin: &HeaderValue, _request_parts: &RequestParts| {
323 /// origin.as_bytes().ends_with(b".rust-lang.org")
324 /// },
325 /// ));
326 /// ```
327 ///
328 /// You can also use an async closure:
329 ///
330 /// ```
331 /// # #[derive(Clone)]
332 /// # struct Client;
333 /// # fn get_api_client() -> Client {
334 /// # Client
335 /// # }
336 /// # impl Client {
337 /// # async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
338 /// # vec![HeaderValue::from_static("http://example.com")]
339 /// # }
340 /// # async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
341 /// # vec![HeaderValue::from_static("http://example.com")]
342 /// # }
343 /// # }
344 /// use tower_http::cors::{CorsLayer, AllowOrigin};
345 /// use http::{request::Parts as RequestParts, HeaderValue};
346 ///
347 /// let client = get_api_client();
348 ///
349 /// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
350 /// |origin: HeaderValue, _request_parts: &RequestParts| async move {
351 /// // fetch list of origins that are allowed
352 /// let origins = client.fetch_allowed_origins().await;
353 /// origins.contains(&origin)
354 /// },
355 /// ));
356 ///
357 /// let client = get_api_client();
358 ///
359 /// // if using &RequestParts, make sure all the values are owned
360 /// // before passing into the future
361 /// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
362 /// |origin: HeaderValue, parts: &RequestParts| {
363 /// let path = parts.uri.path().to_owned();
364 ///
365 /// async move {
366 /// // fetch list of origins that are allowed for this path
367 /// let origins = client.fetch_allowed_origins_for_path(path).await;
368 /// origins.contains(&origin)
369 /// }
370 /// },
371 /// ));
372 /// ```
373 ///
374 /// Note that multiple calls to this method will override any previous
375 /// calls.
376 ///
377 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
378 pub fn allow_origin<T>(mut self, origin: T) -> Self
379 where
380 T: Into<AllowOrigin>,
381 {
382 self.allow_origin = origin.into();
383 self
384 }
385
386 /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
387 ///
388 /// ```
389 /// use tower_http::cors::CorsLayer;
390 /// use http::header::CONTENT_ENCODING;
391 ///
392 /// let layer = CorsLayer::new().expose_headers([CONTENT_ENCODING]);
393 /// ```
394 ///
395 /// All headers can be allowed with
396 ///
397 /// ```
398 /// use tower_http::cors::{Any, CorsLayer};
399 ///
400 /// let layer = CorsLayer::new().expose_headers(Any);
401 /// ```
402 ///
403 /// Note that multiple calls to this method will override any previous
404 /// calls.
405 ///
406 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
407 pub fn expose_headers<T>(mut self, headers: T) -> Self
408 where
409 T: Into<ExposeHeaders>,
410 {
411 self.expose_headers = headers.into();
412 self
413 }
414
415 /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
416 ///
417 /// ```
418 /// use tower_http::cors::CorsLayer;
419 ///
420 /// let layer = CorsLayer::new().allow_private_network(true);
421 /// ```
422 ///
423 /// [wicg]: https://wicg.github.io/private-network-access/
424 pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
425 where
426 T: Into<AllowPrivateNetwork>,
427 {
428 self.allow_private_network = allow_private_network.into();
429 self
430 }
431
432 /// Set the value(s) of the [`Vary`][mdn] header.
433 ///
434 /// By default, this value is derived from whether CORS response headers are
435 /// request-dependent:
436 ///
437 /// - `Origin` is included when `Access-Control-Allow-Origin` depends on the
438 /// request's `Origin` header (for example, origin lists or predicates).
439 /// - `Access-Control-Request-Method` is included when
440 /// `Access-Control-Allow-Methods` mirrors `Access-Control-Request-Method`.
441 /// - `Access-Control-Request-Headers` is included when
442 /// `Access-Control-Allow-Headers` mirrors `Access-Control-Request-Headers`.
443 /// - If none of those values are request-dependent, no `Vary` header is
444 /// added.
445 ///
446 /// Calling this method sets `Vary` explicitly and pins it to the provided
447 /// value, regardless of future changes to those other CORS settings.
448 ///
449 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
450 pub fn vary<T>(mut self, headers: T) -> Self
451 where
452 T: Into<Vary>,
453 {
454 self.vary = headers.into();
455 self.is_vary_custom = true;
456 self
457 }
458
459 /// Recomputes the `Vary` header, if it hasn't been set explicitly.
460 fn update_vary_header(&mut self) {
461 if !self.is_vary_custom {
462 let vary_origin = self.allow_origin.varies_with_origin();
463 let vary_method = self.allow_methods.varies_with_request_method();
464 let vary_headers = self.allow_headers.varies_with_request_headers();
465
466 if !(vary_origin || vary_method || vary_headers) {
467 self.vary = Vary::list([]);
468 } else {
469 let mut vary_header_names = Vec::new();
470 if vary_origin {
471 vary_header_names.push(header::ORIGIN);
472 }
473 if vary_method {
474 vary_header_names.push(header::ACCESS_CONTROL_REQUEST_METHOD);
475 }
476 if vary_headers {
477 vary_header_names.push(header::ACCESS_CONTROL_REQUEST_HEADERS);
478 }
479 self.vary = Vary::list(vary_header_names);
480 }
481 }
482 }
483}
484
485/// Represents a wildcard value (`*`) used with some CORS headers such as
486/// [`CorsLayer::allow_methods`].
487#[derive(Debug, Clone, Copy)]
488#[must_use]
489pub struct Any;
490
491/// Represents a wildcard value (`*`) used with some CORS headers such as
492/// [`CorsLayer::allow_methods`].
493#[deprecated = "Use Any as a unit struct literal instead"]
494pub fn any() -> Any {
495 Any
496}
497
498fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
499where
500 I: Iterator<Item = HeaderValue>,
501{
502 match iter.next() {
503 Some(fst) => {
504 let mut result = BytesMut::from(fst.as_bytes());
505 for val in iter {
506 result.reserve(val.len() + 1);
507 result.put_u8(b',');
508 result.extend_from_slice(val.as_bytes());
509 }
510
511 Some(HeaderValue::from_maybe_shared(result.freeze()).unwrap())
512 }
513 None => None,
514 }
515}
516
517impl Default for CorsLayer {
518 fn default() -> Self {
519 Self::new()
520 }
521}
522
523impl<S> Layer<S> for CorsLayer {
524 type Service = Cors<S>;
525
526 fn layer(&self, inner: S) -> Self::Service {
527 ensure_usable_cors_rules(self);
528
529 // Clone the layer to modify Vary header logic
530 let mut layer = self.clone();
531
532 layer.update_vary_header();
533
534 Cors { inner, layer }
535 }
536}
537
538/// Middleware which adds headers for [CORS][mdn].
539///
540/// See the [module docs](crate::cors) for an example.
541///
542/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
543#[derive(Debug, Clone)]
544#[must_use]
545pub struct Cors<S> {
546 inner: S,
547 layer: CorsLayer,
548}
549
550impl<S> Cors<S> {
551 /// Create a new `Cors`.
552 ///
553 /// See [`CorsLayer::new`] for more details.
554 pub fn new(inner: S) -> Self {
555 Self {
556 inner,
557 layer: CorsLayer::new(),
558 }
559 }
560
561 /// A permissive configuration.
562 ///
563 /// See [`CorsLayer::permissive`] for more details.
564 pub fn permissive(inner: S) -> Self {
565 Self {
566 inner,
567 layer: CorsLayer::permissive(),
568 }
569 }
570
571 /// A very permissive configuration.
572 ///
573 /// See [`CorsLayer::very_permissive`] for more details.
574 pub fn very_permissive(inner: S) -> Self {
575 Self {
576 inner,
577 layer: CorsLayer::very_permissive(),
578 }
579 }
580
581 define_inner_service_accessors!();
582
583 /// Returns a new [`Layer`] that wraps services with a [`Cors`] middleware.
584 ///
585 /// [`Layer`]: tower_layer::Layer
586 pub fn layer() -> CorsLayer {
587 CorsLayer::new()
588 }
589
590 /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
591 ///
592 /// See [`CorsLayer::allow_credentials`] for more details.
593 ///
594 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
595 pub fn allow_credentials<T>(self, allow_credentials: T) -> Self
596 where
597 T: Into<AllowCredentials>,
598 {
599 self.map_layer(|layer| layer.allow_credentials(allow_credentials))
600 }
601
602 /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
603 ///
604 /// See [`CorsLayer::allow_headers`] for more details.
605 ///
606 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
607 pub fn allow_headers<T>(self, headers: T) -> Self
608 where
609 T: Into<AllowHeaders>,
610 {
611 self.map_layer(|layer| layer.allow_headers(headers))
612 }
613
614 /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
615 ///
616 /// See [`CorsLayer::max_age`] for more details.
617 ///
618 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
619 pub fn max_age<T>(self, max_age: T) -> Self
620 where
621 T: Into<MaxAge>,
622 {
623 self.map_layer(|layer| layer.max_age(max_age))
624 }
625
626 /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
627 ///
628 /// See [`CorsLayer::allow_methods`] for more details.
629 ///
630 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
631 pub fn allow_methods<T>(self, methods: T) -> Self
632 where
633 T: Into<AllowMethods>,
634 {
635 self.map_layer(|layer| layer.allow_methods(methods))
636 }
637
638 /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
639 ///
640 /// See [`CorsLayer::allow_origin`] for more details.
641 ///
642 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
643 pub fn allow_origin<T>(self, origin: T) -> Self
644 where
645 T: Into<AllowOrigin>,
646 {
647 self.map_layer(|layer| layer.allow_origin(origin))
648 }
649
650 /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
651 ///
652 /// See [`CorsLayer::expose_headers`] for more details.
653 ///
654 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
655 pub fn expose_headers<T>(self, headers: T) -> Self
656 where
657 T: Into<ExposeHeaders>,
658 {
659 self.map_layer(|layer| layer.expose_headers(headers))
660 }
661
662 /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
663 ///
664 /// See [`CorsLayer::allow_private_network`] for more details.
665 ///
666 /// [wicg]: https://wicg.github.io/private-network-access/
667 pub fn allow_private_network<T>(self, allow_private_network: T) -> Self
668 where
669 T: Into<AllowPrivateNetwork>,
670 {
671 self.map_layer(|layer| layer.allow_private_network(allow_private_network))
672 }
673
674 fn map_layer<F>(mut self, f: F) -> Self
675 where
676 F: FnOnce(CorsLayer) -> CorsLayer,
677 {
678 self.layer = f(self.layer);
679
680 self.layer.update_vary_header();
681 self
682 }
683}
684
685impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Cors<S>
686where
687 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
688 ResBody: Default,
689{
690 type Response = S::Response;
691 type Error = S::Error;
692 type Future = ResponseFuture<S::Future>;
693
694 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
695 ensure_usable_cors_rules(&self.layer);
696 self.inner.poll_ready(cx)
697 }
698
699 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
700 let (parts, body) = req.into_parts();
701 let origin = parts.headers.get(&header::ORIGIN);
702
703 let mut headers = HeaderMap::new();
704
705 // These headers are applied to both preflight and subsequent regular CORS requests:
706 // https://fetch.spec.whatwg.org/#http-responses
707
708 headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
709 headers.extend(self.layer.allow_private_network.to_header(origin, &parts));
710 headers.extend(self.layer.vary.to_header());
711
712 let allow_origin_future = self.layer.allow_origin.to_future(origin, &parts);
713
714 // Return results immediately upon preflight request
715 if parts.method == Method::OPTIONS {
716 // These headers are applied only to preflight requests
717 headers.extend(self.layer.allow_methods.to_header(&parts));
718 headers.extend(self.layer.allow_headers.to_header(&parts));
719 headers.extend(self.layer.max_age.to_header(origin, &parts));
720
721 ResponseFuture {
722 inner: Kind::PreflightCall {
723 allow_origin_future,
724 headers,
725 },
726 }
727 } else {
728 // This header is applied only to non-preflight requests
729 headers.extend(self.layer.expose_headers.to_header(&parts));
730
731 let req = Request::from_parts(parts, body);
732 ResponseFuture {
733 inner: Kind::CorsCall {
734 allow_origin_future,
735 allow_origin_complete: false,
736 future: self.inner.call(req),
737 headers,
738 },
739 }
740 }
741 }
742}
743
744pin_project! {
745 /// Response future for [`Cors`].
746 pub struct ResponseFuture<F> {
747 #[pin]
748 inner: Kind<F>,
749 }
750}
751
752pin_project! {
753 #[project = KindProj]
754 enum Kind<F> {
755 CorsCall {
756 #[pin]
757 allow_origin_future: AllowOriginFuture,
758 allow_origin_complete: bool,
759 #[pin]
760 future: F,
761 headers: HeaderMap,
762 },
763 PreflightCall {
764 #[pin]
765 allow_origin_future: AllowOriginFuture,
766 headers: HeaderMap,
767 },
768 }
769}
770
771impl<F, B, E> Future for ResponseFuture<F>
772where
773 F: Future<Output = Result<Response<B>, E>>,
774 B: Default,
775{
776 type Output = Result<Response<B>, E>;
777
778 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
779 match self.project().inner.project() {
780 KindProj::CorsCall {
781 allow_origin_future,
782 allow_origin_complete,
783 future,
784 headers,
785 } => {
786 if !*allow_origin_complete {
787 headers.extend(ready!(allow_origin_future.poll(cx)));
788 *allow_origin_complete = true;
789 }
790
791 let mut response: Response<B> = ready!(future.poll(cx))?;
792
793 let response_headers = response.headers_mut();
794
795 // vary header can have multiple values, don't overwrite
796 // previously-set value(s).
797 if let Some(vary) = headers.remove(header::VARY) {
798 response_headers.append(header::VARY, vary);
799 }
800 // extend will overwrite previous headers of remaining names
801 response_headers.extend(headers.drain());
802
803 Poll::Ready(Ok(response))
804 }
805 KindProj::PreflightCall {
806 allow_origin_future,
807 headers,
808 } => {
809 headers.extend(ready!(allow_origin_future.poll(cx)));
810
811 let mut response = Response::new(B::default());
812 mem::swap(response.headers_mut(), headers);
813
814 Poll::Ready(Ok(response))
815 }
816 }
817 }
818}
819
820fn ensure_usable_cors_rules(layer: &CorsLayer) {
821 if layer.allow_credentials.is_true() {
822 assert!(
823 !layer.allow_headers.is_wildcard(),
824 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
825 with `Access-Control-Allow-Headers: *`"
826 );
827
828 assert!(
829 !layer.allow_methods.is_wildcard(),
830 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
831 with `Access-Control-Allow-Methods: *`"
832 );
833
834 assert!(
835 !layer.allow_origin.is_wildcard(),
836 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
837 with `Access-Control-Allow-Origin: *`"
838 );
839
840 assert!(
841 !layer.expose_headers.is_wildcard(),
842 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
843 with `Access-Control-Expose-Headers: *`"
844 );
845 }
846}
847
848/// Returns an iterator over the three request headers that may be involved in a CORS preflight request.
849///
850/// This is the default set of header names returned in the `vary` header
851pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
852 IntoIterator::into_iter([
853 header::ORIGIN,
854 header::ACCESS_CONTROL_REQUEST_METHOD,
855 header::ACCESS_CONTROL_REQUEST_HEADERS,
856 ])
857}