Skip to main content

salvo_cors/
lib.rs

1//! Library adds CORS protection for Salvo web framework.
2//!
3//! [CORS]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
4//!
5//! # Docs
6//! Find the docs here: <https://salvo.rs/book/features/cors.html>
7#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
8#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
9#![cfg_attr(docsrs, feature(doc_cfg))]
10
11use bytes::{BufMut, BytesMut};
12use salvo_core::http::header::{self, HeaderMap, HeaderName, HeaderValue};
13use salvo_core::http::{Method, Request, Response, StatusCode};
14use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
15
16mod allow_credentials;
17mod allow_headers;
18mod allow_methods;
19mod allow_origin;
20mod allow_private_network;
21mod expose_headers;
22mod max_age;
23mod vary;
24
25pub use self::allow_credentials::AllowCredentials;
26pub use self::allow_headers::AllowHeaders;
27pub use self::allow_methods::AllowMethods;
28pub use self::allow_origin::AllowOrigin;
29pub use self::allow_private_network::AllowPrivateNetwork;
30pub use self::expose_headers::ExposeHeaders;
31pub use self::max_age::MaxAge;
32pub use self::vary::Vary;
33
34static WILDCARD: HeaderValue = HeaderValue::from_static("*");
35
36/// Represents a wildcard value (`*`) used with some CORS headers such as
37/// [`Cors::allow_methods`].
38#[derive(Debug, Clone, Copy)]
39#[must_use]
40pub struct Any;
41
42fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
43where
44    I: Iterator<Item = HeaderValue>,
45{
46    match iter.next() {
47        Some(fst) => {
48            let mut result = BytesMut::from(fst.as_bytes());
49            for val in iter {
50                result.reserve(val.len() + 1);
51                result.put_u8(b',');
52                result.extend_from_slice(val.as_bytes());
53            }
54
55            HeaderValue::from_maybe_shared(result.freeze()).ok()
56        }
57        None => None,
58    }
59}
60
61/// [`Cors`] middleware which adds headers for [CORS][mdn].
62///
63/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
64#[derive(Clone, Debug)]
65pub struct Cors {
66    allow_credentials: AllowCredentials,
67    allow_headers: AllowHeaders,
68    allow_methods: AllowMethods,
69    allow_origin: AllowOrigin,
70    allow_private_network: AllowPrivateNetwork,
71    expose_headers: ExposeHeaders,
72    max_age: MaxAge,
73    vary: Vary,
74}
75impl Default for Cors {
76    #[inline]
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl Cors {
83    /// Create new `Cors`.
84    #[inline]
85    #[must_use]
86    pub fn new() -> Self {
87        Self {
88            allow_credentials: Default::default(),
89            allow_headers: Default::default(),
90            allow_methods: Default::default(),
91            allow_origin: Default::default(),
92            allow_private_network: Default::default(),
93            expose_headers: Default::default(),
94            max_age: Default::default(),
95            vary: Default::default(),
96        }
97    }
98
99    /// A permissive configuration:
100    ///
101    /// - All request headers allowed.
102    /// - All methods allowed.
103    /// - All origins allowed.
104    /// - All headers exposed.
105    ///
106    /// # Security Warning
107    ///
108    /// **This configuration allows any website to make requests to your API.**
109    /// Only use this for:
110    /// - Public APIs that don't require authentication
111    /// - Development/testing environments
112    ///
113    /// For production APIs that require authentication, configure CORS explicitly
114    /// with specific allowed origins.
115    #[must_use]
116    pub fn permissive() -> Self {
117        Self::new()
118            .allow_headers(Any)
119            .allow_methods(Any)
120            .allow_origin(Any)
121            .expose_headers(Any)
122    }
123
124    /// A very permissive configuration:
125    ///
126    /// - **Credentials allowed.**
127    /// - The method received in `Access-Control-Request-Method` is sent back as an allowed method.
128    /// - The origin of the preflight request is sent back as an allowed origin.
129    /// - The header names received in `Access-Control-Request-Headers` are sent back as allowed
130    ///   headers.
131    /// - No headers are currently exposed, but this may change in the future.
132    ///
133    /// # Security Warning
134    ///
135    /// **⚠️ DANGER: This configuration essentially disables CORS protection!**
136    ///
137    /// By enabling credentials AND mirroring the request origin, you are allowing
138    /// ANY website to:
139    /// - Make authenticated requests to your API
140    /// - Read response data including sensitive information
141    /// - Perform actions on behalf of logged-in users (CSRF attacks)
142    ///
143    /// **This should NEVER be used in production with authentication.**
144    ///
145    /// Only use this for:
146    /// - Local development where security is not a concern
147    /// - Internal tools on trusted networks
148    ///
149    /// For production, always configure explicit allowed origins:
150    /// ```ignore
151    /// Cors::new()
152    ///     .allow_origin("https://your-frontend.com")
153    ///     .allow_credentials(true)
154    /// ```
155    #[must_use]
156    pub fn very_permissive() -> Self {
157        tracing::warn!(
158            "Using Cors::very_permissive() - this disables CORS security and should not be used in production!"
159        );
160        Self::new()
161            .allow_credentials(true)
162            .allow_headers(AllowHeaders::mirror_request())
163            .allow_methods(AllowMethods::mirror_request())
164            .allow_origin(AllowOrigin::mirror_request())
165    }
166
167    /// Sets whether to add the `Access-Control-Allow-Credentials` header.
168    #[inline]
169    #[must_use]
170    pub fn allow_credentials(mut self, allow_credentials: impl Into<AllowCredentials>) -> Self {
171        self.allow_credentials = allow_credentials.into();
172        self
173    }
174
175    /// Adds multiple headers to the list of allowed request headers.
176    ///
177    /// **Note**: These should match the values the browser sends via
178    /// `Access-Control-Request-Headers`, e.g.`content-type`.
179    ///
180    /// # Panics
181    ///
182    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
183    #[inline]
184    #[must_use]
185    pub fn allow_headers(mut self, headers: impl Into<AllowHeaders>) -> Self {
186        self.allow_headers = headers.into();
187        self
188    }
189
190    /// Sets the `Access-Control-Max-Age` header.
191    ///
192    /// # Example
193    ///
194    /// ```
195    /// use std::time::Duration;
196    ///
197    /// use salvo_core::prelude::*;
198    /// use salvo_cors::Cors;
199    ///
200    /// let cors = Cors::new().max_age(30); // 30 seconds
201    /// let cors = Cors::new().max_age(Duration::from_secs(30)); // or a Duration
202    /// ```
203    #[inline]
204    #[must_use]
205    pub fn max_age(mut self, max_age: impl Into<MaxAge>) -> Self {
206        self.max_age = max_age.into();
207        self
208    }
209
210    /// Adds multiple methods to the existing list of allowed request methods.
211    ///
212    /// # Panics
213    ///
214    /// Panics if the provided argument is not a valid `http::Method`.
215    #[inline]
216    #[must_use]
217    pub fn allow_methods<I>(mut self, methods: I) -> Self
218    where
219        I: Into<AllowMethods>,
220    {
221        self.allow_methods = methods.into();
222        self
223    }
224
225    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
226    /// ```
227    /// use salvo_core::http::HeaderValue;
228    /// use salvo_cors::Cors;
229    ///
230    /// let cors = Cors::new().allow_origin("http://example.com".parse::<HeaderValue>().unwrap());
231    /// ```
232    ///
233    /// Multiple origins can be allowed with
234    ///
235    /// ```
236    /// use salvo_cors::Cors;
237    ///
238    /// let origins = ["http://example.com", "http://api.example.com"];
239    ///
240    /// let cors = Cors::new().allow_origin(origins);
241    /// ```
242    ///
243    /// All origins can be allowed with
244    ///
245    /// ```
246    /// use salvo_cors::{Any, Cors};
247    ///
248    /// let cors = Cors::new().allow_origin(Any);
249    /// ```
250    ///
251    /// You can also use a closure
252    ///
253    /// ```
254    /// use salvo_core::http::HeaderValue;
255    /// use salvo_core::{Depot, Request};
256    /// use salvo_cors::{AllowOrigin, Cors};
257    ///
258    /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic(
259    ///     |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
260    ///         if origin?.as_bytes().ends_with(b".rust-lang.org") {
261    ///             origin.cloned()
262    ///         } else {
263    ///             None
264    ///         }
265    ///     },
266    /// ));
267    /// ```
268    ///
269    /// You can also use an async closure, make sure all the values are owned
270    /// before passing into the future:
271    ///
272    /// ```
273    /// # #[derive(Clone)]
274    /// # struct Client;
275    /// # fn get_api_client() -> Client {
276    /// #     Client
277    /// # }
278    /// # impl Client {
279    /// #     async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
280    /// #         vec![HeaderValue::from_static("http://example.com")]
281    /// #     }
282    /// #     async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
283    /// #         vec![HeaderValue::from_static("http://example.com")]
284    /// #     }
285    /// # }
286    /// use salvo_core::http::header::HeaderValue;
287    /// use salvo_core::{Depot, Request};
288    /// use salvo_cors::{AllowOrigin, Cors};
289    ///
290    /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic_async(
291    ///     |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
292    ///         let origin = origin.cloned();
293    ///         async move {
294    ///             let client = get_api_client();
295    ///             // fetch list of origins that are allowed
296    ///             let origins = client.fetch_allowed_origins().await;
297    ///             if origins.contains(origin.as_ref()?) {
298    ///                 origin
299    ///             } else {
300    ///                 None
301    ///             }
302    ///         }
303    ///     },
304    /// ));
305    /// ```
306    ///
307    /// **Note** that multiple calls to this method will override any previous
308    /// calls.
309    ///
310    /// **Note** origin must contain http or https protocol name.
311    ///
312    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
313    #[inline]
314    #[must_use]
315    pub fn allow_origin(mut self, origin: impl Into<AllowOrigin>) -> Self {
316        self.allow_origin = origin.into();
317        self
318    }
319
320    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
321    ///
322    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
323    #[inline]
324    #[must_use]
325    pub fn expose_headers(mut self, headers: impl Into<ExposeHeaders>) -> Self {
326        self.expose_headers = headers.into();
327        self
328    }
329
330    /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
331    ///
332    /// ```
333    /// use salvo_cors::Cors;
334    ///
335    /// let cors = Cors::new().allow_private_network(true);
336    /// ```
337    ///
338    /// [wicg]: https://wicg.github.io/private-network-access/
339    #[must_use]
340    pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
341    where
342        T: Into<AllowPrivateNetwork>,
343    {
344        self.allow_private_network = allow_private_network.into();
345        self
346    }
347
348    /// Set the value(s) of the [`Vary`][mdn] header.
349    ///
350    /// In contrast to the other headers, this one has a non-empty default of
351    /// [`preflight_request_headers()`].
352    ///
353    /// You only need to set this is you want to remove some of these defaults,
354    /// or if you use a closure for one of the other headers and want to add a
355    /// vary header accordingly.
356    ///
357    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
358    #[must_use]
359    pub fn vary<T>(mut self, headers: impl Into<Vary>) -> Self {
360        self.vary = headers.into();
361        self
362    }
363
364    /// Returns a new `CorsHandler` using current cors settings.
365    pub fn into_handler(self) -> CorsHandler {
366        self.ensure_usable_cors_rules();
367        CorsHandler::new(self, CallNext::default())
368    }
369
370    fn ensure_usable_cors_rules(&self) {
371        if self.allow_credentials.is_true() {
372            assert!(
373                !self.allow_headers.is_wildcard(),
374                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
375                 with `Access-Control-Allow-Headers: *`"
376            );
377
378            assert!(
379                !self.allow_methods.is_wildcard(),
380                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
381                 with `Access-Control-Allow-Methods: *`"
382            );
383
384            assert!(
385                !self.allow_origin.is_wildcard(),
386                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
387                 with `Access-Control-Allow-Origin: *`"
388            );
389
390            assert!(
391                !self.expose_headers.is_wildcard(),
392                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
393                 with `Access-Control-Expose-Headers: *`"
394            );
395        }
396    }
397}
398
399/// Enum to control when to call next handler.
400#[non_exhaustive]
401#[derive(Default, Clone, Copy, Eq, PartialEq, Debug)]
402pub enum CallNext {
403    /// Call next handlers before [`CorsHandler`] write data to response.
404    #[default]
405    Before,
406    /// Call next handlers after [`CorsHandler`] write data to response.
407    After,
408}
409
410/// CorsHandler
411#[derive(Clone, Debug)]
412pub struct CorsHandler {
413    cors: Cors,
414    call_next: CallNext,
415}
416impl CorsHandler {
417    /// Create a new `CorsHandler`.
418    pub fn new(cors: Cors, call_next: CallNext) -> Self {
419        Self { cors, call_next }
420    }
421}
422
423#[async_trait]
424impl Handler for CorsHandler {
425    async fn handle(
426        &self,
427        req: &mut Request,
428        depot: &mut Depot,
429        res: &mut Response,
430        ctrl: &mut FlowCtrl,
431    ) {
432        if self.call_next == CallNext::Before {
433            ctrl.call_next(req, depot, res).await;
434        }
435
436        let origin = req.headers().get(&header::ORIGIN);
437        let mut headers = HeaderMap::new();
438
439        // These headers are applied to both preflight and subsequent regular CORS requests:
440        // https://fetch.spec.whatwg.org/#http-responses
441        headers.extend(self.cors.allow_origin.to_header(origin, req, depot).await);
442        headers.extend(
443            self.cors
444                .allow_credentials
445                .to_header(origin, req, depot)
446                .await,
447        );
448        headers.extend(
449            self.cors
450                .allow_private_network
451                .to_header(origin, req, depot)
452                .await,
453        );
454
455        let mut vary_headers = self.cors.vary.values();
456        if let Some(first) = vary_headers.next() {
457            let mut header = match headers.entry(header::VARY) {
458                header::Entry::Occupied(_) => {
459                    unreachable!("no vary header inserted up to this point")
460                }
461                header::Entry::Vacant(v) => v.insert_entry(first),
462            };
463
464            for val in vary_headers {
465                header.append(val);
466            }
467        }
468
469        // Return results immediately upon preflight request
470        if req.method() == Method::OPTIONS {
471            // These headers are applied only to preflight requests
472            headers.extend(self.cors.allow_methods.to_header(origin, req, depot).await);
473            headers.extend(self.cors.allow_headers.to_header(origin, req, depot).await);
474            headers.extend(self.cors.max_age.to_header(origin, req, depot).await);
475            res.status_code = Some(StatusCode::NO_CONTENT);
476        } else {
477            // This header is applied only to non-preflight requests
478            headers.extend(self.cors.expose_headers.to_header(origin, req, depot).await);
479        }
480        res.headers_mut().extend(headers);
481
482        if self.call_next == CallNext::After {
483            ctrl.call_next(req, depot, res).await;
484        }
485    }
486}
487
488/// Iterator over the three request headers that may be involved in a CORS preflight request.
489///
490/// This is the default set of header names returned in the `vary` header
491pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
492    [
493        header::ORIGIN,
494        header::ACCESS_CONTROL_REQUEST_METHOD,
495        header::ACCESS_CONTROL_REQUEST_HEADERS,
496    ]
497    .into_iter()
498}
499
500#[cfg(test)]
501mod tests {
502    use salvo_core::http::header::*;
503    use salvo_core::prelude::*;
504    use salvo_core::test::TestClient;
505
506    use super::*;
507
508    #[tokio::test]
509    async fn test_cors() {
510        let cors_handler = Cors::new()
511            .allow_origin("https://salvo.rs")
512            .allow_methods(vec![Method::GET, Method::POST, Method::OPTIONS])
513            .allow_headers(vec![
514                "CONTENT-TYPE",
515                "Access-Control-Request-Method",
516                "Access-Control-Allow-Origin",
517                "Access-Control-Allow-Headers",
518                "Access-Control-Max-Age",
519            ])
520            .into_handler();
521
522        #[handler]
523        async fn hello() -> &'static str {
524            "hello"
525        }
526
527        let router = Router::new()
528            .hoop(cors_handler)
529            .push(Router::with_path("hello").goal(hello));
530        let service = Service::new(router);
531
532        async fn options_access(service: &Service, origin: &str) -> Response {
533            TestClient::options("http://127.0.0.1:5801/hello")
534                .add_header("Origin", origin, true)
535                .add_header("Access-Control-Request-Method", "POST", true)
536                .add_header("Access-Control-Request-Headers", "Content-Type", true)
537                .send(service)
538                .await
539        }
540
541        let res = TestClient::options("https://salvo.rs").send(&service).await;
542        assert!(res.headers().get(ACCESS_CONTROL_ALLOW_METHODS).is_none());
543
544        let res = options_access(&service, "https://salvo.rs").await;
545        let headers = res.headers();
546        assert!(headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_some());
547        assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_some());
548
549        let res = TestClient::options("https://google.com")
550            .send(&service)
551            .await;
552        let headers = res.headers();
553        assert!(
554            headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_none(),
555            "POST, GET, DELETE, OPTIONS"
556        );
557        assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_none());
558    }
559}