tower_async_http/cors/mod.rs
1//! Middleware which adds headers for [CORS][mdn].
2//!
3//! # Example
4//!
5//! ```
6//! use http::{Request, Response, Method, header};
7//! use http_body_util::Full;
8//! use bytes::Bytes;
9//! use tower_async::{ServiceBuilder, ServiceExt, Service};
10//! use tower_async_http::cors::{Any, CorsLayer};
11//! use std::convert::Infallible;
12//!
13//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
14//! Ok(Response::new(Full::default()))
15//! }
16//!
17//! # #[tokio::main]
18//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
19//! let cors = CorsLayer::new()
20//! // allow `GET` and `POST` when accessing the resource
21//! .allow_methods([Method::GET, Method::POST])
22//! // allow requests from any origin
23//! .allow_origin(Any);
24//!
25//! let mut service = ServiceBuilder::new()
26//! .layer(cors)
27//! .service_fn(handle);
28//!
29//! let request = Request::builder()
30//! .header(header::ORIGIN, "https://example.com")
31//! .body(Full::<Bytes>::default())
32//! .unwrap();
33//!
34//! let response = service
35//! .call(request)
36//! .await?;
37//!
38//! assert_eq!(
39//! response.headers().get(header::ACCESS_CONTROL_ALLOW_ORIGIN).unwrap(),
40//! "*",
41//! );
42//! # Ok(())
43//! # }
44//! ```
45//!
46//! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
47
48#![allow(clippy::enum_variant_names)]
49
50use bytes::{BufMut, BytesMut};
51use http::{
52 header::{self, HeaderName},
53 HeaderMap, HeaderValue, Method, Request, Response,
54};
55use std::{array, mem};
56use tower_async_layer::Layer;
57use tower_async_service::Service;
58
59mod allow_credentials;
60mod allow_headers;
61mod allow_methods;
62mod allow_origin;
63mod allow_private_network;
64mod expose_headers;
65mod max_age;
66mod vary;
67
68pub use self::{
69 allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
70 allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork,
71 expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
72};
73
74/// Layer that applies the [`Cors`] middleware which adds headers for [CORS][mdn].
75///
76/// See the [module docs](crate::cors) for an example.
77///
78/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
79#[derive(Debug, Clone)]
80#[must_use]
81pub struct CorsLayer {
82 allow_credentials: AllowCredentials,
83 allow_headers: AllowHeaders,
84 allow_methods: AllowMethods,
85 allow_origin: AllowOrigin,
86 allow_private_network: AllowPrivateNetwork,
87 expose_headers: ExposeHeaders,
88 max_age: MaxAge,
89 vary: Vary,
90}
91
92#[allow(clippy::declare_interior_mutable_const)]
93const WILDCARD: HeaderValue = HeaderValue::from_static("*");
94
95impl CorsLayer {
96 /// Create a new `CorsLayer`.
97 ///
98 /// No headers are sent by default. Use the builder methods to customize
99 /// the behavior.
100 ///
101 /// You need to set at least an allowed origin for browsers to make
102 /// successful cross-origin requests to your service.
103 pub fn new() -> Self {
104 Self {
105 allow_credentials: Default::default(),
106 allow_headers: Default::default(),
107 allow_methods: Default::default(),
108 allow_origin: Default::default(),
109 allow_private_network: Default::default(),
110 expose_headers: Default::default(),
111 max_age: Default::default(),
112 vary: Default::default(),
113 }
114 }
115
116 /// A permissive configuration:
117 ///
118 /// - All request headers allowed.
119 /// - All methods allowed.
120 /// - All origins allowed.
121 /// - All headers exposed.
122 pub fn permissive() -> Self {
123 Self::new()
124 .allow_headers(Any)
125 .allow_methods(Any)
126 .allow_origin(Any)
127 .expose_headers(Any)
128 }
129
130 /// A very permissive configuration:
131 ///
132 /// - **Credentials allowed.**
133 /// - The method received in `Access-Control-Request-Method` is sent back
134 /// as an allowed method.
135 /// - The origin of the preflight request is sent back as an allowed origin.
136 /// - The header names received in `Access-Control-Request-Headers` are sent
137 /// back as allowed headers.
138 /// - No headers are currently exposed, but this may change in the future.
139 pub fn very_permissive() -> Self {
140 Self::new()
141 .allow_credentials(true)
142 .allow_headers(AllowHeaders::mirror_request())
143 .allow_methods(AllowMethods::mirror_request())
144 .allow_origin(AllowOrigin::mirror_request())
145 }
146
147 /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
148 ///
149 /// ```
150 /// use tower_async_http::cors::CorsLayer;
151 ///
152 /// let layer = CorsLayer::new().allow_credentials(true);
153 /// ```
154 ///
155 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
156 pub fn allow_credentials<T>(mut self, allow_credentials: T) -> Self
157 where
158 T: Into<AllowCredentials>,
159 {
160 self.allow_credentials = allow_credentials.into();
161 self
162 }
163
164 /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
165 ///
166 /// ```
167 /// use tower_async_http::cors::CorsLayer;
168 /// use http::header::{AUTHORIZATION, ACCEPT};
169 ///
170 /// let layer = CorsLayer::new().allow_headers([AUTHORIZATION, ACCEPT]);
171 /// ```
172 ///
173 /// All headers can be allowed with
174 ///
175 /// ```
176 /// use tower_async_http::cors::{Any, CorsLayer};
177 ///
178 /// let layer = CorsLayer::new().allow_headers(Any);
179 /// ```
180 ///
181 /// Note that multiple calls to this method will override any previous
182 /// calls.
183 ///
184 /// Also note that `Access-Control-Allow-Headers` is required for requests that have
185 /// `Access-Control-Request-Headers`.
186 ///
187 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
188 pub fn allow_headers<T>(mut self, headers: T) -> Self
189 where
190 T: Into<AllowHeaders>,
191 {
192 self.allow_headers = headers.into();
193 self
194 }
195
196 /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
197 ///
198 /// ```
199 /// use std::time::Duration;
200 /// use tower_async_http::cors::CorsLayer;
201 ///
202 /// let layer = CorsLayer::new().max_age(Duration::from_secs(60) * 10);
203 /// ```
204 ///
205 /// By default the header will not be set which disables caching and will
206 /// require a preflight call for all requests.
207 ///
208 /// Note that each browser has a maximum internal value that takes
209 /// precedence when the Access-Control-Max-Age is greater. For more details
210 /// see [mdn].
211 ///
212 /// If you need more flexibility, you can use supply a function which can
213 /// dynamically decide the max-age based on the origin and other parts of
214 /// each preflight request:
215 ///
216 /// ```
217 /// # struct MyServerConfig { cors_max_age: Duration }
218 /// use std::time::Duration;
219 ///
220 /// use http::{request::Parts as RequestParts, HeaderValue};
221 /// use tower_async_http::cors::{CorsLayer, MaxAge};
222 ///
223 /// let layer = CorsLayer::new().max_age(MaxAge::dynamic(
224 /// |_origin: &HeaderValue, parts: &RequestParts| -> Duration {
225 /// // Let's say you want to be able to reload your config at
226 /// // runtime and have another middleware that always inserts
227 /// // the current config into the request extensions
228 /// let config = parts.extensions.get::<MyServerConfig>().unwrap();
229 /// config.cors_max_age
230 /// },
231 /// ));
232 /// ```
233 ///
234 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
235 pub fn max_age<T>(mut self, max_age: T) -> Self
236 where
237 T: Into<MaxAge>,
238 {
239 self.max_age = max_age.into();
240 self
241 }
242
243 /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
244 ///
245 /// ```
246 /// use tower_async_http::cors::CorsLayer;
247 /// use http::Method;
248 ///
249 /// let layer = CorsLayer::new().allow_methods([Method::GET, Method::POST]);
250 /// ```
251 ///
252 /// All methods can be allowed with
253 ///
254 /// ```
255 /// use tower_async_http::cors::{Any, CorsLayer};
256 ///
257 /// let layer = CorsLayer::new().allow_methods(Any);
258 /// ```
259 ///
260 /// Note that multiple calls to this method will override any previous
261 /// calls.
262 ///
263 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
264 pub fn allow_methods<T>(mut self, methods: T) -> Self
265 where
266 T: Into<AllowMethods>,
267 {
268 self.allow_methods = methods.into();
269 self
270 }
271
272 /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
273 ///
274 /// ```
275 /// use http::HeaderValue;
276 /// use tower_async_http::cors::CorsLayer;
277 ///
278 /// let layer = CorsLayer::new().allow_origin(
279 /// "http://example.com".parse::<HeaderValue>().unwrap(),
280 /// );
281 /// ```
282 ///
283 /// Multiple origins can be allowed with
284 ///
285 /// ```
286 /// use tower_async_http::cors::CorsLayer;
287 ///
288 /// let origins = [
289 /// "http://example.com".parse().unwrap(),
290 /// "http://api.example.com".parse().unwrap(),
291 /// ];
292 ///
293 /// let layer = CorsLayer::new().allow_origin(origins);
294 /// ```
295 ///
296 /// All origins can be allowed with
297 ///
298 /// ```
299 /// use tower_async_http::cors::{Any, CorsLayer};
300 ///
301 /// let layer = CorsLayer::new().allow_origin(Any);
302 /// ```
303 ///
304 /// You can also use a closure
305 ///
306 /// ```
307 /// use tower_async_http::cors::{CorsLayer, AllowOrigin};
308 /// use http::{request::Parts as RequestParts, HeaderValue};
309 ///
310 /// let layer = CorsLayer::new().allow_origin(AllowOrigin::predicate(
311 /// |origin: &HeaderValue, _request_parts: &RequestParts| {
312 /// origin.as_bytes().ends_with(b".rust-lang.org")
313 /// },
314 /// ));
315 /// ```
316 ///
317 /// Note that multiple calls to this method will override any previous
318 /// calls.
319 ///
320 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
321 pub fn allow_origin<T>(mut self, origin: T) -> Self
322 where
323 T: Into<AllowOrigin>,
324 {
325 self.allow_origin = origin.into();
326 self
327 }
328
329 /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
330 ///
331 /// ```
332 /// use tower_async_http::cors::CorsLayer;
333 /// use http::header::CONTENT_ENCODING;
334 ///
335 /// let layer = CorsLayer::new().expose_headers([CONTENT_ENCODING]);
336 /// ```
337 ///
338 /// All headers can be allowed with
339 ///
340 /// ```
341 /// use tower_async_http::cors::{Any, CorsLayer};
342 ///
343 /// let layer = CorsLayer::new().expose_headers(Any);
344 /// ```
345 ///
346 /// Note that multiple calls to this method will override any previous
347 /// calls.
348 ///
349 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
350 pub fn expose_headers<T>(mut self, headers: T) -> Self
351 where
352 T: Into<ExposeHeaders>,
353 {
354 self.expose_headers = headers.into();
355 self
356 }
357
358 /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
359 ///
360 /// ```
361 /// use tower_async_http::cors::CorsLayer;
362 ///
363 /// let layer = CorsLayer::new().allow_private_network(true);
364 /// ```
365 ///
366 /// [wicg]: https://wicg.github.io/private-network-access/
367 pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
368 where
369 T: Into<AllowPrivateNetwork>,
370 {
371 self.allow_private_network = allow_private_network.into();
372 self
373 }
374
375 /// Set the value(s) of the [`Vary`][mdn] header.
376 ///
377 /// In contrast to the other headers, this one has a non-empty default of
378 /// [`preflight_request_headers()`].
379 ///
380 /// You only need to set this is you want to remove some of these defaults,
381 /// or if you use a closure for one of the other headers and want to add a
382 /// vary header accordingly.
383 ///
384 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
385 pub fn vary<T>(mut self, headers: T) -> Self
386 where
387 T: Into<Vary>,
388 {
389 self.vary = headers.into();
390 self
391 }
392}
393
394/// Represents a wildcard value (`*`) used with some CORS headers such as
395/// [`CorsLayer::allow_methods`].
396#[derive(Debug, Clone, Copy)]
397#[must_use]
398pub struct Any;
399
400/// Represents a wildcard value (`*`) used with some CORS headers such as
401/// [`CorsLayer::allow_methods`].
402#[deprecated = "Use Any as a unit struct literal instead"]
403pub fn any() -> Any {
404 Any
405}
406
407fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
408where
409 I: Iterator<Item = HeaderValue>,
410{
411 match iter.next() {
412 Some(fst) => {
413 let mut result = BytesMut::from(fst.as_bytes());
414 for val in iter {
415 result.reserve(val.len() + 1);
416 result.put_u8(b',');
417 result.extend_from_slice(val.as_bytes());
418 }
419
420 Some(HeaderValue::from_maybe_shared(result.freeze()).unwrap())
421 }
422 None => None,
423 }
424}
425
426impl Default for CorsLayer {
427 fn default() -> Self {
428 Self::new()
429 }
430}
431
432impl<S> Layer<S> for CorsLayer {
433 type Service = Cors<S>;
434
435 fn layer(&self, inner: S) -> Self::Service {
436 ensure_usable_cors_rules(self);
437
438 Cors {
439 inner,
440 layer: self.clone(),
441 }
442 }
443}
444
445/// Middleware which adds headers for [CORS][mdn].
446///
447/// See the [module docs](crate::cors) for an example.
448///
449/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
450#[derive(Debug, Clone)]
451#[must_use]
452pub struct Cors<S> {
453 inner: S,
454 layer: CorsLayer,
455}
456
457impl<S> Cors<S> {
458 /// Create a new `Cors`.
459 ///
460 /// See [`CorsLayer::new`] for more details.
461 pub fn new(inner: S) -> Self {
462 Self {
463 inner,
464 layer: CorsLayer::new(),
465 }
466 }
467
468 /// A permissive configuration.
469 ///
470 /// See [`CorsLayer::permissive`] for more details.
471 pub fn permissive(inner: S) -> Self {
472 Self {
473 inner,
474 layer: CorsLayer::permissive(),
475 }
476 }
477
478 /// A very permissive configuration.
479 ///
480 /// See [`CorsLayer::very_permissive`] for more details.
481 pub fn very_permissive(inner: S) -> Self {
482 Self {
483 inner,
484 layer: CorsLayer::very_permissive(),
485 }
486 }
487
488 define_inner_service_accessors!();
489
490 /// Returns a new [`Layer`] that wraps services with a [`Cors`] middleware.
491 ///
492 /// [`Layer`]: tower_async_layer::Layer
493 pub fn layer() -> CorsLayer {
494 CorsLayer::new()
495 }
496
497 /// Set the [`Access-Control-Allow-Credentials`][mdn] header.
498 ///
499 /// See [`CorsLayer::allow_credentials`] for more details.
500 ///
501 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
502 pub fn allow_credentials<T>(self, allow_credentials: T) -> Self
503 where
504 T: Into<AllowCredentials>,
505 {
506 self.map_layer(|layer| layer.allow_credentials(allow_credentials))
507 }
508
509 /// Set the value of the [`Access-Control-Allow-Headers`][mdn] header.
510 ///
511 /// See [`CorsLayer::allow_headers`] for more details.
512 ///
513 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
514 pub fn allow_headers<T>(self, headers: T) -> Self
515 where
516 T: Into<AllowHeaders>,
517 {
518 self.map_layer(|layer| layer.allow_headers(headers))
519 }
520
521 /// Set the value of the [`Access-Control-Max-Age`][mdn] header.
522 ///
523 /// See [`CorsLayer::max_age`] for more details.
524 ///
525 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
526 pub fn max_age<T>(self, max_age: T) -> Self
527 where
528 T: Into<MaxAge>,
529 {
530 self.map_layer(|layer| layer.max_age(max_age))
531 }
532
533 /// Set the value of the [`Access-Control-Allow-Methods`][mdn] header.
534 ///
535 /// See [`CorsLayer::allow_methods`] for more details.
536 ///
537 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
538 pub fn allow_methods<T>(self, methods: T) -> Self
539 where
540 T: Into<AllowMethods>,
541 {
542 self.map_layer(|layer| layer.allow_methods(methods))
543 }
544
545 /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
546 ///
547 /// See [`CorsLayer::allow_origin`] for more details.
548 ///
549 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
550 pub fn allow_origin<T>(self, origin: T) -> Self
551 where
552 T: Into<AllowOrigin>,
553 {
554 self.map_layer(|layer| layer.allow_origin(origin))
555 }
556
557 /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
558 ///
559 /// See [`CorsLayer::expose_headers`] for more details.
560 ///
561 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
562 pub fn expose_headers<T>(self, headers: T) -> Self
563 where
564 T: Into<ExposeHeaders>,
565 {
566 self.map_layer(|layer| layer.expose_headers(headers))
567 }
568
569 /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
570 ///
571 /// See [`CorsLayer::allow_private_network`] for more details.
572 ///
573 /// [wicg]: https://wicg.github.io/private-network-access/
574 pub fn allow_private_network<T>(self, allow_private_network: T) -> Self
575 where
576 T: Into<AllowPrivateNetwork>,
577 {
578 self.map_layer(|layer| layer.allow_private_network(allow_private_network))
579 }
580
581 fn map_layer<F>(mut self, f: F) -> Self
582 where
583 F: FnOnce(CorsLayer) -> CorsLayer,
584 {
585 self.layer = f(self.layer);
586 self
587 }
588}
589
590impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for Cors<S>
591where
592 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
593 ResBody: Default,
594{
595 type Response = S::Response;
596 type Error = S::Error;
597
598 async fn call(&self, req: Request<ReqBody>) -> Result<Self::Response, Self::Error> {
599 let (parts, body) = req.into_parts();
600 let origin = parts.headers.get(&header::ORIGIN);
601
602 let mut headers = HeaderMap::new();
603
604 // These headers are applied to both preflight and subsequent regular CORS requests:
605 // https://fetch.spec.whatwg.org/#http-responses
606
607 headers.extend(self.layer.allow_origin.to_header(origin, &parts));
608 headers.extend(self.layer.allow_credentials.to_header(origin, &parts));
609 headers.extend(self.layer.allow_private_network.to_header(origin, &parts));
610
611 let mut vary_headers = self.layer.vary.values();
612 if let Some(first) = vary_headers.next() {
613 let mut header = match headers.entry(header::VARY) {
614 header::Entry::Occupied(_) => {
615 unreachable!("no vary header inserted up to this point")
616 }
617 header::Entry::Vacant(v) => v.insert_entry(first),
618 };
619
620 for val in vary_headers {
621 header.append(val);
622 }
623 }
624
625 // Return results immediately upon preflight request
626 if parts.method == Method::OPTIONS {
627 // These headers are applied only to preflight requests
628 headers.extend(self.layer.allow_methods.to_header(&parts));
629 headers.extend(self.layer.allow_headers.to_header(&parts));
630 headers.extend(self.layer.max_age.to_header(origin, &parts));
631
632 let mut response = Response::new(ResBody::default());
633 mem::swap(response.headers_mut(), &mut headers);
634
635 Ok(response)
636 } else {
637 // This header is applied only to non-preflight requests
638 headers.extend(self.layer.expose_headers.to_header(&parts));
639
640 let req = Request::from_parts(parts, body);
641
642 let mut response: Response<ResBody> = self.inner.call(req).await?;
643 response.headers_mut().extend(headers.drain());
644
645 Ok(response)
646 }
647 }
648}
649
650fn ensure_usable_cors_rules(layer: &CorsLayer) {
651 if layer.allow_credentials.is_true() {
652 assert!(
653 !layer.allow_headers.is_wildcard(),
654 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
655 with `Access-Control-Allow-Headers: *`"
656 );
657
658 assert!(
659 !layer.allow_methods.is_wildcard(),
660 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
661 with `Access-Control-Allow-Methods: *`"
662 );
663
664 assert!(
665 !layer.allow_origin.is_wildcard(),
666 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
667 with `Access-Control-Allow-Origin: *`"
668 );
669
670 assert!(
671 !layer.expose_headers.is_wildcard(),
672 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
673 with `Access-Control-Expose-Headers: *`"
674 );
675 }
676}
677
678/// Returns an iterator over the three request headers that may be involved in a CORS preflight request.
679///
680/// This is the default set of header names returned in the `vary` header
681pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
682 #[allow(deprecated)] // Can be changed when MSRV >= 1.53
683 array::IntoIter::new([
684 header::ORIGIN,
685 header::ACCESS_CONTROL_REQUEST_METHOD,
686 header::ACCESS_CONTROL_REQUEST_HEADERS,
687 ])
688}