salvo_cors/
lib.rs

1//! Library adds CORS protection for Salvo web framework.
2//!
3//! [CORS]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
4//!
5//! # Docs
6//! Find the docs here: <https://salvo.rs/book/features/cors.html>
7#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
8#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
9#![cfg_attr(docsrs, feature(doc_cfg))]
10
11use bytes::{BufMut, BytesMut};
12use salvo_core::http::header::{self, HeaderMap, HeaderName, HeaderValue};
13use salvo_core::http::{Method, Request, Response, StatusCode};
14use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
15
16mod allow_credentials;
17mod allow_headers;
18mod allow_methods;
19mod allow_origin;
20mod allow_private_network;
21mod expose_headers;
22mod max_age;
23mod vary;
24
25pub use self::{
26    allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
27    allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork,
28    expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
29};
30
31static WILDCARD: HeaderValue = HeaderValue::from_static("*");
32
33/// Represents a wildcard value (`*`) used with some CORS headers such as
34/// [`Cors::allow_methods`].
35#[derive(Debug, Clone, Copy)]
36#[must_use]
37pub struct Any;
38
39fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
40where
41    I: Iterator<Item = HeaderValue>,
42{
43    match iter.next() {
44        Some(fst) => {
45            let mut result = BytesMut::from(fst.as_bytes());
46            for val in iter {
47                result.reserve(val.len() + 1);
48                result.put_u8(b',');
49                result.extend_from_slice(val.as_bytes());
50            }
51
52            HeaderValue::from_maybe_shared(result.freeze()).ok()
53        }
54        None => None,
55    }
56}
57
58/// [`Cors`] middleware which adds headers for [CORS][mdn].
59///
60/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
61#[derive(Clone, Debug)]
62pub struct Cors {
63    allow_credentials: AllowCredentials,
64    allow_headers: AllowHeaders,
65    allow_methods: AllowMethods,
66    allow_origin: AllowOrigin,
67    allow_private_network: AllowPrivateNetwork,
68    expose_headers: ExposeHeaders,
69    max_age: MaxAge,
70    vary: Vary,
71}
72impl Default for Cors {
73    #[inline]
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl Cors {
80    /// Create new `Cors`.
81    #[inline]
82    #[must_use]
83    pub fn new() -> Self {
84        Self {
85            allow_credentials: Default::default(),
86            allow_headers: Default::default(),
87            allow_methods: Default::default(),
88            allow_origin: Default::default(),
89            allow_private_network: Default::default(),
90            expose_headers: Default::default(),
91            max_age: Default::default(),
92            vary: Default::default(),
93        }
94    }
95
96    /// A permissive configuration:
97    ///
98    /// - All request headers allowed.
99    /// - All methods allowed.
100    /// - All origins allowed.
101    /// - All headers exposed.
102    #[must_use]
103    pub fn permissive() -> Self {
104        Self::new()
105            .allow_headers(Any)
106            .allow_methods(Any)
107            .allow_origin(Any)
108            .expose_headers(Any)
109    }
110
111    /// A very permissive configuration:
112    ///
113    /// - **Credentials allowed.**
114    /// - The method received in `Access-Control-Request-Method` is sent back
115    ///   as an allowed method.
116    /// - The origin of the preflight request is sent back as an allowed origin.
117    /// - The header names received in `Access-Control-Request-Headers` are sent
118    ///   back as allowed headers.
119    /// - No headers are currently exposed, but this may change in the future.
120    #[must_use]
121    pub fn very_permissive() -> Self {
122        Self::new()
123            .allow_credentials(true)
124            .allow_headers(AllowHeaders::mirror_request())
125            .allow_methods(AllowMethods::mirror_request())
126            .allow_origin(AllowOrigin::mirror_request())
127    }
128
129    /// Sets whether to add the `Access-Control-Allow-Credentials` header.
130    #[inline]
131    #[must_use]
132    pub fn allow_credentials(mut self, allow_credentials: impl Into<AllowCredentials>) -> Self {
133        self.allow_credentials = allow_credentials.into();
134        self
135    }
136
137    /// Adds multiple headers to the list of allowed request headers.
138    ///
139    /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`.
140    ///
141    /// # Panics
142    ///
143    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
144    #[inline]
145    #[must_use]
146    pub fn allow_headers(mut self, headers: impl Into<AllowHeaders>) -> Self {
147        self.allow_headers = headers.into();
148        self
149    }
150
151    /// Sets the `Access-Control-Max-Age` header.
152    ///
153    /// # Example
154    ///
155    /// ```
156    /// use std::time::Duration;
157    /// use salvo_core::prelude::*;
158    /// use salvo_cors::Cors;
159    ///
160    /// let cors = Cors::new().max_age(30); // 30 seconds
161    /// let cors = Cors::new().max_age(Duration::from_secs(30)); // or a Duration
162    /// ```
163    #[inline]
164    #[must_use]
165    pub fn max_age(mut self, max_age: impl Into<MaxAge>) -> Self {
166        self.max_age = max_age.into();
167        self
168    }
169
170    /// Adds multiple methods to the existing list of allowed request methods.
171    ///
172    /// # Panics
173    ///
174    /// Panics if the provided argument is not a valid `http::Method`.
175    #[inline]
176    #[must_use]
177    pub fn allow_methods<I>(mut self, methods: I) -> Self
178    where
179        I: Into<AllowMethods>,
180    {
181        self.allow_methods = methods.into();
182        self
183    }
184
185    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
186    /// ```
187    /// use salvo_core::http::HeaderValue;
188    /// use salvo_cors::Cors;
189    ///
190    /// let cors = Cors::new().allow_origin(
191    ///     "http://example.com".parse::<HeaderValue>().unwrap(),
192    /// );
193    /// ```
194    ///
195    /// Multiple origins can be allowed with
196    ///
197    /// ```
198    /// use salvo_cors::Cors;
199    ///
200    /// let origins = ["http://example.com", "http://api.example.com"];
201    ///
202    /// let cors = Cors::new().allow_origin(origins);
203    /// ```
204    ///
205    /// All origins can be allowed with
206    ///
207    /// ```
208    /// use salvo_cors::{Any, Cors};
209    ///
210    /// let cors = Cors::new().allow_origin(Any);
211    /// ```
212    ///
213    /// You can also use a closure
214    ///
215    /// ```
216    /// use salvo_cors::{Cors, AllowOrigin};
217    /// use salvo_core::http::HeaderValue;
218    /// use salvo_core::{Depot, Request};
219    ///
220    /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic(
221    ///     |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
222    ///         if origin?.as_bytes().ends_with(b".rust-lang.org") {
223    ///             origin.cloned()
224    ///         } else {
225    ///             None
226    ///         }
227    ///     },
228    /// ));
229    /// ```
230    ///
231    /// You can also use an async closure, make sure all the values are owned
232    /// before passing into the future:
233    ///
234    /// ```
235    /// # #[derive(Clone)]
236    /// # struct Client;
237    /// # fn get_api_client() -> Client {
238    /// #     Client
239    /// # }
240    /// # impl Client {
241    /// #     async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
242    /// #         vec![HeaderValue::from_static("http://example.com")]
243    /// #     }
244    /// #     async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
245    /// #         vec![HeaderValue::from_static("http://example.com")]
246    /// #     }
247    /// # }
248    /// use salvo_cors::{Cors, AllowOrigin};
249    /// use salvo_core::http::header::HeaderValue;
250    /// use salvo_core::{Depot, Request};
251    ///
252    ///
253    /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic_async(
254    ///     |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
255    ///         let origin = origin.cloned();
256    ///         async move {
257    ///             let client = get_api_client();
258    ///             // fetch list of origins that are allowed
259    ///             let origins = client.fetch_allowed_origins().await;
260    ///             if origins.contains(origin.as_ref()?) {
261    ///                 origin
262    ///             } else {
263    ///                 None
264    ///             }
265    ///         }
266    ///     },
267    /// ));
268    /// ```
269    ///
270    /// **Note** that multiple calls to this method will override any previous
271    /// calls.
272    ///
273    /// **Note** origin must contain http or https protocol name.
274    ///
275    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
276    #[inline]
277    #[must_use]
278    pub fn allow_origin(mut self, origin: impl Into<AllowOrigin>) -> Self {
279        self.allow_origin = origin.into();
280        self
281    }
282
283    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
284    ///
285    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
286    #[inline]
287    #[must_use]
288    pub fn expose_headers(mut self, headers: impl Into<ExposeHeaders>) -> Self {
289        self.expose_headers = headers.into();
290        self
291    }
292
293    /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
294    ///
295    /// ```
296    /// use salvo_cors::Cors;
297    ///
298    /// let cors = Cors::new().allow_private_network(true);
299    /// ```
300    ///
301    /// [wicg]: https://wicg.github.io/private-network-access/
302    #[must_use]
303    pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
304    where
305        T: Into<AllowPrivateNetwork>,
306    {
307        self.allow_private_network = allow_private_network.into();
308        self
309    }
310
311    /// Set the value(s) of the [`Vary`][mdn] header.
312    ///
313    /// In contrast to the other headers, this one has a non-empty default of
314    /// [`preflight_request_headers()`].
315    ///
316    /// You only need to set this is you want to remove some of these defaults,
317    /// or if you use a closure for one of the other headers and want to add a
318    /// vary header accordingly.
319    ///
320    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
321    #[must_use]
322    pub fn vary<T>(mut self, headers: impl Into<Vary>) -> Self {
323        self.vary = headers.into();
324        self
325    }
326
327    /// Returns a new `CorsHandler` using current cors settings.
328    pub fn into_handler(self) -> CorsHandler {
329        self.ensure_usable_cors_rules();
330        CorsHandler::new(self, CallNext::default())
331    }
332
333    fn ensure_usable_cors_rules(&self) {
334        if self.allow_credentials.is_true() {
335            assert!(
336                !self.allow_headers.is_wildcard(),
337                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
338                 with `Access-Control-Allow-Headers: *`"
339            );
340
341            assert!(
342                !self.allow_methods.is_wildcard(),
343                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
344                 with `Access-Control-Allow-Methods: *`"
345            );
346
347            assert!(
348                !self.allow_origin.is_wildcard(),
349                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
350                 with `Access-Control-Allow-Origin: *`"
351            );
352
353            assert!(
354                !self.expose_headers.is_wildcard(),
355                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
356                 with `Access-Control-Expose-Headers: *`"
357            );
358        }
359    }
360}
361
362/// Enum to control when to call next handler.
363#[non_exhaustive]
364#[derive(Default, Clone, Copy, Eq, PartialEq, Debug)]
365pub enum CallNext {
366    /// Call next handlers before [`CorsHandler`] write data to response.
367    #[default]
368    Before,
369    /// Call next handlers after [`CorsHandler`] write data to response.
370    After,
371}
372
373/// CorsHandler
374#[derive(Clone, Debug)]
375pub struct CorsHandler {
376    cors: Cors,
377    call_next: CallNext,
378}
379impl CorsHandler {
380    /// Create a new `CorsHandler`.
381    pub fn new(cors: Cors, call_next: CallNext) -> Self {
382        Self { cors, call_next }
383    }
384}
385
386#[async_trait]
387impl Handler for CorsHandler {
388    async fn handle(
389        &self,
390        req: &mut Request,
391        depot: &mut Depot,
392        res: &mut Response,
393        ctrl: &mut FlowCtrl,
394    ) {
395        if self.call_next == CallNext::Before {
396            ctrl.call_next(req, depot, res).await;
397        }
398
399        let origin = req.headers().get(&header::ORIGIN);
400        let mut headers = HeaderMap::new();
401
402        // These headers are applied to both preflight and subsequent regular CORS requests:
403        // https://fetch.spec.whatwg.org/#http-responses
404        headers.extend(self.cors.allow_origin.to_header(origin, req, depot).await);
405        headers.extend(
406            self.cors
407                .allow_credentials
408                .to_header(origin, req, depot)
409                .await,
410        );
411        headers.extend(
412            self.cors
413                .allow_private_network
414                .to_header(origin, req, depot)
415                .await,
416        );
417
418        let mut vary_headers = self.cors.vary.values();
419        if let Some(first) = vary_headers.next() {
420            let mut header = match headers.entry(header::VARY) {
421                header::Entry::Occupied(_) => {
422                    unreachable!("no vary header inserted up to this point")
423                }
424                header::Entry::Vacant(v) => v.insert_entry(first),
425            };
426
427            for val in vary_headers {
428                header.append(val);
429            }
430        }
431
432        // Return results immediately upon preflight request
433        if req.method() == Method::OPTIONS {
434            // These headers are applied only to preflight requests
435            headers.extend(self.cors.allow_methods.to_header(origin, req, depot).await);
436            headers.extend(self.cors.allow_headers.to_header(origin, req, depot).await);
437            headers.extend(self.cors.max_age.to_header(origin, req, depot).await);
438            res.status_code = Some(StatusCode::NO_CONTENT);
439        } else {
440            // This header is applied only to non-preflight requests
441            headers.extend(self.cors.expose_headers.to_header(origin, req, depot).await);
442        }
443        res.headers_mut().extend(headers);
444
445        if self.call_next == CallNext::After {
446            ctrl.call_next(req, depot, res).await;
447        }
448    }
449}
450
451/// Iterator over the three request headers that may be involved in a CORS preflight request.
452///
453/// This is the default set of header names returned in the `vary` header
454pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
455    [
456        header::ORIGIN,
457        header::ACCESS_CONTROL_REQUEST_METHOD,
458        header::ACCESS_CONTROL_REQUEST_HEADERS,
459    ]
460    .into_iter()
461}
462
463#[cfg(test)]
464mod tests {
465    use salvo_core::http::header::*;
466    use salvo_core::prelude::*;
467    use salvo_core::test::TestClient;
468
469    use super::*;
470
471    #[tokio::test]
472    async fn test_cors() {
473        let cors_handler = Cors::new()
474            .allow_origin("https://salvo.rs")
475            .allow_methods(vec![Method::GET, Method::POST, Method::OPTIONS])
476            .allow_headers(vec![
477                "CONTENT-TYPE",
478                "Access-Control-Request-Method",
479                "Access-Control-Allow-Origin",
480                "Access-Control-Allow-Headers",
481                "Access-Control-Max-Age",
482            ])
483            .into_handler();
484
485        #[handler]
486        async fn hello() -> &'static str {
487            "hello"
488        }
489
490        let router = Router::new()
491            .hoop(cors_handler)
492            .push(Router::with_path("hello").goal(hello));
493        let service = Service::new(router);
494
495        async fn options_access(service: &Service, origin: &str) -> Response {
496            TestClient::options("http://127.0.0.1:5801/hello")
497                .add_header("Origin", origin, true)
498                .add_header("Access-Control-Request-Method", "POST", true)
499                .add_header("Access-Control-Request-Headers", "Content-Type", true)
500                .send(service)
501                .await
502        }
503
504        let res = TestClient::options("https://salvo.rs").send(&service).await;
505        assert!(res.headers().get(ACCESS_CONTROL_ALLOW_METHODS).is_none());
506
507        let res = options_access(&service, "https://salvo.rs").await;
508        let headers = res.headers();
509        assert!(headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_some());
510        assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_some());
511
512        let res = TestClient::options("https://google.com")
513            .send(&service)
514            .await;
515        let headers = res.headers();
516        assert!(
517            headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_none(),
518            "POST, GET, DELETE, OPTIONS"
519        );
520        assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_none());
521    }
522}