salvo_cors/lib.rs
1//! Library adds CORS protection for Salvo web framework.
2//!
3//! [CORS]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
4//!
5//! # Docs
6//! Find the docs here: <https://salvo.rs/book/features/cors.html>
7#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
8#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
9#![cfg_attr(docsrs, feature(doc_cfg))]
10
11use bytes::{BufMut, BytesMut};
12use salvo_core::http::header::{self, HeaderMap, HeaderName, HeaderValue};
13use salvo_core::http::{Method, Request, Response, StatusCode};
14use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
15
16mod allow_credentials;
17mod allow_headers;
18mod allow_methods;
19mod allow_origin;
20mod allow_private_network;
21mod expose_headers;
22mod max_age;
23mod vary;
24
25pub use self::allow_credentials::AllowCredentials;
26pub use self::allow_headers::AllowHeaders;
27pub use self::allow_methods::AllowMethods;
28pub use self::allow_origin::AllowOrigin;
29pub use self::allow_private_network::AllowPrivateNetwork;
30pub use self::expose_headers::ExposeHeaders;
31pub use self::max_age::MaxAge;
32pub use self::vary::Vary;
33
34static WILDCARD: HeaderValue = HeaderValue::from_static("*");
35
36/// Represents a wildcard value (`*`) used with some CORS headers such as
37/// [`Cors::allow_methods`].
38#[derive(Debug, Clone, Copy)]
39#[must_use]
40pub struct Any;
41
42fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
43where
44 I: Iterator<Item = HeaderValue>,
45{
46 match iter.next() {
47 Some(fst) => {
48 let mut result = BytesMut::from(fst.as_bytes());
49 for val in iter {
50 result.reserve(val.len() + 1);
51 result.put_u8(b',');
52 result.extend_from_slice(val.as_bytes());
53 }
54
55 HeaderValue::from_maybe_shared(result.freeze()).ok()
56 }
57 None => None,
58 }
59}
60
61/// [`Cors`] middleware which adds headers for [CORS][mdn].
62///
63/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
64#[derive(Clone, Debug)]
65pub struct Cors {
66 allow_credentials: AllowCredentials,
67 allow_headers: AllowHeaders,
68 allow_methods: AllowMethods,
69 allow_origin: AllowOrigin,
70 allow_private_network: AllowPrivateNetwork,
71 expose_headers: ExposeHeaders,
72 max_age: MaxAge,
73 vary: Vary,
74}
75impl Default for Cors {
76 #[inline]
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl Cors {
83 /// Create new `Cors`.
84 #[inline]
85 #[must_use]
86 pub fn new() -> Self {
87 Self {
88 allow_credentials: Default::default(),
89 allow_headers: Default::default(),
90 allow_methods: Default::default(),
91 allow_origin: Default::default(),
92 allow_private_network: Default::default(),
93 expose_headers: Default::default(),
94 max_age: Default::default(),
95 vary: Default::default(),
96 }
97 }
98
99 /// A permissive configuration:
100 ///
101 /// - All request headers allowed.
102 /// - All methods allowed.
103 /// - All origins allowed.
104 /// - All headers exposed.
105 #[must_use]
106 pub fn permissive() -> Self {
107 Self::new()
108 .allow_headers(Any)
109 .allow_methods(Any)
110 .allow_origin(Any)
111 .expose_headers(Any)
112 }
113
114 /// A very permissive configuration:
115 ///
116 /// - **Credentials allowed.**
117 /// - The method received in `Access-Control-Request-Method` is sent back as an allowed method.
118 /// - The origin of the preflight request is sent back as an allowed origin.
119 /// - The header names received in `Access-Control-Request-Headers` are sent back as allowed
120 /// headers.
121 /// - No headers are currently exposed, but this may change in the future.
122 #[must_use]
123 pub fn very_permissive() -> Self {
124 Self::new()
125 .allow_credentials(true)
126 .allow_headers(AllowHeaders::mirror_request())
127 .allow_methods(AllowMethods::mirror_request())
128 .allow_origin(AllowOrigin::mirror_request())
129 }
130
131 /// Sets whether to add the `Access-Control-Allow-Credentials` header.
132 #[inline]
133 #[must_use]
134 pub fn allow_credentials(mut self, allow_credentials: impl Into<AllowCredentials>) -> Self {
135 self.allow_credentials = allow_credentials.into();
136 self
137 }
138
139 /// Adds multiple headers to the list of allowed request headers.
140 ///
141 /// **Note**: These should match the values the browser sends via
142 /// `Access-Control-Request-Headers`, e.g.`content-type`.
143 ///
144 /// # Panics
145 ///
146 /// Panics if any of the headers are not a valid `http::header::HeaderName`.
147 #[inline]
148 #[must_use]
149 pub fn allow_headers(mut self, headers: impl Into<AllowHeaders>) -> Self {
150 self.allow_headers = headers.into();
151 self
152 }
153
154 /// Sets the `Access-Control-Max-Age` header.
155 ///
156 /// # Example
157 ///
158 /// ```
159 /// use std::time::Duration;
160 ///
161 /// use salvo_core::prelude::*;
162 /// use salvo_cors::Cors;
163 ///
164 /// let cors = Cors::new().max_age(30); // 30 seconds
165 /// let cors = Cors::new().max_age(Duration::from_secs(30)); // or a Duration
166 /// ```
167 #[inline]
168 #[must_use]
169 pub fn max_age(mut self, max_age: impl Into<MaxAge>) -> Self {
170 self.max_age = max_age.into();
171 self
172 }
173
174 /// Adds multiple methods to the existing list of allowed request methods.
175 ///
176 /// # Panics
177 ///
178 /// Panics if the provided argument is not a valid `http::Method`.
179 #[inline]
180 #[must_use]
181 pub fn allow_methods<I>(mut self, methods: I) -> Self
182 where
183 I: Into<AllowMethods>,
184 {
185 self.allow_methods = methods.into();
186 self
187 }
188
189 /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
190 /// ```
191 /// use salvo_core::http::HeaderValue;
192 /// use salvo_cors::Cors;
193 ///
194 /// let cors = Cors::new().allow_origin("http://example.com".parse::<HeaderValue>().unwrap());
195 /// ```
196 ///
197 /// Multiple origins can be allowed with
198 ///
199 /// ```
200 /// use salvo_cors::Cors;
201 ///
202 /// let origins = ["http://example.com", "http://api.example.com"];
203 ///
204 /// let cors = Cors::new().allow_origin(origins);
205 /// ```
206 ///
207 /// All origins can be allowed with
208 ///
209 /// ```
210 /// use salvo_cors::{Any, Cors};
211 ///
212 /// let cors = Cors::new().allow_origin(Any);
213 /// ```
214 ///
215 /// You can also use a closure
216 ///
217 /// ```
218 /// use salvo_core::http::HeaderValue;
219 /// use salvo_core::{Depot, Request};
220 /// use salvo_cors::{AllowOrigin, Cors};
221 ///
222 /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic(
223 /// |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
224 /// if origin?.as_bytes().ends_with(b".rust-lang.org") {
225 /// origin.cloned()
226 /// } else {
227 /// None
228 /// }
229 /// },
230 /// ));
231 /// ```
232 ///
233 /// You can also use an async closure, make sure all the values are owned
234 /// before passing into the future:
235 ///
236 /// ```
237 /// # #[derive(Clone)]
238 /// # struct Client;
239 /// # fn get_api_client() -> Client {
240 /// # Client
241 /// # }
242 /// # impl Client {
243 /// # async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
244 /// # vec![HeaderValue::from_static("http://example.com")]
245 /// # }
246 /// # async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
247 /// # vec![HeaderValue::from_static("http://example.com")]
248 /// # }
249 /// # }
250 /// use salvo_core::http::header::HeaderValue;
251 /// use salvo_core::{Depot, Request};
252 /// use salvo_cors::{AllowOrigin, Cors};
253 ///
254 /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic_async(
255 /// |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
256 /// let origin = origin.cloned();
257 /// async move {
258 /// let client = get_api_client();
259 /// // fetch list of origins that are allowed
260 /// let origins = client.fetch_allowed_origins().await;
261 /// if origins.contains(origin.as_ref()?) {
262 /// origin
263 /// } else {
264 /// None
265 /// }
266 /// }
267 /// },
268 /// ));
269 /// ```
270 ///
271 /// **Note** that multiple calls to this method will override any previous
272 /// calls.
273 ///
274 /// **Note** origin must contain http or https protocol name.
275 ///
276 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
277 #[inline]
278 #[must_use]
279 pub fn allow_origin(mut self, origin: impl Into<AllowOrigin>) -> Self {
280 self.allow_origin = origin.into();
281 self
282 }
283
284 /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
285 ///
286 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
287 #[inline]
288 #[must_use]
289 pub fn expose_headers(mut self, headers: impl Into<ExposeHeaders>) -> Self {
290 self.expose_headers = headers.into();
291 self
292 }
293
294 /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
295 ///
296 /// ```
297 /// use salvo_cors::Cors;
298 ///
299 /// let cors = Cors::new().allow_private_network(true);
300 /// ```
301 ///
302 /// [wicg]: https://wicg.github.io/private-network-access/
303 #[must_use]
304 pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
305 where
306 T: Into<AllowPrivateNetwork>,
307 {
308 self.allow_private_network = allow_private_network.into();
309 self
310 }
311
312 /// Set the value(s) of the [`Vary`][mdn] header.
313 ///
314 /// In contrast to the other headers, this one has a non-empty default of
315 /// [`preflight_request_headers()`].
316 ///
317 /// You only need to set this is you want to remove some of these defaults,
318 /// or if you use a closure for one of the other headers and want to add a
319 /// vary header accordingly.
320 ///
321 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
322 #[must_use]
323 pub fn vary<T>(mut self, headers: impl Into<Vary>) -> Self {
324 self.vary = headers.into();
325 self
326 }
327
328 /// Returns a new `CorsHandler` using current cors settings.
329 pub fn into_handler(self) -> CorsHandler {
330 self.ensure_usable_cors_rules();
331 CorsHandler::new(self, CallNext::default())
332 }
333
334 fn ensure_usable_cors_rules(&self) {
335 if self.allow_credentials.is_true() {
336 assert!(
337 !self.allow_headers.is_wildcard(),
338 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
339 with `Access-Control-Allow-Headers: *`"
340 );
341
342 assert!(
343 !self.allow_methods.is_wildcard(),
344 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
345 with `Access-Control-Allow-Methods: *`"
346 );
347
348 assert!(
349 !self.allow_origin.is_wildcard(),
350 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
351 with `Access-Control-Allow-Origin: *`"
352 );
353
354 assert!(
355 !self.expose_headers.is_wildcard(),
356 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
357 with `Access-Control-Expose-Headers: *`"
358 );
359 }
360 }
361}
362
363/// Enum to control when to call next handler.
364#[non_exhaustive]
365#[derive(Default, Clone, Copy, Eq, PartialEq, Debug)]
366pub enum CallNext {
367 /// Call next handlers before [`CorsHandler`] write data to response.
368 #[default]
369 Before,
370 /// Call next handlers after [`CorsHandler`] write data to response.
371 After,
372}
373
374/// CorsHandler
375#[derive(Clone, Debug)]
376pub struct CorsHandler {
377 cors: Cors,
378 call_next: CallNext,
379}
380impl CorsHandler {
381 /// Create a new `CorsHandler`.
382 pub fn new(cors: Cors, call_next: CallNext) -> Self {
383 Self { cors, call_next }
384 }
385}
386
387#[async_trait]
388impl Handler for CorsHandler {
389 async fn handle(
390 &self,
391 req: &mut Request,
392 depot: &mut Depot,
393 res: &mut Response,
394 ctrl: &mut FlowCtrl,
395 ) {
396 if self.call_next == CallNext::Before {
397 ctrl.call_next(req, depot, res).await;
398 }
399
400 let origin = req.headers().get(&header::ORIGIN);
401 let mut headers = HeaderMap::new();
402
403 // These headers are applied to both preflight and subsequent regular CORS requests:
404 // https://fetch.spec.whatwg.org/#http-responses
405 headers.extend(self.cors.allow_origin.to_header(origin, req, depot).await);
406 headers.extend(
407 self.cors
408 .allow_credentials
409 .to_header(origin, req, depot)
410 .await,
411 );
412 headers.extend(
413 self.cors
414 .allow_private_network
415 .to_header(origin, req, depot)
416 .await,
417 );
418
419 let mut vary_headers = self.cors.vary.values();
420 if let Some(first) = vary_headers.next() {
421 let mut header = match headers.entry(header::VARY) {
422 header::Entry::Occupied(_) => {
423 unreachable!("no vary header inserted up to this point")
424 }
425 header::Entry::Vacant(v) => v.insert_entry(first),
426 };
427
428 for val in vary_headers {
429 header.append(val);
430 }
431 }
432
433 // Return results immediately upon preflight request
434 if req.method() == Method::OPTIONS {
435 // These headers are applied only to preflight requests
436 headers.extend(self.cors.allow_methods.to_header(origin, req, depot).await);
437 headers.extend(self.cors.allow_headers.to_header(origin, req, depot).await);
438 headers.extend(self.cors.max_age.to_header(origin, req, depot).await);
439 res.status_code = Some(StatusCode::NO_CONTENT);
440 } else {
441 // This header is applied only to non-preflight requests
442 headers.extend(self.cors.expose_headers.to_header(origin, req, depot).await);
443 }
444 res.headers_mut().extend(headers);
445
446 if self.call_next == CallNext::After {
447 ctrl.call_next(req, depot, res).await;
448 }
449 }
450}
451
452/// Iterator over the three request headers that may be involved in a CORS preflight request.
453///
454/// This is the default set of header names returned in the `vary` header
455pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
456 [
457 header::ORIGIN,
458 header::ACCESS_CONTROL_REQUEST_METHOD,
459 header::ACCESS_CONTROL_REQUEST_HEADERS,
460 ]
461 .into_iter()
462}
463
464#[cfg(test)]
465mod tests {
466 use salvo_core::http::header::*;
467 use salvo_core::prelude::*;
468 use salvo_core::test::TestClient;
469
470 use super::*;
471
472 #[tokio::test]
473 async fn test_cors() {
474 let cors_handler = Cors::new()
475 .allow_origin("https://salvo.rs")
476 .allow_methods(vec![Method::GET, Method::POST, Method::OPTIONS])
477 .allow_headers(vec![
478 "CONTENT-TYPE",
479 "Access-Control-Request-Method",
480 "Access-Control-Allow-Origin",
481 "Access-Control-Allow-Headers",
482 "Access-Control-Max-Age",
483 ])
484 .into_handler();
485
486 #[handler]
487 async fn hello() -> &'static str {
488 "hello"
489 }
490
491 let router = Router::new()
492 .hoop(cors_handler)
493 .push(Router::with_path("hello").goal(hello));
494 let service = Service::new(router);
495
496 async fn options_access(service: &Service, origin: &str) -> Response {
497 TestClient::options("http://127.0.0.1:5801/hello")
498 .add_header("Origin", origin, true)
499 .add_header("Access-Control-Request-Method", "POST", true)
500 .add_header("Access-Control-Request-Headers", "Content-Type", true)
501 .send(service)
502 .await
503 }
504
505 let res = TestClient::options("https://salvo.rs").send(&service).await;
506 assert!(res.headers().get(ACCESS_CONTROL_ALLOW_METHODS).is_none());
507
508 let res = options_access(&service, "https://salvo.rs").await;
509 let headers = res.headers();
510 assert!(headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_some());
511 assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_some());
512
513 let res = TestClient::options("https://google.com")
514 .send(&service)
515 .await;
516 let headers = res.headers();
517 assert!(
518 headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_none(),
519 "POST, GET, DELETE, OPTIONS"
520 );
521 assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_none());
522 }
523}