1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
8#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
9#![cfg_attr(docsrs, feature(doc_cfg))]
10
11use bytes::{BufMut, BytesMut};
12use salvo_core::http::header::{self, HeaderMap, HeaderName, HeaderValue};
13use salvo_core::http::{Method, Request, Response, StatusCode};
14use salvo_core::{Depot, FlowCtrl, Handler, async_trait};
15
16mod allow_credentials;
17mod allow_headers;
18mod allow_methods;
19mod allow_origin;
20mod expose_headers;
21mod max_age;
22mod vary;
23
24pub use self::{
25 allow_credentials::AllowCredentials, allow_headers::AllowHeaders, allow_methods::AllowMethods,
26 allow_origin::AllowOrigin, expose_headers::ExposeHeaders, max_age::MaxAge, vary::Vary,
27};
28
29static WILDCARD: HeaderValue = HeaderValue::from_static("*");
30
31#[derive(Debug, Clone, Copy)]
34#[must_use]
35pub struct Any;
36
37fn separated_by_commas<I>(mut iter: I) -> Option<HeaderValue>
38where
39 I: Iterator<Item = HeaderValue>,
40{
41 match iter.next() {
42 Some(fst) => {
43 let mut result = BytesMut::from(fst.as_bytes());
44 for val in iter {
45 result.reserve(val.len() + 1);
46 result.put_u8(b',');
47 result.extend_from_slice(val.as_bytes());
48 }
49
50 HeaderValue::from_maybe_shared(result.freeze()).ok()
51 }
52 None => None,
53 }
54}
55
56#[derive(Clone, Debug)]
60pub struct Cors {
61 allow_credentials: AllowCredentials,
62 allow_headers: AllowHeaders,
63 allow_methods: AllowMethods,
64 allow_origin: AllowOrigin,
65 expose_headers: ExposeHeaders,
66 max_age: MaxAge,
67 vary: Vary,
68}
69impl Default for Cors {
70 #[inline]
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl Cors {
77 #[inline]
79 pub fn new() -> Self {
80 Cors {
81 allow_credentials: Default::default(),
82 allow_headers: Default::default(),
83 allow_methods: Default::default(),
84 allow_origin: Default::default(),
85 expose_headers: Default::default(),
86 max_age: Default::default(),
87 vary: Default::default(),
88 }
89 }
90
91 pub fn permissive() -> Self {
98 Self::new()
99 .allow_headers(Any)
100 .allow_methods(Any)
101 .allow_origin(Any)
102 .expose_headers(Any)
103 }
104
105 pub fn very_permissive() -> Self {
115 Self::new()
116 .allow_credentials(true)
117 .allow_headers(AllowHeaders::mirror_request())
118 .allow_methods(AllowMethods::mirror_request())
119 .allow_origin(AllowOrigin::mirror_request())
120 }
121
122 #[inline]
124 pub fn allow_credentials(mut self, allow_credentials: impl Into<AllowCredentials>) -> Self {
125 self.allow_credentials = allow_credentials.into();
126 self
127 }
128
129 #[inline]
137 pub fn allow_headers(mut self, headers: impl Into<AllowHeaders>) -> Self {
138 self.allow_headers = headers.into();
139 self
140 }
141
142 #[inline]
155 pub fn max_age(mut self, max_age: impl Into<MaxAge>) -> Self {
156 self.max_age = max_age.into();
157 self
158 }
159
160 #[inline]
166 pub fn allow_methods<I>(mut self, methods: I) -> Self
167 where
168 I: Into<AllowMethods>,
169 {
170 self.allow_methods = methods.into();
171 self
172 }
173
174 #[inline]
178 pub fn allow_origin(mut self, origin: impl Into<AllowOrigin>) -> Self {
179 self.allow_origin = origin.into();
180 self
181 }
182
183 #[inline]
187 pub fn expose_headers(mut self, headers: impl Into<ExposeHeaders>) -> Self {
188 self.expose_headers = headers.into();
189 self
190 }
191
192 pub fn vary<T>(mut self, headers: impl Into<Vary>) -> Self {
203 self.vary = headers.into();
204 self
205 }
206
207 pub fn into_handler(self) -> CorsHandler {
209 self.ensure_usable_cors_rules();
210 CorsHandler::new(self, CallNext::default())
211 }
212
213 fn ensure_usable_cors_rules(&self) {
214 if self.allow_credentials.is_true() {
215 assert!(
216 !self.allow_headers.is_wildcard(),
217 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
218 with `Access-Control-Allow-Headers: *`"
219 );
220
221 assert!(
222 !self.allow_methods.is_wildcard(),
223 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
224 with `Access-Control-Allow-Methods: *`"
225 );
226
227 assert!(
228 !self.allow_origin.is_wildcard(),
229 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
230 with `Access-Control-Allow-Origin: *`"
231 );
232
233 assert!(
234 !self.expose_headers.is_wildcard(),
235 "Invalid CORS configuration: Cannot combine `Access-Control-Allow-Credentials: true` \
236 with `Access-Control-Expose-Headers: *`"
237 );
238 }
239 }
240}
241
242#[non_exhaustive]
244#[derive(Default, Clone, Copy, Eq, PartialEq, Debug)]
245pub enum CallNext {
246 #[default]
248 Before,
249 After,
251}
252
253#[derive(Clone, Debug)]
255pub struct CorsHandler {
256 cors: Cors,
257 call_next: CallNext,
258}
259impl CorsHandler {
260 pub fn new(cors: Cors, call_next: CallNext) -> Self {
262 Self { cors, call_next }
263 }
264}
265
266#[async_trait]
267impl Handler for CorsHandler {
268 async fn handle(
269 &self,
270 req: &mut Request,
271 depot: &mut Depot,
272 res: &mut Response,
273 ctrl: &mut FlowCtrl,
274 ) {
275 if self.call_next == CallNext::Before {
276 ctrl.call_next(req, depot, res).await;
277 }
278
279 let origin = req.headers().get(&header::ORIGIN);
280 let mut headers = HeaderMap::new();
281
282 headers.extend(self.cors.allow_origin.to_header(origin, req, depot));
285 headers.extend(self.cors.allow_credentials.to_header(origin, req, depot));
286
287 let mut vary_headers = self.cors.vary.values();
288 if let Some(first) = vary_headers.next() {
289 let mut header = match headers.entry(header::VARY) {
290 header::Entry::Occupied(_) => {
291 unreachable!("no vary header inserted up to this point")
292 }
293 header::Entry::Vacant(v) => v.insert_entry(first),
294 };
295
296 for val in vary_headers {
297 header.append(val);
298 }
299 }
300
301 if req.method() == Method::OPTIONS {
303 headers.extend(self.cors.allow_methods.to_header(origin, req, depot));
305 headers.extend(self.cors.allow_headers.to_header(origin, req, depot));
306 headers.extend(self.cors.max_age.to_header(origin, req, depot));
307 res.status_code = Some(StatusCode::NO_CONTENT);
308 } else {
309 headers.extend(self.cors.expose_headers.to_header(origin, req, depot));
311 }
312 res.headers_mut().extend(headers);
313
314 if self.call_next == CallNext::After {
315 ctrl.call_next(req, depot, res).await;
316 }
317 }
318}
319
320pub fn preflight_request_headers() -> impl Iterator<Item = HeaderName> {
324 [
325 header::ORIGIN,
326 header::ACCESS_CONTROL_REQUEST_METHOD,
327 header::ACCESS_CONTROL_REQUEST_HEADERS,
328 ]
329 .into_iter()
330}
331
332#[cfg(test)]
333mod tests {
334 use salvo_core::http::header::*;
335 use salvo_core::prelude::*;
336 use salvo_core::test::TestClient;
337
338 use super::*;
339
340 #[tokio::test]
341 async fn test_cors() {
342 let cors_handler = Cors::new()
343 .allow_origin("https://salvo.rs")
344 .allow_methods(vec![Method::GET, Method::POST, Method::OPTIONS])
345 .allow_headers(vec![
346 "CONTENT-TYPE",
347 "Access-Control-Request-Method",
348 "Access-Control-Allow-Origin",
349 "Access-Control-Allow-Headers",
350 "Access-Control-Max-Age",
351 ])
352 .into_handler();
353
354 #[handler]
355 async fn hello() -> &'static str {
356 "hello"
357 }
358
359 let router = Router::new()
360 .hoop(cors_handler)
361 .push(Router::with_path("hello").goal(hello));
362 let service = Service::new(router);
363
364 async fn options_access(service: &Service, origin: &str) -> Response {
365 TestClient::options("http://127.0.0.1:5801/hello")
366 .add_header("Origin", origin, true)
367 .add_header("Access-Control-Request-Method", "POST", true)
368 .add_header("Access-Control-Request-Headers", "Content-Type", true)
369 .send(service)
370 .await
371 }
372
373 let res = TestClient::options("https://salvo.rs").send(&service).await;
374 assert!(res.headers().get(ACCESS_CONTROL_ALLOW_METHODS).is_none());
375
376 let res = options_access(&service, "https://salvo.rs").await;
377 let headers = res.headers();
378 assert!(headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_some());
379 assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_some());
380
381 let res = TestClient::options("https://google.com")
382 .send(&service)
383 .await;
384 let headers = res.headers();
385 assert!(
386 headers.get(ACCESS_CONTROL_ALLOW_METHODS).is_none(),
387 "POST, GET, DELETE, OPTIONS"
388 );
389 assert!(headers.get(ACCESS_CONTROL_ALLOW_HEADERS).is_none());
390 }
391}