salvo_cors/lib.rs
1//! Library adds CORS protection for Salvo web framework.
2//!
3//! [CORS]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
4//!
5//! # Docs
6//! Find the docs here: <https://salvo.rs/book/features/cors.html>
7#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
8#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
9#![cfg_attr(docsrs, feature(doc_cfg))]
10
11use bytes::{BufMut, BytesMut};
12use salvo_core::http::header::{self, HeaderMap, HeaderName, HeaderValue};
13use salvo_core::http::{Method, Request, Response, StatusCode};
14use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
15
16mod allow_credentials;
17mod allow_headers;
18mod allow_methods;
19mod allow_origin;
20mod allow_private_network;
21mod expose_headers;
22mod max_age;
23mod vary;
24
25pub use self::{
26 allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
27 allow_origin::AllowOrigin, allow_private_network::AllowPrivateNetwork,
28 expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
29};
30
31static WILDCARD: HeaderValue = HeaderValue::from_static("*");
32
33/// Represents a wildcard value (`*`) used with some CORS headers such as
34/// [`Cors::allow_methods`].
35#[derive(Debug, Clone, Copy)]
36#[must_use]
37pub struct Any;
38
39fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
40where
41 I: Iterator<Item = HeaderValue>,
42{
43 match iter.next() {
44 Some(fst) => {
45 let mut result = BytesMut::from(fst.as_bytes());
46 for val in iter {
47 result.reserve(val.len() + 1);
48 result.put_u8(b',');
49 result.extend_from_slice(val.as_bytes());
50 }
51
52 HeaderValue::from_maybe_shared(result.freeze()).ok()
53 }
54 None => None,
55 }
56}
57
58/// [`Cors`] middleware which adds headers for [CORS][mdn].
59///
60/// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
61#[derive(Clone, Debug)]
62pub struct Cors {
63 allow_credentials: AllowCredentials,
64 allow_headers: AllowHeaders,
65 allow_methods: AllowMethods,
66 allow_origin: AllowOrigin,
67 allow_private_network: AllowPrivateNetwork,
68 expose_headers: ExposeHeaders,
69 max_age: MaxAge,
70 vary: Vary,
71}
72impl Default for Cors {
73 #[inline]
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl Cors {
80 /// Create new `Cors`.
81 #[inline]
82 #[must_use]
83 pub fn new() -> Self {
84 Self {
85 allow_credentials: Default::default(),
86 allow_headers: Default::default(),
87 allow_methods: Default::default(),
88 allow_origin: Default::default(),
89 allow_private_network: Default::default(),
90 expose_headers: Default::default(),
91 max_age: Default::default(),
92 vary: Default::default(),
93 }
94 }
95
96 /// A permissive configuration:
97 ///
98 /// - All request headers allowed.
99 /// - All methods allowed.
100 /// - All origins allowed.
101 /// - All headers exposed.
102 #[must_use]
103 pub fn permissive() -> Self {
104 Self::new()
105 .allow_headers(Any)
106 .allow_methods(Any)
107 .allow_origin(Any)
108 .expose_headers(Any)
109 }
110
111 /// A very permissive configuration:
112 ///
113 /// - **Credentials allowed.**
114 /// - The method received in `Access-Control-Request-Method` is sent back
115 /// as an allowed method.
116 /// - The origin of the preflight request is sent back as an allowed origin.
117 /// - The header names received in `Access-Control-Request-Headers` are sent
118 /// back as allowed headers.
119 /// - No headers are currently exposed, but this may change in the future.
120 #[must_use]
121 pub fn very_permissive() -> Self {
122 Self::new()
123 .allow_credentials(true)
124 .allow_headers(AllowHeaders::mirror_request())
125 .allow_methods(AllowMethods::mirror_request())
126 .allow_origin(AllowOrigin::mirror_request())
127 }
128
129 /// Sets whether to add the `Access-Control-Allow-Credentials` header.
130 #[inline]
131 #[must_use]
132 pub fn allow_credentials(mut self, allow_credentials: impl Into<AllowCredentials>) -> Self {
133 self.allow_credentials = allow_credentials.into();
134 self
135 }
136
137 /// Adds multiple headers to the list of allowed request headers.
138 ///
139 /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`.
140 ///
141 /// # Panics
142 ///
143 /// Panics if any of the headers are not a valid `http::header::HeaderName`.
144 #[inline]
145 #[must_use]
146 pub fn allow_headers(mut self, headers: impl Into<AllowHeaders>) -> Self {
147 self.allow_headers = headers.into();
148 self
149 }
150
151 /// Sets the `Access-Control-Max-Age` header.
152 ///
153 /// # Example
154 ///
155 /// ```
156 /// use std::time::Duration;
157 /// use salvo_core::prelude::*;
158 /// use salvo_cors::Cors;
159 ///
160 /// let cors = Cors::new().max_age(30); // 30 seconds
161 /// let cors = Cors::new().max_age(Duration::from_secs(30)); // or a Duration
162 /// ```
163 #[inline]
164 #[must_use]
165 pub fn max_age(mut self, max_age: impl Into<MaxAge>) -> Self {
166 self.max_age = max_age.into();
167 self
168 }
169
170 /// Adds multiple methods to the existing list of allowed request methods.
171 ///
172 /// # Panics
173 ///
174 /// Panics if the provided argument is not a valid `http::Method`.
175 #[inline]
176 #[must_use]
177 pub fn allow_methods<I>(mut self, methods: I) -> Self
178 where
179 I: Into<AllowMethods>,
180 {
181 self.allow_methods = methods.into();
182 self
183 }
184
185 /// Set the value of the [`Access-Control-Allow-Origin`][mdn] header.
186 /// ```
187 /// use salvo_core::http::HeaderValue;
188 /// use salvo_cors::Cors;
189 ///
190 /// let cors = Cors::new().allow_origin(
191 /// "http://example.com".parse::<HeaderValue>().unwrap(),
192 /// );
193 /// ```
194 ///
195 /// Multiple origins can be allowed with
196 ///
197 /// ```
198 /// use salvo_cors::Cors;
199 ///
200 /// let origins = ["http://example.com", "http://api.example.com"];
201 ///
202 /// let cors = Cors::new().allow_origin(origins);
203 /// ```
204 ///
205 /// All origins can be allowed with
206 ///
207 /// ```
208 /// use salvo_cors::{Any, Cors};
209 ///
210 /// let cors = Cors::new().allow_origin(Any);
211 /// ```
212 ///
213 /// You can also use a closure
214 ///
215 /// ```
216 /// use salvo_cors::{Cors, AllowOrigin};
217 /// use salvo_core::http::HeaderValue;
218 /// use salvo_core::{Depot, Request};
219 ///
220 /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic(
221 /// |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
222 /// if origin?.as_bytes().ends_with(b".rust-lang.org") {
223 /// origin.cloned()
224 /// } else {
225 /// None
226 /// }
227 /// },
228 /// ));
229 /// ```
230 ///
231 /// You can also use an async closure, make sure all the values are owned
232 /// before passing into the future:
233 ///
234 /// ```
235 /// # #[derive(Clone)]
236 /// # struct Client;
237 /// # fn get_api_client() -> Client {
238 /// # Client
239 /// # }
240 /// # impl Client {
241 /// # async fn fetch_allowed_origins(&self) -> Vec<HeaderValue> {
242 /// # vec![HeaderValue::from_static("http://example.com")]
243 /// # }
244 /// # async fn fetch_allowed_origins_for_path(&self, _path: String) -> Vec<HeaderValue> {
245 /// # vec![HeaderValue::from_static("http://example.com")]
246 /// # }
247 /// # }
248 /// use salvo_cors::{Cors, AllowOrigin};
249 /// use salvo_core::http::header::HeaderValue;
250 /// use salvo_core::{Depot, Request};
251 ///
252 ///
253 /// let cors = Cors::new().allow_origin(AllowOrigin::dynamic_async(
254 /// |origin: Option<&HeaderValue>, _req: &Request, _depot: &Depot| {
255 /// let origin = origin.cloned();
256 /// async move {
257 /// let client = get_api_client();
258 /// // fetch list of origins that are allowed
259 /// let origins = client.fetch_allowed_origins().await;
260 /// if origins.contains(origin.as_ref()?) {
261 /// origin
262 /// } else {
263 /// None
264 /// }
265 /// }
266 /// },
267 /// ));
268 /// ```
269 ///
270 /// **Note** that multiple calls to this method will override any previous
271 /// calls.
272 ///
273 /// **Note** origin must contain http or https protocol name.
274 ///
275 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
276 #[inline]
277 #[must_use]
278 pub fn allow_origin(mut self, origin: impl Into<AllowOrigin>) -> Self {
279 self.allow_origin = origin.into();
280 self
281 }
282
283 /// Set the value of the [`Access-Control-Expose-Headers`][mdn] header.
284 ///
285 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
286 #[inline]
287 #[must_use]
288 pub fn expose_headers(mut self, headers: impl Into<ExposeHeaders>) -> Self {
289 self.expose_headers = headers.into();
290 self
291 }
292
293 /// Set the value of the [`Access-Control-Allow-Private-Network`][wicg] header.
294 ///
295 /// ```
296 /// use salvo_cors::Cors;
297 ///
298 /// let cors = Cors::new().allow_private_network(true);
299 /// ```
300 ///
301 /// [wicg]: https://wicg.github.io/private-network-access/
302 #[must_use]
303 pub fn allow_private_network<T>(mut self, allow_private_network: T) -> Self
304 where
305 T: Into<AllowPrivateNetwork>,
306 {
307 self.allow_private_network = allow_private_network.into();
308 self
309 }
310
311 /// Set the value(s) of the [`Vary`][mdn] header.
312 ///
313 /// In contrast to the other headers, this one has a non-empty default of
314 /// [`preflight_request_headers()`].
315 ///
316 /// You only need to set this is you want to remove some of these defaults,
317 /// or if you use a closure for one of the other headers and want to add a
318 /// vary header accordingly.
319 ///
320 /// [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary
321 #[must_use]
322 pub fn vary<T>(mut self, headers: impl Into<Vary>) -> Self {
323 self.vary = headers.into();
324 self
325 }
326
327 /// Returns a new `CorsHandler` using current cors settings.
328 pub fn into_handler(self) -> CorsHandler {
329 self.ensure_usable_cors_rules();
330 CorsHandler::new(self, CallNext::default())
331 }
332
333 fn ensure_usable_cors_rules(&self) {
334 if self.allow_credentials.is_true() {
335 assert!(
336 !self.allow_headers.is_wildcard(),
337 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
338 with `Access-Control-Allow-Headers: *`"
339 );
340
341 assert!(
342 !self.allow_methods.is_wildcard(),
343 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
344 with `Access-Control-Allow-Methods: *`"
345 );
346
347 assert!(
348 !self.allow_origin.is_wildcard(),
349 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
350 with `Access-Control-Allow-Origin: *`"
351 );
352
353 assert!(
354 !self.expose_headers.is_wildcard(),
355 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
356 with `Access-Control-Expose-Headers: *`"
357 );
358 }
359 }
360}
361
362/// Enum to control when to call next handler.
363#[non_exhaustive]
364#[derive(Default, Clone, Copy, Eq, PartialEq, Debug)]
365pub enum CallNext {
366 /// Call next handlers before [`CorsHandler`] write data to response.
367 #[default]
368 Before,
369 /// Call next handlers after [`CorsHandler`] write data to response.
370 After,
371}
372
373/// CorsHandler
374#[derive(Clone, Debug)]
375pub struct CorsHandler {
376 cors: Cors,
377 call_next: CallNext,
378}
379impl CorsHandler {
380 /// Create a new `CorsHandler`.
381 pub fn new(cors: Cors, call_next: CallNext) -> Self {
382 Self { cors, call_next }
383 }
384}
385
386#[async_trait]
387impl Handler for CorsHandler {
388 async fn handle(
389 &self,
390 req: &mut Request,
391 depot: &mut Depot,
392 res: &mut Response,
393 ctrl: &mut FlowCtrl,
394 ) {
395 if self.call_next == CallNext::Before {
396 ctrl.call_next(req, depot, res).await;
397 }
398
399 let origin = req.headers().get(&header::ORIGIN);
400 let mut headers = HeaderMap::new();
401
402 // These headers are applied to both preflight and subsequent regular CORS requests:
403 // https://fetch.spec.whatwg.org/#http-responses
404 headers.extend(self.cors.allow_origin.to_header(origin, req, depot).await);
405 headers.extend(
406 self.cors
407 .allow_credentials
408 .to_header(origin, req, depot)
409 .await,
410 );
411 headers.extend(
412 self.cors
413 .allow_private_network
414 .to_header(origin, req, depot)
415 .await,
416 );
417
418 let mut vary_headers = self.cors.vary.values();
419 if let Some(first) = vary_headers.next() {
420 let mut header = match headers.entry(header::VARY) {
421 header::Entry::Occupied(_) => {
422 unreachable!("no vary header inserted up to this point")
423 }
424 header::Entry::Vacant(v) => v.insert_entry(first),
425 };
426
427 for val in vary_headers {
428 header.append(val);
429 }
430 }
431
432 // Return results immediately upon preflight request
433 if req.method() == Method::OPTIONS {
434 // These headers are applied only to preflight requests
435 headers.extend(self.cors.allow_methods.to_header(origin, req, depot).await);
436 headers.extend(self.cors.allow_headers.to_header(origin, req, depot).await);
437 headers.extend(self.cors.max_age.to_header(origin, req, depot).await);
438 res.status_code = Some(StatusCode::NO_CONTENT);
439 } else {
440 // This header is applied only to non-preflight requests
441 headers.extend(self.cors.expose_headers.to_header(origin, req, depot).await);
442 }
443 res.headers_mut().extend(headers);
444
445 if self.call_next == CallNext::After {
446 ctrl.call_next(req, depot, res).await;
447 }
448 }
449}
450
451/// Iterator over the three request headers that may be involved in a CORS preflight request.
452///
453/// This is the default set of header names returned in the `vary` header
454pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
455 [
456 header::ORIGIN,
457 header::ACCESS_CONTROL_REQUEST_METHOD,
458 header::ACCESS_CONTROL_REQUEST_HEADERS,
459 ]
460 .into_iter()
461}
462
463#[cfg(test)]
464mod tests {
465 use salvo_core::http::header::*;
466 use salvo_core::prelude::*;
467 use salvo_core::test::TestClient;
468
469 use super::*;
470
471 #[tokio::test]
472 async fn test_cors() {
473 let cors_handler = Cors::new()
474 .allow_origin("https://salvo.rs")
475 .allow_methods(vec![Method::GET, Method::POST, Method::OPTIONS])
476 .allow_headers(vec![
477 "CONTENT-TYPE",
478 "Access-Control-Request-Method",
479 "Access-Control-Allow-Origin",
480 "Access-Control-Allow-Headers",
481 "Access-Control-Max-Age",
482 ])
483 .into_handler();
484
485 #[handler]
486 async fn hello() -> &'static str {
487 "hello"
488 }
489
490 let router = Router::new()
491 .hoop(cors_handler)
492 .push(Router::with_path("hello").goal(hello));
493 let service = Service::new(router);
494
495 async fn options_access(service: &Service, origin: &str) -> Response {
496 TestClient::options("http://127.0.0.1:5801/hello")
497 .add_header("Origin", origin, true)
498 .add_header("Access-Control-Request-Method", "POST", true)
499 .add_header("Access-Control-Request-Headers", "Content-Type", true)
500 .send(service)
501 .await
502 }
503
504 let res = TestClient::options("https://salvo.rs").send(&service).await;
505 assert!(res.headers().get(ACCESS_CONTROL_ALLOW_METHODS).is_none());
506
507 let res = options_access(&service, "https://salvo.rs").await;
508 let headers = res.headers();
509 assert!(headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_some());
510 assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_some());
511
512 let res = TestClient::options("https://google.com")
513 .send(&service)
514 .await;
515 let headers = res.headers();
516 assert!(
517 headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_none(),
518 "POST, GET, DELETE, OPTIONS"
519 );
520 assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_none());
521 }
522}