salvo_cors/
lib.rs

1//! Library adds CORS protection for Salvo web framework.
2//!
3//! [CORS]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
4//!
5//! # Docs
6//! Find the docs here: <https://salvo.rs/book/features/cors.html>
7#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
8#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
9#![cfg_attr(docsrs, feature(doc_cfg))]
10
11use bytes::{BufMut, BytesMut};
12use salvo_core::http::header::{self, HeaderMap, HeaderName, HeaderValue};
13use salvo_core::http::{Method, Request, Response, StatusCode};
14use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
15
16mod allow_credentials;
17mod allow_headers;
18mod allow_methods;
19mod allow_origin;
20mod allow_private_network;
21mod expose_headers;
22mod max_age;
23mod vary;
24
25pub use self::allow_credentials::AllowCredentials;
26pub use self::allow_headers::AllowHeaders;
27pub use self::allow_methods::AllowMethods;
28pub use self::allow_origin::AllowOrigin;
29pub use self::allow_private_network::AllowPrivateNetwork;
30pub use self::expose_headers::ExposeHeaders;
31pub use self::max_age::MaxAge;
32pub use self::vary::Vary;
33
34static WILDCARD: HeaderValue = HeaderValue::from_static("*");
35
36/// Represents a wildcard value (`*`) used with some CORS headers such as
37/// [`Cors::allow_methods`].
38#[derive(Debug, Clone, Copy)]
39#[must_use]
40pub struct Any;
41
42fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
43where
44    I: Iterator<Item = HeaderValue>,
45{
46    match iter.next() {
47        Some(fst) => {
48            let mut result = BytesMut::from(fst.as_bytes());
49            for val in iter {
50                result.reserve(val.len() + 1);
51                result.put_u8(b',');
52                result.extend_from_slice(val.as_bytes());
53            }
54
55            HeaderValue::from_maybe_shared(result.freeze()).ok()
56        }
57        None => None,
58    }
59}
60
61/// [`Cors`] middleware which adds headers for [CORS][mdn].
62///
63/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
64#[derive(Clone, Debug)]
65pub struct Cors {
66    allow_credentials: AllowCredentials,
67    allow_headers: AllowHeaders,
68    allow_methods: AllowMethods,
69    allow_origin: AllowOrigin,
70    allow_private_network: AllowPrivateNetwork,
71    expose_headers: ExposeHeaders,
72    max_age: MaxAge,
73    vary: Vary,
74}
75impl Default for Cors {
76    #[inline]
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl Cors {
83    /// Create new `Cors`.
84    #[inline]
85    #[must_use]
86    pub fn new() -> Self {
87        Self {
88            allow_credentials: Default::default(),
89            allow_headers: Default::default(),
90            allow_methods: Default::default(),
91            allow_origin: Default::default(),
92            allow_private_network: Default::default(),
93            expose_headers: Default::default(),
94            max_age: Default::default(),
95            vary: Default::default(),
96        }
97    }
98
99    /// A permissive configuration:
100    ///
101    /// - All request headers allowed.
102    /// - All methods allowed.
103    /// - All origins allowed.
104    /// - All headers exposed.
105    #[must_use]
106    pub fn permissive() -> Self {
107        Self::new()
108            .allow_headers(Any)
109            .allow_methods(Any)
110            .allow_origin(Any)
111            .expose_headers(Any)
112    }
113
114    /// A very permissive configuration:
115    ///
116    /// - **Credentials allowed.**
117    /// - The method received in `Access-Control-Request-Method` is sent back as an allowed method.
118    /// - The origin of the preflight request is sent back as an allowed origin.
119    /// - The header names received in `Access-Control-Request-Headers` are sent back as allowed
120    ///   headers.
121    /// - No headers are currently exposed, but this may change in the future.
122    #[must_use]
123    pub fn very_permissive() -> Self {
124        Self::new()
125            .allow_credentials(true)
126            .allow_headers(AllowHeaders::mirror_request())
127            .allow_methods(AllowMethods::mirror_request())
128            .allow_origin(AllowOrigin::mirror_request())
129    }
130
131    /// Sets whether to add the `Access-Control-Allow-Credentials` header.
132    #[inline]
133    #[must_use]
134    pub fn allow_credentials(mut self, allow_credentials: impl Into<AllowCredentials>) -> Self {
135        self.allow_credentials = allow_credentials.into();
136        self
137    }
138
139    /// Adds multiple headers to the list of allowed request headers.
140    ///
141    /// **Note**: These should match the values the browser sends via
142    /// `Access-Control-Request-Headers`, e.g.`content-type`.
143    ///
144    /// # Panics
145    ///
146    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
147    #[inline]
148    #[must_use]
149    pub fn allow_headers(mut self, headers: impl Into<AllowHeaders>) -> Self {
150        self.allow_headers = headers.into();
151        self
152    }
153
154    /// Sets the `Access-Control-Max-Age` header.
155    ///
156    /// # Example
157    ///
158    /// ```
159    /// use std::time::Duration;
160    ///
161    /// use salvo_core::prelude::*;
162    /// use salvo_cors::Cors;
163    ///
164    /// let cors = Cors::new().max_age(30); // 30 seconds
165    /// let cors = Cors::new().max_age(Duration::from_secs(30)); // or a Duration
166    /// ```
167    #[inline]
168    #[must_use]
169    pub fn max_age(mut self, max_age: impl Into<MaxAge>) -> Self {
170        self.max_age = max_age.into();
171        self
172    }
173
174    /// Adds multiple methods to the existing list of allowed request methods.
175    ///
176    /// # Panics
177    ///
178    /// Panics if the provided argument is not a valid `http::Method`.
179    #[inline]
180    #[must_use]
181    pub fn allow_methods<I>(mut self, methods: I) -> Self
182    where
183        I: Into<AllowMethods>,
184    {
185        self.allow_methods = methods.into();
186        self
187    }
188
189    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
190    /// ```
191    /// use salvo_core::http::HeaderValue;
192    /// use salvo_cors::Cors;
193    ///
194    /// let cors = Cors::new().allow_origin("http://example.com".parse::<HeaderValue>().unwrap());
195    /// ```
196    ///
197    /// Multiple origins can be allowed with
198    ///
199    /// ```
200    /// use salvo_cors::Cors;
201    ///
202    /// let origins = ["http://example.com", "http://api.example.com"];
203    ///
204    /// let cors = Cors::new().allow_origin(origins);
205    /// ```
206    ///
207    /// All origins can be allowed with
208    ///
209    /// ```
210    /// use salvo_cors::{Any, Cors};
211    ///
212    /// let cors = Cors::new().allow_origin(Any);
213    /// ```
214    ///
215    /// You can also use a closure
216    ///
217    /// ```
218    /// use salvo_core::http::HeaderValue;
219    /// use salvo_core::{Depot, Request};
220    /// use salvo_cors::{AllowOrigin, Cors};
221    ///
222    /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic(
223    ///     |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
224    ///         if origin?.as_bytes().ends_with(b".rust-lang.org") {
225    ///             origin.cloned()
226    ///         } else {
227    ///             None
228    ///         }
229    ///     },
230    /// ));
231    /// ```
232    ///
233    /// You can also use an async closure, make sure all the values are owned
234    /// before passing into the future:
235    ///
236    /// ```
237    /// # #[derive(Clone)]
238    /// # struct Client;
239    /// # fn get_api_client() -> Client {
240    /// #     Client
241    /// # }
242    /// # impl Client {
243    /// #     async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
244    /// #         vec![HeaderValue::from_static("http://example.com")]
245    /// #     }
246    /// #     async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
247    /// #         vec![HeaderValue::from_static("http://example.com")]
248    /// #     }
249    /// # }
250    /// use salvo_core::http::header::HeaderValue;
251    /// use salvo_core::{Depot, Request};
252    /// use salvo_cors::{AllowOrigin, Cors};
253    ///
254    /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic_async(
255    ///     |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
256    ///         let origin = origin.cloned();
257    ///         async move {
258    ///             let client = get_api_client();
259    ///             // fetch list of origins that are allowed
260    ///             let origins = client.fetch_allowed_origins().await;
261    ///             if origins.contains(origin.as_ref()?) {
262    ///                 origin
263    ///             } else {
264    ///                 None
265    ///             }
266    ///         }
267    ///     },
268    /// ));
269    /// ```
270    ///
271    /// **Note** that multiple calls to this method will override any previous
272    /// calls.
273    ///
274    /// **Note** origin must contain http or https protocol name.
275    ///
276    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
277    #[inline]
278    #[must_use]
279    pub fn allow_origin(mut self, origin: impl Into<AllowOrigin>) -> Self {
280        self.allow_origin = origin.into();
281        self
282    }
283
284    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
285    ///
286    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
287    #[inline]
288    #[must_use]
289    pub fn expose_headers(mut self, headers: impl Into<ExposeHeaders>) -> Self {
290        self.expose_headers = headers.into();
291        self
292    }
293
294    /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
295    ///
296    /// ```
297    /// use salvo_cors::Cors;
298    ///
299    /// let cors = Cors::new().allow_private_network(true);
300    /// ```
301    ///
302    /// [wicg]: https://wicg.github.io/private-network-access/
303    #[must_use]
304    pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
305    where
306        T: Into<AllowPrivateNetwork>,
307    {
308        self.allow_private_network = allow_private_network.into();
309        self
310    }
311
312    /// Set the value(s) of the [`Vary`][mdn] header.
313    ///
314    /// In contrast to the other headers, this one has a non-empty default of
315    /// [`preflight_request_headers()`].
316    ///
317    /// You only need to set this is you want to remove some of these defaults,
318    /// or if you use a closure for one of the other headers and want to add a
319    /// vary header accordingly.
320    ///
321    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
322    #[must_use]
323    pub fn vary<T>(mut self, headers: impl Into<Vary>) -> Self {
324        self.vary = headers.into();
325        self
326    }
327
328    /// Returns a new `CorsHandler` using current cors settings.
329    pub fn into_handler(self) -> CorsHandler {
330        self.ensure_usable_cors_rules();
331        CorsHandler::new(self, CallNext::default())
332    }
333
334    fn ensure_usable_cors_rules(&self) {
335        if self.allow_credentials.is_true() {
336            assert!(
337                !self.allow_headers.is_wildcard(),
338                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
339                 with `Access-Control-Allow-Headers: *`"
340            );
341
342            assert!(
343                !self.allow_methods.is_wildcard(),
344                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
345                 with `Access-Control-Allow-Methods: *`"
346            );
347
348            assert!(
349                !self.allow_origin.is_wildcard(),
350                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
351                 with `Access-Control-Allow-Origin: *`"
352            );
353
354            assert!(
355                !self.expose_headers.is_wildcard(),
356                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
357                 with `Access-Control-Expose-Headers: *`"
358            );
359        }
360    }
361}
362
363/// Enum to control when to call next handler.
364#[non_exhaustive]
365#[derive(Default, Clone, Copy, Eq, PartialEq, Debug)]
366pub enum CallNext {
367    /// Call next handlers before [`CorsHandler`] write data to response.
368    #[default]
369    Before,
370    /// Call next handlers after [`CorsHandler`] write data to response.
371    After,
372}
373
374/// CorsHandler
375#[derive(Clone, Debug)]
376pub struct CorsHandler {
377    cors: Cors,
378    call_next: CallNext,
379}
380impl CorsHandler {
381    /// Create a new `CorsHandler`.
382    pub fn new(cors: Cors, call_next: CallNext) -> Self {
383        Self { cors, call_next }
384    }
385}
386
387#[async_trait]
388impl Handler for CorsHandler {
389    async fn handle(
390        &self,
391        req: &mut Request,
392        depot: &mut Depot,
393        res: &mut Response,
394        ctrl: &mut FlowCtrl,
395    ) {
396        if self.call_next == CallNext::Before {
397            ctrl.call_next(req, depot, res).await;
398        }
399
400        let origin = req.headers().get(&header::ORIGIN);
401        let mut headers = HeaderMap::new();
402
403        // These headers are applied to both preflight and subsequent regular CORS requests:
404        // https://fetch.spec.whatwg.org/#http-responses
405        headers.extend(self.cors.allow_origin.to_header(origin, req, depot).await);
406        headers.extend(
407            self.cors
408                .allow_credentials
409                .to_header(origin, req, depot)
410                .await,
411        );
412        headers.extend(
413            self.cors
414                .allow_private_network
415                .to_header(origin, req, depot)
416                .await,
417        );
418
419        let mut vary_headers = self.cors.vary.values();
420        if let Some(first) = vary_headers.next() {
421            let mut header = match headers.entry(header::VARY) {
422                header::Entry::Occupied(_) => {
423                    unreachable!("no vary header inserted up to this point")
424                }
425                header::Entry::Vacant(v) => v.insert_entry(first),
426            };
427
428            for val in vary_headers {
429                header.append(val);
430            }
431        }
432
433        // Return results immediately upon preflight request
434        if req.method() == Method::OPTIONS {
435            // These headers are applied only to preflight requests
436            headers.extend(self.cors.allow_methods.to_header(origin, req, depot).await);
437            headers.extend(self.cors.allow_headers.to_header(origin, req, depot).await);
438            headers.extend(self.cors.max_age.to_header(origin, req, depot).await);
439            res.status_code = Some(StatusCode::NO_CONTENT);
440        } else {
441            // This header is applied only to non-preflight requests
442            headers.extend(self.cors.expose_headers.to_header(origin, req, depot).await);
443        }
444        res.headers_mut().extend(headers);
445
446        if self.call_next == CallNext::After {
447            ctrl.call_next(req, depot, res).await;
448        }
449    }
450}
451
452/// Iterator over the three request headers that may be involved in a CORS preflight request.
453///
454/// This is the default set of header names returned in the `vary` header
455pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
456    [
457        header::ORIGIN,
458        header::ACCESS_CONTROL_REQUEST_METHOD,
459        header::ACCESS_CONTROL_REQUEST_HEADERS,
460    ]
461    .into_iter()
462}
463
464#[cfg(test)]
465mod tests {
466    use salvo_core::http::header::*;
467    use salvo_core::prelude::*;
468    use salvo_core::test::TestClient;
469
470    use super::*;
471
472    #[tokio::test]
473    async fn test_cors() {
474        let cors_handler = Cors::new()
475            .allow_origin("https://salvo.rs")
476            .allow_methods(vec![Method::GET, Method::POST, Method::OPTIONS])
477            .allow_headers(vec![
478                "CONTENT-TYPE",
479                "Access-Control-Request-Method",
480                "Access-Control-Allow-Origin",
481                "Access-Control-Allow-Headers",
482                "Access-Control-Max-Age",
483            ])
484            .into_handler();
485
486        #[handler]
487        async fn hello() -> &'static str {
488            "hello"
489        }
490
491        let router = Router::new()
492            .hoop(cors_handler)
493            .push(Router::with_path("hello").goal(hello));
494        let service = Service::new(router);
495
496        async fn options_access(service: &Service, origin: &str) -> Response {
497            TestClient::options("http://127.0.0.1:5801/hello")
498                .add_header("Origin", origin, true)
499                .add_header("Access-Control-Request-Method", "POST", true)
500                .add_header("Access-Control-Request-Headers", "Content-Type", true)
501                .send(service)
502                .await
503        }
504
505        let res = TestClient::options("https://salvo.rs").send(&service).await;
506        assert!(res.headers().get(ACCESS_CONTROL_ALLOW_METHODS).is_none());
507
508        let res = options_access(&service, "https://salvo.rs").await;
509        let headers = res.headers();
510        assert!(headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_some());
511        assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_some());
512
513        let res = TestClient::options("https://google.com")
514            .send(&service)
515            .await;
516        let headers = res.headers();
517        assert!(
518            headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_none(),
519            "POST, GET, DELETE, OPTIONS"
520        );
521        assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_none());
522    }
523}