Skip to main content

tower_http/cors/
mod.rs

1//! Middleware which adds headers for [CORS][mdn].
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, Method, header};
7//! use http_body_util::Full;
8//! use bytes::Bytes;
9//! use tower::{ServiceBuilder, ServiceExt, Service};
10//! use tower_http::cors::{Any, CorsLayer};
11//! use std::convert::Infallible;
12//!
13//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
14//!     Ok(Response::new(Full::default()))
15//! }
16//!
17//! # #[tokio::main]
18//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
19//! let cors = CorsLayer::new()
20//!     // allow `GET` and `POST` when accessing the resource
21//!     .allow_methods([Method::GET, Method::POST])
22//!     // allow requests from any origin
23//!     .allow_origin(Any);
24//!
25//! let mut service = ServiceBuilder::new()
26//!     .layer(cors)
27//!     .service_fn(handle);
28//!
29//! let request = Request::builder()
30//!     .header(header::ORIGIN, "https://example.com")
31//!     .body(Full::default())
32//!     .unwrap();
33//!
34//! let response = service
35//!     .ready()
36//!     .await?
37//!     .call(request)
38//!     .await?;
39//!
40//! assert_eq!(
41//!     response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(),
42//!     "*",
43//! );
44//! # Ok(())
45//! # }
46//! ```
47//!
48//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
49
50#![allow(clippy::enum_variant_names)]
51
52use allow_origin::AllowOriginFuture;
53use bytes::{BufMut, BytesMut};
54use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response};
55use pin_project_lite::pin_project;
56use std::{
57    future::Future,
58    mem,
59    pin::Pin,
60    task::{ready, Context, Poll},
61};
62use tower_layer::Layer;
63use tower_service::Service;
64
65mod allow_credentials;
66mod allow_headers;
67mod allow_methods;
68mod allow_origin;
69mod allow_private_network;
70mod expose_headers;
71mod max_age;
72mod vary;
73
74#[cfg(test)]
75mod tests;
76
77pub use self::{
78    allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
79    allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork,
80    expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
81};
82
83/// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn].
84///
85/// See the [module docs](crate::cors) for an example.
86///
87/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
88#[derive(Debug, Clone)]
89#[must_use]
90pub struct CorsLayer {
91    allow_credentials: AllowCredentials,
92    allow_headers: AllowHeaders,
93    allow_methods: AllowMethods,
94    allow_origin: AllowOrigin,
95    allow_private_network: AllowPrivateNetwork,
96    expose_headers: ExposeHeaders,
97    max_age: MaxAge,
98    vary: Vary,
99    is_vary_custom: bool,
100}
101
102#[allow(clippy::declare_interior_mutable_const)]
103const WILDCARD: HeaderValue = HeaderValue::from_static("*");
104
105impl CorsLayer {
106    /// Create a new `CorsLayer`.
107    ///
108    /// No headers are sent by default. Use the builder methods to customize
109    /// the behavior.
110    ///
111    /// You need to set at least an allowed origin for browsers to make
112    /// successful cross-origin requests to your service.
113    pub fn new() -> Self {
114        Self {
115            allow_credentials: Default::default(),
116            allow_headers: Default::default(),
117            allow_methods: Default::default(),
118            allow_origin: Default::default(),
119            allow_private_network: Default::default(),
120            expose_headers: Default::default(),
121            max_age: Default::default(),
122            vary: Default::default(),
123            is_vary_custom: false,
124        }
125    }
126
127    /// A permissive configuration:
128    ///
129    /// - All request headers allowed.
130    /// - All methods allowed.
131    /// - All origins allowed.
132    /// - All headers exposed.
133    pub fn permissive() -> Self {
134        Self::new()
135            .allow_headers(Any)
136            .allow_methods(Any)
137            .allow_origin(Any)
138            .expose_headers(Any)
139    }
140
141    /// A very permissive configuration:
142    ///
143    /// - **Credentials allowed.**
144    /// - The method received in `Access-Control-Request-Method` is sent back
145    ///   as an allowed method.
146    /// - The origin of the preflight request is sent back as an allowed origin.
147    /// - The header names received in `Access-Control-Request-Headers` are sent
148    ///   back as allowed headers.
149    /// - No headers are currently exposed, but this may change in the future.
150    pub fn very_permissive() -> Self {
151        Self::new()
152            .allow_credentials(true)
153            .allow_headers(AllowHeaders::mirror_request())
154            .allow_methods(AllowMethods::mirror_request())
155            .allow_origin(AllowOrigin::mirror_request())
156    }
157
158    /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
159    ///
160    /// ```
161    /// use tower_http::cors::CorsLayer;
162    ///
163    /// let layer = CorsLayer::new().allow_credentials(true);
164    /// ```
165    ///
166    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
167    pub fn allow_credentials<T>(mut self, allow_credentials: T) -> Self
168    where
169        T: Into<AllowCredentials>,
170    {
171        self.allow_credentials = allow_credentials.into();
172        self
173    }
174
175    /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
176    ///
177    /// ```
178    /// use tower_http::cors::CorsLayer;
179    /// use http::header::{AUTHORIZATION, ACCEPT};
180    ///
181    /// let layer = CorsLayer::new().allow_headers([AUTHORIZATION, ACCEPT]);
182    /// ```
183    ///
184    /// All headers can be allowed with
185    ///
186    /// ```
187    /// use tower_http::cors::{Any, CorsLayer};
188    ///
189    /// let layer = CorsLayer::new().allow_headers(Any);
190    /// ```
191    ///
192    /// Note that multiple calls to this method will override any previous
193    /// calls.
194    ///
195    /// Also note that `Access-Control-Allow-Headers` is required for requests that have
196    /// `Access-Control-Request-Headers`.
197    ///
198    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
199    pub fn allow_headers<T>(mut self, headers: T) -> Self
200    where
201        T: Into<AllowHeaders>,
202    {
203        self.allow_headers = headers.into();
204        self
205    }
206
207    /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
208    ///
209    /// ```
210    /// use std::time::Duration;
211    /// use tower_http::cors::CorsLayer;
212    ///
213    /// let layer = CorsLayer::new().max_age(Duration::from_secs(60) * 10);
214    /// ```
215    ///
216    /// By default the header will not be set which disables caching and will
217    /// require a preflight call for all requests.
218    ///
219    /// Note that each browser has a maximum internal value that takes
220    /// precedence when the Access-Control-Max-Age is greater. For more details
221    /// see [mdn].
222    ///
223    /// If you need more flexibility, you can use supply a function which can
224    /// dynamically decide the max-age based on the origin and other parts of
225    /// each preflight request:
226    ///
227    /// ```
228    /// # struct MyServerConfig { cors_max_age: Duration }
229    /// use std::time::Duration;
230    ///
231    /// use http::{request::Parts as RequestParts, HeaderValue};
232    /// use tower_http::cors::{CorsLayer, MaxAge};
233    ///
234    /// let layer = CorsLayer::new().max_age(MaxAge::dynamic(
235    ///     |_origin: &HeaderValue, parts: &RequestParts| -> Duration {
236    ///         // Let's say you want to be able to reload your config at
237    ///         // runtime and have another middleware that always inserts
238    ///         // the current config into the request extensions
239    ///         let config = parts.extensions.get::<MyServerConfig>().unwrap();
240    ///         config.cors_max_age
241    ///     },
242    /// ));
243    /// ```
244    ///
245    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
246    pub fn max_age<T>(mut self, max_age: T) -> Self
247    where
248        T: Into<MaxAge>,
249    {
250        self.max_age = max_age.into();
251        self
252    }
253
254    /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
255    ///
256    /// ```
257    /// use tower_http::cors::CorsLayer;
258    /// use http::Method;
259    ///
260    /// let layer = CorsLayer::new().allow_methods([Method::GET, Method::POST]);
261    /// ```
262    ///
263    /// All methods can be allowed with
264    ///
265    /// ```
266    /// use tower_http::cors::{Any, CorsLayer};
267    ///
268    /// let layer = CorsLayer::new().allow_methods(Any);
269    /// ```
270    ///
271    /// Note that multiple calls to this method will override any previous
272    /// calls.
273    ///
274    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
275    pub fn allow_methods<T>(mut self, methods: T) -> Self
276    where
277        T: Into<AllowMethods>,
278    {
279        self.allow_methods = methods.into();
280        self
281    }
282
283    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
284    ///
285    /// ```
286    /// use http::HeaderValue;
287    /// use tower_http::cors::CorsLayer;
288    ///
289    /// let layer = CorsLayer::new().allow_origin(
290    ///     "http://example.com".parse::<HeaderValue>().unwrap(),
291    /// );
292    /// ```
293    ///
294    /// Multiple origins can be allowed with
295    ///
296    /// ```
297    /// use tower_http::cors::CorsLayer;
298    ///
299    /// let origins = [
300    ///     "http://example.com".parse().unwrap(),
301    ///     "http://api.example.com".parse().unwrap(),
302    /// ];
303    ///
304    /// let layer = CorsLayer::new().allow_origin(origins);
305    /// ```
306    ///
307    /// All origins can be allowed with
308    ///
309    /// ```
310    /// use tower_http::cors::{Any, CorsLayer};
311    ///
312    /// let layer = CorsLayer::new().allow_origin(Any);
313    /// ```
314    ///
315    /// You can also use a closure
316    ///
317    /// ```
318    /// use tower_http::cors::{CorsLayer, AllowOrigin};
319    /// use http::{request::Parts as RequestParts, HeaderValue};
320    ///
321    /// let layer = CorsLayer::new().allow_origin(AllowOrigin::predicate(
322    ///     |origin: &HeaderValue, _request_parts: &RequestParts| {
323    ///         origin.as_bytes().ends_with(b".rust-lang.org")
324    ///     },
325    /// ));
326    /// ```
327    ///
328    /// You can also use an async closure:
329    ///
330    /// ```
331    /// # #[derive(Clone)]
332    /// # struct Client;
333    /// # fn get_api_client() -> Client {
334    /// #     Client
335    /// # }
336    /// # impl Client {
337    /// #     async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
338    /// #         vec![HeaderValue::from_static("http://example.com")]
339    /// #     }
340    /// #     async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
341    /// #         vec![HeaderValue::from_static("http://example.com")]
342    /// #     }
343    /// # }
344    /// use tower_http::cors::{CorsLayer, AllowOrigin};
345    /// use http::{request::Parts as RequestParts, HeaderValue};
346    ///
347    /// let client = get_api_client();
348    ///
349    /// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
350    ///     |origin: HeaderValue, _request_parts: &RequestParts| async move {
351    ///         // fetch list of origins that are allowed
352    ///         let origins = client.fetch_allowed_origins().await;
353    ///         origins.contains(&origin)
354    ///     },
355    /// ));
356    ///
357    /// let client = get_api_client();
358    ///
359    /// // if using &RequestParts, make sure all the values are owned
360    /// // before passing into the future
361    /// let layer = CorsLayer::new().allow_origin(AllowOrigin::async_predicate(
362    ///     |origin: HeaderValue, parts: &RequestParts| {
363    ///         let path = parts.uri.path().to_owned();
364    ///
365    ///         async move {
366    ///             // fetch list of origins that are allowed for this path
367    ///             let origins = client.fetch_allowed_origins_for_path(path).await;
368    ///             origins.contains(&origin)
369    ///         }
370    ///     },
371    /// ));
372    /// ```
373    ///
374    /// Note that multiple calls to this method will override any previous
375    /// calls.
376    ///
377    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
378    pub fn allow_origin<T>(mut self, origin: T) -> Self
379    where
380        T: Into<AllowOrigin>,
381    {
382        self.allow_origin = origin.into();
383        self
384    }
385
386    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
387    ///
388    /// ```
389    /// use tower_http::cors::CorsLayer;
390    /// use http::header::CONTENT_ENCODING;
391    ///
392    /// let layer = CorsLayer::new().expose_headers([CONTENT_ENCODING]);
393    /// ```
394    ///
395    /// All headers can be allowed with
396    ///
397    /// ```
398    /// use tower_http::cors::{Any, CorsLayer};
399    ///
400    /// let layer = CorsLayer::new().expose_headers(Any);
401    /// ```
402    ///
403    /// Note that multiple calls to this method will override any previous
404    /// calls.
405    ///
406    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
407    pub fn expose_headers<T>(mut self, headers: T) -> Self
408    where
409        T: Into<ExposeHeaders>,
410    {
411        self.expose_headers = headers.into();
412        self
413    }
414
415    /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
416    ///
417    /// ```
418    /// use tower_http::cors::CorsLayer;
419    ///
420    /// let layer = CorsLayer::new().allow_private_network(true);
421    /// ```
422    ///
423    /// [wicg]: https://wicg.github.io/private-network-access/
424    pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
425    where
426        T: Into<AllowPrivateNetwork>,
427    {
428        self.allow_private_network = allow_private_network.into();
429        self
430    }
431
432    /// Set the value(s) of the [`Vary`][mdn] header.
433    ///
434    /// By default, this value is derived from whether CORS response headers are
435    /// request-dependent:
436    ///
437    /// - `Origin` is included when `Access-Control-Allow-Origin` depends on the
438    ///   request's `Origin` header (for example, origin lists or predicates).
439    /// - `Access-Control-Request-Method` is included when
440    ///   `Access-Control-Allow-Methods` mirrors `Access-Control-Request-Method`.
441    /// - `Access-Control-Request-Headers` is included when
442    ///   `Access-Control-Allow-Headers` mirrors `Access-Control-Request-Headers`.
443    /// - If none of those values are request-dependent, no `Vary` header is
444    ///   added.
445    ///
446    /// Calling this method sets `Vary` explicitly and pins it to the provided
447    /// value, regardless of future changes to those other CORS settings.
448    ///
449    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
450    pub fn vary<T>(mut self, headers: T) -> Self
451    where
452        T: Into<Vary>,
453    {
454        self.vary = headers.into();
455        self.is_vary_custom = true;
456        self
457    }
458
459    /// Recomputes the `Vary` header, if it hasn't been set explicitly.
460    fn update_vary_header(&mut self) {
461        if !self.is_vary_custom {
462            let vary_origin = self.allow_origin.varies_with_origin();
463            let vary_method = self.allow_methods.varies_with_request_method();
464            let vary_headers = self.allow_headers.varies_with_request_headers();
465
466            if !(vary_origin || vary_method || vary_headers) {
467                self.vary = Vary::list([]);
468            } else {
469                let mut vary_header_names = Vec::new();
470                if vary_origin {
471                    vary_header_names.push(header::ORIGIN);
472                }
473                if vary_method {
474                    vary_header_names.push(header::ACCESS_CONTROL_REQUEST_METHOD);
475                }
476                if vary_headers {
477                    vary_header_names.push(header::ACCESS_CONTROL_REQUEST_HEADERS);
478                }
479                self.vary = Vary::list(vary_header_names);
480            }
481        }
482    }
483}
484
485/// Represents a wildcard value (`*`) used with some CORS headers such as
486/// [`CorsLayer::allow_methods`].
487#[derive(Debug, Clone, Copy)]
488#[must_use]
489pub struct Any;
490
491/// Represents a wildcard value (`*`) used with some CORS headers such as
492/// [`CorsLayer::allow_methods`].
493#[deprecated = "Use Any as a unit struct literal instead"]
494pub fn any() -> Any {
495    Any
496}
497
498fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
499where
500    I: Iterator<Item = HeaderValue>,
501{
502    match iter.next() {
503        Some(fst) => {
504            let mut result = BytesMut::from(fst.as_bytes());
505            for val in iter {
506                result.reserve(val.len() + 1);
507                result.put_u8(b',');
508                result.extend_from_slice(val.as_bytes());
509            }
510
511            Some(HeaderValue::from_maybe_shared(result.freeze()).unwrap())
512        }
513        None => None,
514    }
515}
516
517impl Default for CorsLayer {
518    fn default() -> Self {
519        Self::new()
520    }
521}
522
523impl<S> Layer<S> for CorsLayer {
524    type Service = Cors<S>;
525
526    fn layer(&self, inner: S) -> Self::Service {
527        ensure_usable_cors_rules(self);
528
529        // Clone the layer to modify Vary header logic
530        let mut layer = self.clone();
531
532        layer.update_vary_header();
533
534        Cors { inner, layer }
535    }
536}
537
538/// Middleware which adds headers for [CORS][mdn].
539///
540/// See the [module docs](crate::cors) for an example.
541///
542/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
543#[derive(Debug, Clone)]
544#[must_use]
545pub struct Cors<S> {
546    inner: S,
547    layer: CorsLayer,
548}
549
550impl<S> Cors<S> {
551    /// Create a new `Cors`.
552    ///
553    /// See [`CorsLayer::new`] for more details.
554    pub fn new(inner: S) -> Self {
555        Self {
556            inner,
557            layer: CorsLayer::new(),
558        }
559    }
560
561    /// A permissive configuration.
562    ///
563    /// See [`CorsLayer::permissive`] for more details.
564    pub fn permissive(inner: S) -> Self {
565        Self {
566            inner,
567            layer: CorsLayer::permissive(),
568        }
569    }
570
571    /// A very permissive configuration.
572    ///
573    /// See [`CorsLayer::very_permissive`] for more details.
574    pub fn very_permissive(inner: S) -> Self {
575        Self {
576            inner,
577            layer: CorsLayer::very_permissive(),
578        }
579    }
580
581    define_inner_service_accessors!();
582
583    /// Returns a new [`Layer`] that wraps services with a [`Cors`] middleware.
584    ///
585    /// [`Layer`]: tower_layer::Layer
586    pub fn layer() -> CorsLayer {
587        CorsLayer::new()
588    }
589
590    /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
591    ///
592    /// See [`CorsLayer::allow_credentials`] for more details.
593    ///
594    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
595    pub fn allow_credentials<T>(self, allow_credentials: T) -> Self
596    where
597        T: Into<AllowCredentials>,
598    {
599        self.map_layer(|layer| layer.allow_credentials(allow_credentials))
600    }
601
602    /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
603    ///
604    /// See [`CorsLayer::allow_headers`] for more details.
605    ///
606    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
607    pub fn allow_headers<T>(self, headers: T) -> Self
608    where
609        T: Into<AllowHeaders>,
610    {
611        self.map_layer(|layer| layer.allow_headers(headers))
612    }
613
614    /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
615    ///
616    /// See [`CorsLayer::max_age`] for more details.
617    ///
618    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
619    pub fn max_age<T>(self, max_age: T) -> Self
620    where
621        T: Into<MaxAge>,
622    {
623        self.map_layer(|layer| layer.max_age(max_age))
624    }
625
626    /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
627    ///
628    /// See [`CorsLayer::allow_methods`] for more details.
629    ///
630    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
631    pub fn allow_methods<T>(self, methods: T) -> Self
632    where
633        T: Into<AllowMethods>,
634    {
635        self.map_layer(|layer| layer.allow_methods(methods))
636    }
637
638    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
639    ///
640    /// See [`CorsLayer::allow_origin`] for more details.
641    ///
642    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
643    pub fn allow_origin<T>(self, origin: T) -> Self
644    where
645        T: Into<AllowOrigin>,
646    {
647        self.map_layer(|layer| layer.allow_origin(origin))
648    }
649
650    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
651    ///
652    /// See [`CorsLayer::expose_headers`] for more details.
653    ///
654    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
655    pub fn expose_headers<T>(self, headers: T) -> Self
656    where
657        T: Into<ExposeHeaders>,
658    {
659        self.map_layer(|layer| layer.expose_headers(headers))
660    }
661
662    /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
663    ///
664    /// See [`CorsLayer::allow_private_network`] for more details.
665    ///
666    /// [wicg]: https://wicg.github.io/private-network-access/
667    pub fn allow_private_network<T>(self, allow_private_network: T) -> Self
668    where
669        T: Into<AllowPrivateNetwork>,
670    {
671        self.map_layer(|layer| layer.allow_private_network(allow_private_network))
672    }
673
674    fn map_layer<F>(mut self, f: F) -> Self
675    where
676        F: FnOnce(CorsLayer) -> CorsLayer,
677    {
678        self.layer = f(self.layer);
679
680        self.layer.update_vary_header();
681        self
682    }
683}
684
685impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Cors<S>
686where
687    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
688    ResBody: Default,
689{
690    type Response = S::Response;
691    type Error = S::Error;
692    type Future = ResponseFuture<S::Future>;
693
694    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
695        ensure_usable_cors_rules(&self.layer);
696        self.inner.poll_ready(cx)
697    }
698
699    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
700        let (parts, body) = req.into_parts();
701        let origin = parts.headers.get(&header::ORIGIN);
702
703        let mut headers = HeaderMap::new();
704
705        // These headers are applied to both preflight and subsequent regular CORS requests:
706        // https://fetch.spec.whatwg.org/#http-responses
707
708        headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
709        headers.extend(self.layer.allow_private_network.to_header(origin, &parts));
710        headers.extend(self.layer.vary.to_header());
711
712        let allow_origin_future = self.layer.allow_origin.to_future(origin, &parts);
713
714        // Return results immediately upon preflight request
715        if parts.method == Method::OPTIONS {
716            // These headers are applied only to preflight requests
717            headers.extend(self.layer.allow_methods.to_header(&parts));
718            headers.extend(self.layer.allow_headers.to_header(&parts));
719            headers.extend(self.layer.max_age.to_header(origin, &parts));
720
721            ResponseFuture {
722                inner: Kind::PreflightCall {
723                    allow_origin_future,
724                    headers,
725                },
726            }
727        } else {
728            // This header is applied only to non-preflight requests
729            headers.extend(self.layer.expose_headers.to_header(&parts));
730
731            let req = Request::from_parts(parts, body);
732            ResponseFuture {
733                inner: Kind::CorsCall {
734                    allow_origin_future,
735                    allow_origin_complete: false,
736                    future: self.inner.call(req),
737                    headers,
738                },
739            }
740        }
741    }
742}
743
744pin_project! {
745    /// Response future for [`Cors`].
746    pub struct ResponseFuture<F> {
747        #[pin]
748        inner: Kind<F>,
749    }
750}
751
752pin_project! {
753    #[project = KindProj]
754    enum Kind<F> {
755        CorsCall {
756            #[pin]
757            allow_origin_future: AllowOriginFuture,
758            allow_origin_complete: bool,
759            #[pin]
760            future: F,
761            headers: HeaderMap,
762        },
763        PreflightCall {
764            #[pin]
765            allow_origin_future: AllowOriginFuture,
766            headers: HeaderMap,
767        },
768    }
769}
770
771impl<F, B, E> Future for ResponseFuture<F>
772where
773    F: Future<Output = Result<Response<B>, E>>,
774    B: Default,
775{
776    type Output = Result<Response<B>, E>;
777
778    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
779        match self.project().inner.project() {
780            KindProj::CorsCall {
781                allow_origin_future,
782                allow_origin_complete,
783                future,
784                headers,
785            } => {
786                if !*allow_origin_complete {
787                    headers.extend(ready!(allow_origin_future.poll(cx)));
788                    *allow_origin_complete = true;
789                }
790
791                let mut response: Response<B> = ready!(future.poll(cx))?;
792
793                let response_headers = response.headers_mut();
794
795                // vary header can have multiple values, don't overwrite
796                // previously-set value(s).
797                if let Some(vary) = headers.remove(header::VARY) {
798                    response_headers.append(header::VARY, vary);
799                }
800                // extend will overwrite previous headers of remaining names
801                response_headers.extend(headers.drain());
802
803                Poll::Ready(Ok(response))
804            }
805            KindProj::PreflightCall {
806                allow_origin_future,
807                headers,
808            } => {
809                headers.extend(ready!(allow_origin_future.poll(cx)));
810
811                let mut response = Response::new(B::default());
812                mem::swap(response.headers_mut(), headers);
813
814                Poll::Ready(Ok(response))
815            }
816        }
817    }
818}
819
820fn ensure_usable_cors_rules(layer: &CorsLayer) {
821    if layer.allow_credentials.is_true() {
822        assert!(
823            !layer.allow_headers.is_wildcard(),
824            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
825             with `Access-Control-Allow-Headers: *`"
826        );
827
828        assert!(
829            !layer.allow_methods.is_wildcard(),
830            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
831             with `Access-Control-Allow-Methods: *`"
832        );
833
834        assert!(
835            !layer.allow_origin.is_wildcard(),
836            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
837             with `Access-Control-Allow-Origin: *`"
838        );
839
840        assert!(
841            !layer.expose_headers.is_wildcard(),
842            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
843             with `Access-Control-Expose-Headers: *`"
844        );
845    }
846}
847
848/// Returns an iterator over the three request headers that may be involved in a CORS preflight request.
849///
850/// This is the default set of header names returned in the `vary` header
851pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
852    IntoIterator::into_iter([
853        header::ORIGIN,
854        header::ACCESS_CONTROL_REQUEST_METHOD,
855        header::ACCESS_CONTROL_REQUEST_HEADERS,
856    ])
857}