tower_async_http/cors/
mod.rs

1//! Middleware which adds headers for [CORS][mdn].
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, Method, header};
7//! use http_body_util::Full;
8//! use bytes::Bytes;
9//! use tower_async::{ServiceBuilder, ServiceExt, Service};
10//! use tower_async_http::cors::{Any, CorsLayer};
11//! use std::convert::Infallible;
12//!
13//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
14//!     Ok(Response::new(Full::default()))
15//! }
16//!
17//! # #[tokio::main]
18//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
19//! let cors = CorsLayer::new()
20//!     // allow `GET` and `POST` when accessing the resource
21//!     .allow_methods([Method::GET, Method::POST])
22//!     // allow requests from any origin
23//!     .allow_origin(Any);
24//!
25//! let mut service = ServiceBuilder::new()
26//!     .layer(cors)
27//!     .service_fn(handle);
28//!
29//! let request = Request::builder()
30//!     .header(header::ORIGIN, "https://example.com")
31//!     .body(Full::<Bytes>::default())
32//!     .unwrap();
33//!
34//! let response = service
35//!     .call(request)
36//!     .await?;
37//!
38//! assert_eq!(
39//!     response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(),
40//!     "*",
41//! );
42//! # Ok(())
43//! # }
44//! ```
45//!
46//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
47
48#![allow(clippy::enum_variant_names)]
49
50use bytes::{BufMut, BytesMut};
51use http::{
52    header::{self, HeaderName},
53    HeaderMap, HeaderValue, Method, Request, Response,
54};
55use std::{array, mem};
56use tower_async_layer::Layer;
57use tower_async_service::Service;
58
59mod allow_credentials;
60mod allow_headers;
61mod allow_methods;
62mod allow_origin;
63mod allow_private_network;
64mod expose_headers;
65mod max_age;
66mod vary;
67
68pub use self::{
69    allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
70    allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork,
71    expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
72};
73
74/// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn].
75///
76/// See the [module docs](crate::cors) for an example.
77///
78/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
79#[derive(Debug, Clone)]
80#[must_use]
81pub struct CorsLayer {
82    allow_credentials: AllowCredentials,
83    allow_headers: AllowHeaders,
84    allow_methods: AllowMethods,
85    allow_origin: AllowOrigin,
86    allow_private_network: AllowPrivateNetwork,
87    expose_headers: ExposeHeaders,
88    max_age: MaxAge,
89    vary: Vary,
90}
91
92#[allow(clippy::declare_interior_mutable_const)]
93const WILDCARD: HeaderValue = HeaderValue::from_static("*");
94
95impl CorsLayer {
96    /// Create a new `CorsLayer`.
97    ///
98    /// No headers are sent by default. Use the builder methods to customize
99    /// the behavior.
100    ///
101    /// You need to set at least an allowed origin for browsers to make
102    /// successful cross-origin requests to your service.
103    pub fn new() -> Self {
104        Self {
105            allow_credentials: Default::default(),
106            allow_headers: Default::default(),
107            allow_methods: Default::default(),
108            allow_origin: Default::default(),
109            allow_private_network: Default::default(),
110            expose_headers: Default::default(),
111            max_age: Default::default(),
112            vary: Default::default(),
113        }
114    }
115
116    /// A permissive configuration:
117    ///
118    /// - All request headers allowed.
119    /// - All methods allowed.
120    /// - All origins allowed.
121    /// - All headers exposed.
122    pub fn permissive() -> Self {
123        Self::new()
124            .allow_headers(Any)
125            .allow_methods(Any)
126            .allow_origin(Any)
127            .expose_headers(Any)
128    }
129
130    /// A very permissive configuration:
131    ///
132    /// - **Credentials allowed.**
133    /// - The method received in `Access-Control-Request-Method` is sent back
134    ///   as an allowed method.
135    /// - The origin of the preflight request is sent back as an allowed origin.
136    /// - The header names received in `Access-Control-Request-Headers` are sent
137    ///   back as allowed headers.
138    /// - No headers are currently exposed, but this may change in the future.
139    pub fn very_permissive() -> Self {
140        Self::new()
141            .allow_credentials(true)
142            .allow_headers(AllowHeaders::mirror_request())
143            .allow_methods(AllowMethods::mirror_request())
144            .allow_origin(AllowOrigin::mirror_request())
145    }
146
147    /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
148    ///
149    /// ```
150    /// use tower_async_http::cors::CorsLayer;
151    ///
152    /// let layer = CorsLayer::new().allow_credentials(true);
153    /// ```
154    ///
155    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
156    pub fn allow_credentials<T>(mut self, allow_credentials: T) -> Self
157    where
158        T: Into<AllowCredentials>,
159    {
160        self.allow_credentials = allow_credentials.into();
161        self
162    }
163
164    /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
165    ///
166    /// ```
167    /// use tower_async_http::cors::CorsLayer;
168    /// use http::header::{AUTHORIZATION, ACCEPT};
169    ///
170    /// let layer = CorsLayer::new().allow_headers([AUTHORIZATION, ACCEPT]);
171    /// ```
172    ///
173    /// All headers can be allowed with
174    ///
175    /// ```
176    /// use tower_async_http::cors::{Any, CorsLayer};
177    ///
178    /// let layer = CorsLayer::new().allow_headers(Any);
179    /// ```
180    ///
181    /// Note that multiple calls to this method will override any previous
182    /// calls.
183    ///
184    /// Also note that `Access-Control-Allow-Headers` is required for requests that have
185    /// `Access-Control-Request-Headers`.
186    ///
187    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
188    pub fn allow_headers<T>(mut self, headers: T) -> Self
189    where
190        T: Into<AllowHeaders>,
191    {
192        self.allow_headers = headers.into();
193        self
194    }
195
196    /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
197    ///
198    /// ```
199    /// use std::time::Duration;
200    /// use tower_async_http::cors::CorsLayer;
201    ///
202    /// let layer = CorsLayer::new().max_age(Duration::from_secs(60) * 10);
203    /// ```
204    ///
205    /// By default the header will not be set which disables caching and will
206    /// require a preflight call for all requests.
207    ///
208    /// Note that each browser has a maximum internal value that takes
209    /// precedence when the Access-Control-Max-Age is greater. For more details
210    /// see [mdn].
211    ///
212    /// If you need more flexibility, you can use supply a function which can
213    /// dynamically decide the max-age based on the origin and other parts of
214    /// each preflight request:
215    ///
216    /// ```
217    /// # struct MyServerConfig { cors_max_age: Duration }
218    /// use std::time::Duration;
219    ///
220    /// use http::{request::Parts as RequestParts, HeaderValue};
221    /// use tower_async_http::cors::{CorsLayer, MaxAge};
222    ///
223    /// let layer = CorsLayer::new().max_age(MaxAge::dynamic(
224    ///     |_origin: &HeaderValue, parts: &RequestParts| -> Duration {
225    ///         // Let's say you want to be able to reload your config at
226    ///         // runtime and have another middleware that always inserts
227    ///         // the current config into the request extensions
228    ///         let config = parts.extensions.get::<MyServerConfig>().unwrap();
229    ///         config.cors_max_age
230    ///     },
231    /// ));
232    /// ```
233    ///
234    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
235    pub fn max_age<T>(mut self, max_age: T) -> Self
236    where
237        T: Into<MaxAge>,
238    {
239        self.max_age = max_age.into();
240        self
241    }
242
243    /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
244    ///
245    /// ```
246    /// use tower_async_http::cors::CorsLayer;
247    /// use http::Method;
248    ///
249    /// let layer = CorsLayer::new().allow_methods([Method::GET, Method::POST]);
250    /// ```
251    ///
252    /// All methods can be allowed with
253    ///
254    /// ```
255    /// use tower_async_http::cors::{Any, CorsLayer};
256    ///
257    /// let layer = CorsLayer::new().allow_methods(Any);
258    /// ```
259    ///
260    /// Note that multiple calls to this method will override any previous
261    /// calls.
262    ///
263    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
264    pub fn allow_methods<T>(mut self, methods: T) -> Self
265    where
266        T: Into<AllowMethods>,
267    {
268        self.allow_methods = methods.into();
269        self
270    }
271
272    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
273    ///
274    /// ```
275    /// use http::HeaderValue;
276    /// use tower_async_http::cors::CorsLayer;
277    ///
278    /// let layer = CorsLayer::new().allow_origin(
279    ///     "http://example.com".parse::<HeaderValue>().unwrap(),
280    /// );
281    /// ```
282    ///
283    /// Multiple origins can be allowed with
284    ///
285    /// ```
286    /// use tower_async_http::cors::CorsLayer;
287    ///
288    /// let origins = [
289    ///     "http://example.com".parse().unwrap(),
290    ///     "http://api.example.com".parse().unwrap(),
291    /// ];
292    ///
293    /// let layer = CorsLayer::new().allow_origin(origins);
294    /// ```
295    ///
296    /// All origins can be allowed with
297    ///
298    /// ```
299    /// use tower_async_http::cors::{Any, CorsLayer};
300    ///
301    /// let layer = CorsLayer::new().allow_origin(Any);
302    /// ```
303    ///
304    /// You can also use a closure
305    ///
306    /// ```
307    /// use tower_async_http::cors::{CorsLayer, AllowOrigin};
308    /// use http::{request::Parts as RequestParts, HeaderValue};
309    ///
310    /// let layer = CorsLayer::new().allow_origin(AllowOrigin::predicate(
311    ///     |origin: &HeaderValue, _request_parts: &RequestParts| {
312    ///         origin.as_bytes().ends_with(b".rust-lang.org")
313    ///     },
314    /// ));
315    /// ```
316    ///
317    /// Note that multiple calls to this method will override any previous
318    /// calls.
319    ///
320    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
321    pub fn allow_origin<T>(mut self, origin: T) -> Self
322    where
323        T: Into<AllowOrigin>,
324    {
325        self.allow_origin = origin.into();
326        self
327    }
328
329    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
330    ///
331    /// ```
332    /// use tower_async_http::cors::CorsLayer;
333    /// use http::header::CONTENT_ENCODING;
334    ///
335    /// let layer = CorsLayer::new().expose_headers([CONTENT_ENCODING]);
336    /// ```
337    ///
338    /// All headers can be allowed with
339    ///
340    /// ```
341    /// use tower_async_http::cors::{Any, CorsLayer};
342    ///
343    /// let layer = CorsLayer::new().expose_headers(Any);
344    /// ```
345    ///
346    /// Note that multiple calls to this method will override any previous
347    /// calls.
348    ///
349    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
350    pub fn expose_headers<T>(mut self, headers: T) -> Self
351    where
352        T: Into<ExposeHeaders>,
353    {
354        self.expose_headers = headers.into();
355        self
356    }
357
358    /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
359    ///
360    /// ```
361    /// use tower_async_http::cors::CorsLayer;
362    ///
363    /// let layer = CorsLayer::new().allow_private_network(true);
364    /// ```
365    ///
366    /// [wicg]: https://wicg.github.io/private-network-access/
367    pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
368    where
369        T: Into<AllowPrivateNetwork>,
370    {
371        self.allow_private_network = allow_private_network.into();
372        self
373    }
374
375    /// Set the value(s) of the [`Vary`][mdn] header.
376    ///
377    /// In contrast to the other headers, this one has a non-empty default of
378    /// [`preflight_request_headers()`].
379    ///
380    /// You only need to set this is you want to remove some of these defaults,
381    /// or if you use a closure for one of the other headers and want to add a
382    /// vary header accordingly.
383    ///
384    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
385    pub fn vary<T>(mut self, headers: T) -> Self
386    where
387        T: Into<Vary>,
388    {
389        self.vary = headers.into();
390        self
391    }
392}
393
394/// Represents a wildcard value (`*`) used with some CORS headers such as
395/// [`CorsLayer::allow_methods`].
396#[derive(Debug, Clone, Copy)]
397#[must_use]
398pub struct Any;
399
400/// Represents a wildcard value (`*`) used with some CORS headers such as
401/// [`CorsLayer::allow_methods`].
402#[deprecated = "Use Any as a unit struct literal instead"]
403pub fn any() -> Any {
404    Any
405}
406
407fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
408where
409    I: Iterator<Item = HeaderValue>,
410{
411    match iter.next() {
412        Some(fst) => {
413            let mut result = BytesMut::from(fst.as_bytes());
414            for val in iter {
415                result.reserve(val.len() + 1);
416                result.put_u8(b',');
417                result.extend_from_slice(val.as_bytes());
418            }
419
420            Some(HeaderValue::from_maybe_shared(result.freeze()).unwrap())
421        }
422        None => None,
423    }
424}
425
426impl Default for CorsLayer {
427    fn default() -> Self {
428        Self::new()
429    }
430}
431
432impl<S> Layer<S> for CorsLayer {
433    type Service = Cors<S>;
434
435    fn layer(&self, inner: S) -> Self::Service {
436        ensure_usable_cors_rules(self);
437
438        Cors {
439            inner,
440            layer: self.clone(),
441        }
442    }
443}
444
445/// Middleware which adds headers for [CORS][mdn].
446///
447/// See the [module docs](crate::cors) for an example.
448///
449/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
450#[derive(Debug, Clone)]
451#[must_use]
452pub struct Cors<S> {
453    inner: S,
454    layer: CorsLayer,
455}
456
457impl<S> Cors<S> {
458    /// Create a new `Cors`.
459    ///
460    /// See [`CorsLayer::new`] for more details.
461    pub fn new(inner: S) -> Self {
462        Self {
463            inner,
464            layer: CorsLayer::new(),
465        }
466    }
467
468    /// A permissive configuration.
469    ///
470    /// See [`CorsLayer::permissive`] for more details.
471    pub fn permissive(inner: S) -> Self {
472        Self {
473            inner,
474            layer: CorsLayer::permissive(),
475        }
476    }
477
478    /// A very permissive configuration.
479    ///
480    /// See [`CorsLayer::very_permissive`] for more details.
481    pub fn very_permissive(inner: S) -> Self {
482        Self {
483            inner,
484            layer: CorsLayer::very_permissive(),
485        }
486    }
487
488    define_inner_service_accessors!();
489
490    /// Returns a new [`Layer`] that wraps services with a [`Cors`] middleware.
491    ///
492    /// [`Layer`]: tower_async_layer::Layer
493    pub fn layer() -> CorsLayer {
494        CorsLayer::new()
495    }
496
497    /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
498    ///
499    /// See [`CorsLayer::allow_credentials`] for more details.
500    ///
501    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
502    pub fn allow_credentials<T>(self, allow_credentials: T) -> Self
503    where
504        T: Into<AllowCredentials>,
505    {
506        self.map_layer(|layer| layer.allow_credentials(allow_credentials))
507    }
508
509    /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
510    ///
511    /// See [`CorsLayer::allow_headers`] for more details.
512    ///
513    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
514    pub fn allow_headers<T>(self, headers: T) -> Self
515    where
516        T: Into<AllowHeaders>,
517    {
518        self.map_layer(|layer| layer.allow_headers(headers))
519    }
520
521    /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
522    ///
523    /// See [`CorsLayer::max_age`] for more details.
524    ///
525    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
526    pub fn max_age<T>(self, max_age: T) -> Self
527    where
528        T: Into<MaxAge>,
529    {
530        self.map_layer(|layer| layer.max_age(max_age))
531    }
532
533    /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
534    ///
535    /// See [`CorsLayer::allow_methods`] for more details.
536    ///
537    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
538    pub fn allow_methods<T>(self, methods: T) -> Self
539    where
540        T: Into<AllowMethods>,
541    {
542        self.map_layer(|layer| layer.allow_methods(methods))
543    }
544
545    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
546    ///
547    /// See [`CorsLayer::allow_origin`] for more details.
548    ///
549    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
550    pub fn allow_origin<T>(self, origin: T) -> Self
551    where
552        T: Into<AllowOrigin>,
553    {
554        self.map_layer(|layer| layer.allow_origin(origin))
555    }
556
557    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
558    ///
559    /// See [`CorsLayer::expose_headers`] for more details.
560    ///
561    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
562    pub fn expose_headers<T>(self, headers: T) -> Self
563    where
564        T: Into<ExposeHeaders>,
565    {
566        self.map_layer(|layer| layer.expose_headers(headers))
567    }
568
569    /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
570    ///
571    /// See [`CorsLayer::allow_private_network`] for more details.
572    ///
573    /// [wicg]: https://wicg.github.io/private-network-access/
574    pub fn allow_private_network<T>(self, allow_private_network: T) -> Self
575    where
576        T: Into<AllowPrivateNetwork>,
577    {
578        self.map_layer(|layer| layer.allow_private_network(allow_private_network))
579    }
580
581    fn map_layer<F>(mut self, f: F) -> Self
582    where
583        F: FnOnce(CorsLayer) -> CorsLayer,
584    {
585        self.layer = f(self.layer);
586        self
587    }
588}
589
590impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Cors<S>
591where
592    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
593    ResBody: Default,
594{
595    type Response = S::Response;
596    type Error = S::Error;
597
598    async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
599        let (parts, body) = req.into_parts();
600        let origin = parts.headers.get(&header::ORIGIN);
601
602        let mut headers = HeaderMap::new();
603
604        // These headers are applied to both preflight and subsequent regular CORS requests:
605        // https://fetch.spec.whatwg.org/#http-responses
606
607        headers.extend(self.layer.allow_origin.to_header(origin, &parts));
608        headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
609        headers.extend(self.layer.allow_private_network.to_header(origin, &parts));
610
611        let mut vary_headers = self.layer.vary.values();
612        if let Some(first) = vary_headers.next() {
613            let mut header = match headers.entry(header::VARY) {
614                header::Entry::Occupied(_) => {
615                    unreachable!("no vary header inserted up to this point")
616                }
617                header::Entry::Vacant(v) => v.insert_entry(first),
618            };
619
620            for val in vary_headers {
621                header.append(val);
622            }
623        }
624
625        // Return results immediately upon preflight request
626        if parts.method == Method::OPTIONS {
627            // These headers are applied only to preflight requests
628            headers.extend(self.layer.allow_methods.to_header(&parts));
629            headers.extend(self.layer.allow_headers.to_header(&parts));
630            headers.extend(self.layer.max_age.to_header(origin, &parts));
631
632            let mut response = Response::new(ResBody::default());
633            mem::swap(response.headers_mut(), &mut headers);
634
635            Ok(response)
636        } else {
637            // This header is applied only to non-preflight requests
638            headers.extend(self.layer.expose_headers.to_header(&parts));
639
640            let req = Request::from_parts(parts, body);
641
642            let mut response: Response<ResBody> = self.inner.call(req).await?;
643            response.headers_mut().extend(headers.drain());
644
645            Ok(response)
646        }
647    }
648}
649
650fn ensure_usable_cors_rules(layer: &CorsLayer) {
651    if layer.allow_credentials.is_true() {
652        assert!(
653            !layer.allow_headers.is_wildcard(),
654            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
655             with `Access-Control-Allow-Headers: *`"
656        );
657
658        assert!(
659            !layer.allow_methods.is_wildcard(),
660            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
661             with `Access-Control-Allow-Methods: *`"
662        );
663
664        assert!(
665            !layer.allow_origin.is_wildcard(),
666            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
667             with `Access-Control-Allow-Origin: *`"
668        );
669
670        assert!(
671            !layer.expose_headers.is_wildcard(),
672            "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
673             with `Access-Control-Expose-Headers: *`"
674        );
675    }
676}
677
678/// Returns an iterator over the three request headers that may be involved in a CORS preflight request.
679///
680/// This is the default set of header names returned in the `vary` header
681pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
682    #[allow(deprecated)] // Can be changed when MSRV >= 1.53
683    array::IntoIter::new([
684        header::ORIGIN,
685        header::ACCESS_CONTROL_REQUEST_METHOD,
686        header::ACCESS_CONTROL_REQUEST_HEADERS,
687    ])
688}