salvo_cors/
lib.rs

1//! Library adds CORS protection for Salvo web framework.
2//!
3//! [CORS]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
4//!
5//! # Docs
6//! Find the docs here: <https://salvo.rs/book/features/cors.html>
7#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
8#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
9#![cfg_attr(docsrs, feature(doc_cfg))]
10
11use bytes::{BufMut, BytesMut};
12use salvo_core::http::header::{self, HeaderMap, HeaderName, HeaderValue};
13use salvo_core::http::{Method, Request, Response, StatusCode};
14use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
15
16mod allow_credentials;
17mod allow_headers;
18mod allow_methods;
19mod allow_origin;
20mod expose_headers;
21mod max_age;
22mod vary;
23
24pub use self::{
25    allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
26    allow_origin::AllowOrigin, expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
27};
28
29static WILDCARD: HeaderValue = HeaderValue::from_static("*");
30
31/// Represents a wildcard value (`*`) used with some CORS headers such as
32/// [`Cors::allow_methods`].
33#[derive(Debug, Clone, Copy)]
34#[must_use]
35pub struct Any;
36
37fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
38where
39    I: Iterator<Item = HeaderValue>,
40{
41    match iter.next() {
42        Some(fst) => {
43            let mut result = BytesMut::from(fst.as_bytes());
44            for val in iter {
45                result.reserve(val.len() + 1);
46                result.put_u8(b',');
47                result.extend_from_slice(val.as_bytes());
48            }
49
50            HeaderValue::from_maybe_shared(result.freeze()).ok()
51        }
52        None => None,
53    }
54}
55
56/// [`Cors`] middleware which adds headers for [CORS][mdn].
57///
58/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
59#[derive(Clone, Debug)]
60pub struct Cors {
61    allow_credentials: AllowCredentials,
62    allow_headers: AllowHeaders,
63    allow_methods: AllowMethods,
64    allow_origin: AllowOrigin,
65    expose_headers: ExposeHeaders,
66    max_age: MaxAge,
67    vary: Vary,
68}
69impl Default for Cors {
70    #[inline]
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl Cors {
77    /// Create new `Cors`.
78    #[inline]
79    pub fn new() -> Self {
80        Cors {
81            allow_credentials: Default::default(),
82            allow_headers: Default::default(),
83            allow_methods: Default::default(),
84            allow_origin: Default::default(),
85            expose_headers: Default::default(),
86            max_age: Default::default(),
87            vary: Default::default(),
88        }
89    }
90
91    /// A permissive configuration:
92    ///
93    /// - All request headers allowed.
94    /// - All methods allowed.
95    /// - All origins allowed.
96    /// - All headers exposed.
97    pub fn permissive() -> Self {
98        Self::new()
99            .allow_headers(Any)
100            .allow_methods(Any)
101            .allow_origin(Any)
102            .expose_headers(Any)
103    }
104
105    /// A very permissive configuration:
106    ///
107    /// - **Credentials allowed.**
108    /// - The method received in `Access-Control-Request-Method` is sent back
109    ///   as an allowed method.
110    /// - The origin of the preflight request is sent back as an allowed origin.
111    /// - The header names received in `Access-Control-Request-Headers` are sent
112    ///   back as allowed headers.
113    /// - No headers are currently exposed, but this may change in the future.
114    pub fn very_permissive() -> Self {
115        Self::new()
116            .allow_credentials(true)
117            .allow_headers(AllowHeaders::mirror_request())
118            .allow_methods(AllowMethods::mirror_request())
119            .allow_origin(AllowOrigin::mirror_request())
120    }
121
122    /// Sets whether to add the `Access-Control-Allow-Credentials` header.
123    #[inline]
124    pub fn allow_credentials(mut self, allow_credentials: impl Into<AllowCredentials>) -> Self {
125        self.allow_credentials = allow_credentials.into();
126        self
127    }
128
129    /// Adds multiple headers to the list of allowed request headers.
130    ///
131    /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`.
132    ///
133    /// # Panics
134    ///
135    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
136    #[inline]
137    pub fn allow_headers(mut self, headers: impl Into<AllowHeaders>) -> Self {
138        self.allow_headers = headers.into();
139        self
140    }
141
142    /// Sets the `Access-Control-Max-Age` header.
143    ///
144    /// # Example
145    ///
146    /// ```
147    /// use std::time::Duration;
148    /// use salvo_core::prelude::*;
149    /// use salvo_cors::Cors;
150    ///
151    /// let cors = Cors::new().max_age(30); // 30 seconds
152    /// let cors = Cors::new().max_age(Duration::from_secs(30)); // or a Duration
153    /// ```
154    #[inline]
155    pub fn max_age(mut self, max_age: impl Into<MaxAge>) -> Self {
156        self.max_age = max_age.into();
157        self
158    }
159
160    /// Adds multiple methods to the existing list of allowed request methods.
161    ///
162    /// # Panics
163    ///
164    /// Panics if the provided argument is not a valid `http::Method`.
165    #[inline]
166    pub fn allow_methods<I>(mut self, methods: I) -> Self
167    where
168        I: Into<AllowMethods>,
169    {
170        self.allow_methods = methods.into();
171        self
172    }
173
174    /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
175    ///
176    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
177    #[inline]
178    pub fn allow_origin(mut self, origin: impl Into<AllowOrigin>) -> Self {
179        self.allow_origin = origin.into();
180        self
181    }
182
183    /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
184    ///
185    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
186    #[inline]
187    pub fn expose_headers(mut self, headers: impl Into<ExposeHeaders>) -> Self {
188        self.expose_headers = headers.into();
189        self
190    }
191
192    /// Set the value(s) of the [`Vary`][mdn] header.
193    ///
194    /// In contrast to the other headers, this one has a non-empty default of
195    /// [`preflight_request_headers()`].
196    ///
197    /// You only need to set this is you want to remove some of these defaults,
198    /// or if you use a closure for one of the other headers and want to add a
199    /// vary header accordingly.
200    ///
201    /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
202    pub fn vary<T>(mut self, headers: impl Into<Vary>) -> Self {
203        self.vary = headers.into();
204        self
205    }
206
207    /// Returns a new `CorsHandler` using current cors settings.
208    pub fn into_handler(self) -> CorsHandler {
209        self.ensure_usable_cors_rules();
210        CorsHandler::new(self, CallNext::default())
211    }
212
213    fn ensure_usable_cors_rules(&self) {
214        if self.allow_credentials.is_true() {
215            assert!(
216                !self.allow_headers.is_wildcard(),
217                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
218                 with `Access-Control-Allow-Headers: *`"
219            );
220
221            assert!(
222                !self.allow_methods.is_wildcard(),
223                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
224                 with `Access-Control-Allow-Methods: *`"
225            );
226
227            assert!(
228                !self.allow_origin.is_wildcard(),
229                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
230                 with `Access-Control-Allow-Origin: *`"
231            );
232
233            assert!(
234                !self.expose_headers.is_wildcard(),
235                "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
236                 with `Access-Control-Expose-Headers: *`"
237            );
238        }
239    }
240}
241
242/// Enum to control when to call next handler.
243#[non_exhaustive]
244#[derive(Default, Clone, Copy, Eq, PartialEq, Debug)]
245pub enum CallNext {
246    /// Call next handlers before [`CorsHandler`] write data to response.
247    #[default]
248    Before,
249    /// Call next handlers after [`CorsHandler`] write data to response.
250    After,
251}
252
253/// CorsHandler
254#[derive(Clone, Debug)]
255pub struct CorsHandler {
256    cors: Cors,
257    call_next: CallNext,
258}
259impl CorsHandler {
260    /// Create a new `CorsHandler`.
261    pub fn new(cors: Cors, call_next: CallNext) -> Self {
262        Self { cors, call_next }
263    }
264}
265
266#[async_trait]
267impl Handler for CorsHandler {
268    async fn handle(
269        &self,
270        req: &mut Request,
271        depot: &mut Depot,
272        res: &mut Response,
273        ctrl: &mut FlowCtrl,
274    ) {
275        if self.call_next == CallNext::Before {
276            ctrl.call_next(req, depot, res).await;
277        }
278
279        let origin = req.headers().get(&header::ORIGIN);
280        let mut headers = HeaderMap::new();
281
282        // These headers are applied to both preflight and subsequent regular CORS requests:
283        // https://fetch.spec.whatwg.org/#http-responses
284        headers.extend(self.cors.allow_origin.to_header(origin, req, depot));
285        headers.extend(self.cors.allow_credentials.to_header(origin, req, depot));
286
287        let mut vary_headers = self.cors.vary.values();
288        if let Some(first) = vary_headers.next() {
289            let mut header = match headers.entry(header::VARY) {
290                header::Entry::Occupied(_) => {
291                    unreachable!("no vary header inserted up to this point")
292                }
293                header::Entry::Vacant(v) => v.insert_entry(first),
294            };
295
296            for val in vary_headers {
297                header.append(val);
298            }
299        }
300
301        // Return results immediately upon preflight request
302        if req.method() == Method::OPTIONS {
303            // These headers are applied only to preflight requests
304            headers.extend(self.cors.allow_methods.to_header(origin, req, depot));
305            headers.extend(self.cors.allow_headers.to_header(origin, req, depot));
306            headers.extend(self.cors.max_age.to_header(origin, req, depot));
307            res.status_code = Some(StatusCode::NO_CONTENT);
308        } else {
309            // This header is applied only to non-preflight requests
310            headers.extend(self.cors.expose_headers.to_header(origin, req, depot));
311        }
312        res.headers_mut().extend(headers);
313
314        if self.call_next == CallNext::After {
315            ctrl.call_next(req, depot, res).await;
316        }
317    }
318}
319
320/// Iterator over the three request headers that may be involved in a CORS preflight request.
321///
322/// This is the default set of header names returned in the `vary` header
323pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
324    [
325        header::ORIGIN,
326        header::ACCESS_CONTROL_REQUEST_METHOD,
327        header::ACCESS_CONTROL_REQUEST_HEADERS,
328    ]
329    .into_iter()
330}
331
332#[cfg(test)]
333mod tests {
334    use salvo_core::http::header::*;
335    use salvo_core::prelude::*;
336    use salvo_core::test::TestClient;
337
338    use super::*;
339
340    #[tokio::test]
341    async fn test_cors() {
342        let cors_handler = Cors::new()
343            .allow_origin("https://salvo.rs")
344            .allow_methods(vec![Method::GET, Method::POST, Method::OPTIONS])
345            .allow_headers(vec![
346                "CONTENT-TYPE",
347                "Access-Control-Request-Method",
348                "Access-Control-Allow-Origin",
349                "Access-Control-Allow-Headers",
350                "Access-Control-Max-Age",
351            ])
352            .into_handler();
353
354        #[handler]
355        async fn hello() -> &'static str {
356            "hello"
357        }
358
359        let router = Router::new()
360            .hoop(cors_handler)
361            .push(Router::with_path("hello").goal(hello));
362        let service = Service::new(router);
363
364        async fn options_access(service: &Service, origin: &str) -> Response {
365            TestClient::options("http://127.0.0.1:5801/hello")
366                .add_header("Origin", origin, true)
367                .add_header("Access-Control-Request-Method", "POST", true)
368                .add_header("Access-Control-Request-Headers", "Content-Type", true)
369                .send(service)
370                .await
371        }
372
373        let res = TestClient::options("https://salvo.rs").send(&service).await;
374        assert!(res.headers().get(ACCESS_CONTROL_ALLOW_METHODS).is_none());
375
376        let res = options_access(&service, "https://salvo.rs").await;
377        let headers = res.headers();
378        assert!(headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_some());
379        assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_some());
380
381        let res = TestClient::options("https://google.com")
382            .send(&service)
383            .await;
384        let headers = res.headers();
385        assert!(
386            headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_none(),
387            "POST, GET, DELETE, OPTIONS"
388        );
389        assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_none());
390    }
391}