1use headers::{
12 AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMap,
13 HeaderMapExt, HeaderName, HeaderValue, Origin,
14};
15use http::header;
16use hyper::{Body, Request, Response, StatusCode};
17use std::collections::HashSet;
18
19use crate::{Error, error_page, handler::RequestHandlerOpts};
20
21#[derive(Clone, Debug)]
23pub struct Cors {
24 allowed_headers: HashSet<HeaderName>,
25 exposed_headers: HashSet<HeaderName>,
26 max_age: Option<u64>,
27 allowed_methods: HashSet<http::Method>,
28 origins: Option<HashSet<HeaderValue>>,
29}
30
31pub fn new(
33 origins_str: &str,
34 allow_headers_str: &str,
35 expose_headers_str: &str,
36) -> Option<Configured> {
37 let cors = Cors::new();
38 let cors = if origins_str.is_empty() {
39 None
40 } else {
41 let [allow_headers_vec, expose_headers_vec] =
42 [allow_headers_str, expose_headers_str].map(|s| {
43 if s.is_empty() {
44 vec!["origin", "content-type"]
45 } else {
46 s.split(',').map(|s| s.trim()).collect::<Vec<_>>()
47 }
48 });
49 let [allow_headers_str, expose_headers_str] =
50 [&allow_headers_vec, &expose_headers_vec].map(|v| v.join(","));
51
52 let cors_res = if origins_str == "*" {
53 Some(
54 cors.allow_any_origin()
55 .allow_headers(allow_headers_vec)
56 .expose_headers(expose_headers_vec)
57 .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
58 )
59 } else {
60 let hosts = origins_str.split(',').map(|s| s.trim()).collect::<Vec<_>>();
61 if hosts.is_empty() {
62 None
63 } else {
64 Some(
65 cors.allow_origins(hosts)
66 .allow_headers(allow_headers_vec)
67 .expose_headers(expose_headers_vec)
68 .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
69 )
70 }
71 };
72
73 if cors_res.is_some() {
74 tracing::info!(
75 "cors enabled=true, allow_methods=[GET,HEAD,OPTIONS], allow_origins={}, allow_headers=[{}], expose_headers=[{}]",
76 origins_str,
77 allow_headers_str,
78 expose_headers_str,
79 );
80 }
81 cors_res
82 };
83
84 Cors::build(cors)
85}
86
87impl Cors {
88 pub fn new() -> Self {
90 Self {
91 origins: None,
92 allowed_headers: HashSet::new(),
93 exposed_headers: HashSet::new(),
94 allowed_methods: HashSet::new(),
95 max_age: None,
96 }
97 }
98
99 pub fn allow_methods<I>(mut self, methods: I) -> Self
105 where
106 I: IntoIterator,
107 http::Method: TryFrom<I::Item>,
108 {
109 let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
110 Ok(m) => m,
111 Err(_) => panic!("cors: illegal method"),
112 });
113 self.allowed_methods.extend(iter);
114 self
115 }
116
117 pub fn allow_any_origin(mut self) -> Self {
124 self.origins = None;
125 self
126 }
127
128 pub fn allow_origins<I>(mut self, origins: I) -> Self
134 where
135 I: IntoIterator,
136 I::Item: IntoOrigin,
137 {
138 let iter = origins
139 .into_iter()
140 .map(IntoOrigin::into_origin)
141 .map(|origin| {
142 origin
143 .to_string()
144 .parse()
145 .expect("cors: Origin is always a valid HeaderValue")
146 });
147
148 self.origins.get_or_insert_with(HashSet::new).extend(iter);
149 self
150 }
151
152 pub fn allow_headers<I>(mut self, headers: I) -> Self
160 where
161 I: IntoIterator,
162 HeaderName: TryFrom<I::Item>,
163 {
164 let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
165 Ok(h) => h,
166 Err(_) => panic!("cors: illegal Header"),
167 });
168 self.allowed_headers.extend(iter);
169 self
170 }
171
172 pub fn expose_headers<I>(mut self, headers: I) -> Self
180 where
181 I: IntoIterator,
182 HeaderName: TryFrom<I::Item>,
183 {
184 let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
185 Ok(h) => h,
186 Err(_) => panic!("cors: illegal Header"),
187 });
188 self.exposed_headers.extend(iter);
189 self
190 }
191
192 pub fn build(cors: Option<Cors>) -> Option<Configured> {
194 cors.as_ref()?;
195 let cors = cors?;
196
197 let allowed_headers = cors.allowed_headers.iter().cloned().collect();
198 let exposed_headers = cors.exposed_headers.iter().cloned().collect();
199 let methods_header = cors.allowed_methods.iter().cloned().collect();
200
201 Some(Configured {
202 cors,
203 allowed_headers,
204 exposed_headers,
205 methods_header,
206 })
207 }
208}
209
210impl Default for Cors {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216#[derive(Clone, Debug)]
217pub struct Configured {
219 cors: Cors,
220 allowed_headers: AccessControlAllowHeaders,
221 exposed_headers: AccessControlExposeHeaders,
222 methods_header: AccessControlAllowMethods,
223}
224
225#[derive(Debug)]
226pub enum Validated {
228 Preflight(HeaderValue),
230 Simple(HeaderValue),
232 NotCors,
234}
235
236#[derive(Debug, Default)]
237pub enum Forbidden {
239 #[default]
241 Origin,
242 Method,
244 Header,
246}
247
248impl Configured {
249 pub fn check_request(
251 &self,
252 method: &http::Method,
253 headers: &HeaderMap,
254 ) -> Result<(HeaderMap, Validated), Forbidden> {
255 match (headers.get(header::ORIGIN), method) {
256 (Some(origin), &http::Method::OPTIONS) => {
257 if !self.is_origin_allowed(origin) {
260 return Err(Forbidden::Origin);
261 }
262
263 if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
264 if !self.is_method_allowed(req_method) {
265 return Err(Forbidden::Method);
266 }
267 } else {
268 tracing::warn!(
269 "cors: preflight request missing `access-control-request-method` header"
270 );
271 return Err(Forbidden::Method);
272 }
273
274 if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
275 let headers = match req_headers.to_str() {
276 Ok(val) => val,
277 Err(err) => {
278 tracing::error!(
279 "cors: error parsing header `access-control-request-headers` value: {:?}",
280 err,
281 );
282 return Err(Forbidden::Header);
283 }
284 };
285
286 for header in headers.split(',') {
287 let h = header.trim();
288 if !self.is_header_allowed(h) {
289 tracing::error!(
290 "cors: header `{}` is not allowed because is missing in `cors_allow_headers` server option",
291 h
292 );
293 return Err(Forbidden::Header);
294 }
295 }
296 }
297
298 let mut headers = HeaderMap::new();
299 self.append_preflight_headers(&mut headers);
300 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
301
302 Ok((headers, Validated::Preflight(origin.clone())))
303 }
304 (Some(origin), _) => {
305 tracing::trace!("cors origin header: {:?}", origin);
307
308 if self.is_origin_allowed(origin) {
309 let mut headers = HeaderMap::new();
310 self.append_preflight_headers(&mut headers);
311 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
312
313 Ok((headers, Validated::Simple(origin.clone())))
314 } else {
315 Err(Forbidden::Origin)
316 }
317 }
318 _ => {
319 Ok((HeaderMap::new(), Validated::NotCors))
321 }
322 }
323 }
324
325 fn is_method_allowed(&self, header: &HeaderValue) -> bool {
326 http::Method::from_bytes(header.as_bytes())
327 .map(|method| self.cors.allowed_methods.contains(&method))
328 .unwrap_or(false)
329 }
330
331 fn is_header_allowed(&self, header: &str) -> bool {
332 if header.is_empty() {
333 return false;
334 }
335 HeaderName::from_bytes(header.as_bytes())
336 .map(|header| self.cors.allowed_headers.contains(&header))
337 .unwrap_or(false)
338 }
339
340 fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
341 if origin.is_empty() {
342 return false;
343 }
344 if let Some(ref allowed) = self.cors.origins {
345 allowed.contains(origin)
346 } else {
347 true
348 }
349 }
350
351 fn append_preflight_headers(&self, headers: &mut HeaderMap) {
352 headers.typed_insert(self.allowed_headers.clone());
353 headers.typed_insert(self.exposed_headers.clone());
354 headers.typed_insert(self.methods_header.clone());
355
356 if let Some(max_age) = self.cors.max_age {
357 headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
358 }
359 }
360}
361
362pub trait IntoOrigin {
364 fn into_origin(self) -> Origin;
366}
367
368impl IntoOrigin for &str {
369 fn into_origin(self) -> Origin {
370 let mut parts = self.splitn(2, "://");
371 let scheme = parts.next().expect("cors::into_origin: missing url scheme");
372 let rest = parts.next().expect("cors::into_origin: missing url scheme");
373
374 Origin::try_from_parts(scheme, rest, None).expect("cors::into_origin: invalid Origin")
375 }
376}
377
378pub(crate) fn init(
380 cors_allow_origins: &str,
381 cors_allow_headers: &str,
382 cors_expose_headers: &str,
383 handler_opts: &mut RequestHandlerOpts,
384) {
385 handler_opts.cors = new(
386 cors_allow_origins.trim(),
387 cors_allow_headers.trim(),
388 cors_expose_headers.trim(),
389 );
390}
391
392#[derive(Clone)]
395pub(crate) struct CorsHeaders(pub(crate) HeaderMap);
396
397pub(crate) fn pre_process<T>(
399 opts: &RequestHandlerOpts,
400 req: &mut Request<T>,
401) -> Option<Result<Response<Body>, Error>> {
402 let cors = opts.cors.as_ref()?;
403 match cors.check_request(req.method(), req.headers()) {
404 Ok((headers, state)) => {
405 tracing::debug!("cors state: {:?}", state);
406 if !headers.is_empty() {
408 req.extensions_mut().insert(CorsHeaders(headers));
409 }
410 None
411 }
412 Err(err) => {
413 tracing::error!("cors error kind: {:?}", err);
414 Some(error_page::error_response(
415 req.uri(),
416 req.method(),
417 &StatusCode::FORBIDDEN,
418 &opts.page404,
419 &opts.page50x,
420 ))
421 }
422 }
423}
424
425pub(crate) fn post_process<T>(
427 opts: &RequestHandlerOpts,
428 req: &Request<T>,
429 mut resp: Response<Body>,
430) -> Result<Response<Body>, Error> {
431 if opts.cors.is_some()
432 && let Some(cors_headers) = req.extensions().get::<CorsHeaders>()
433 {
434 for (k, v) in cors_headers.0.iter() {
435 resp.headers_mut().insert(k, v.to_owned());
436 }
437 resp.headers_mut().insert(
438 hyper::header::VARY,
439 HeaderValue::from_name(hyper::header::ORIGIN),
440 );
441 resp.headers_mut().remove(http::header::ALLOW);
442 }
443 Ok(resp)
444}
445
446#[cfg(test)]
447mod tests {
448 use super::{Configured, Cors, post_process, pre_process};
449 use crate::{Error, handler::RequestHandlerOpts};
450 use hyper::{Body, Request, Response, StatusCode};
451
452 fn make_request(method: &str, origin: &str) -> Request<Body> {
453 let mut builder = Request::builder();
454 if !origin.is_empty() {
455 builder = builder.header("Origin", origin);
456 }
457 builder.method(method).uri("/").body(Body::empty()).unwrap()
458 }
459
460 fn make_response() -> Response<Body> {
461 Response::builder().body(Body::empty()).unwrap()
462 }
463
464 fn make_cors_config() -> Option<Configured> {
465 Cors::build(Some(
466 Cors::new()
467 .allow_origins(vec!["https://example.com/"])
468 .allow_headers(vec!["X-Allowed"])
469 .allow_methods(vec!["GET", "HEAD"]),
470 ))
471 }
472
473 fn get_allowed_origin(resp: Response<Body>) -> Option<String> {
474 resp.headers()
475 .get("Access-Control-Allow-Origin")
476 .and_then(|v| v.to_str().ok())
477 .map(|s| s.to_owned())
478 }
479
480 fn is_403(result: Option<Result<Response<Body>, Error>>) -> bool {
481 if let Some(Ok(response)) = result {
482 response.status() == StatusCode::FORBIDDEN
483 } else {
484 false
485 }
486 }
487
488 #[test]
489 fn test_cors_disabled() -> Result<(), Error> {
490 let opts = RequestHandlerOpts {
491 cors: None,
492 ..Default::default()
493 };
494 let mut req = make_request("GET", "https://example.com/");
495
496 assert!(pre_process(&opts, &mut req).is_none());
497
498 let resp = post_process(&opts, &req, make_response())?;
499 assert_eq!(get_allowed_origin(resp), None);
500
501 Ok(())
502 }
503
504 #[test]
505 fn test_non_cors_request() -> Result<(), Error> {
506 let opts = RequestHandlerOpts {
507 cors: make_cors_config(),
508 ..Default::default()
509 };
510 let mut req = make_request("GET", "");
511
512 assert!(pre_process(&opts, &mut req).is_none());
513
514 let resp = post_process(&opts, &req, make_response())?;
515 assert_eq!(get_allowed_origin(resp), None);
516
517 Ok(())
518 }
519
520 #[test]
521 fn test_forbidden_request() {
522 let opts = RequestHandlerOpts {
523 cors: make_cors_config(),
524 ..Default::default()
525 };
526
527 assert!(is_403(pre_process(
528 &opts,
529 &mut make_request("GET", "https://example.info")
530 )));
531 assert!(is_403(pre_process(
532 &opts,
533 &mut make_request("OPTIONS", "https://example.com")
534 )));
535
536 let mut req = make_request("OPTIONS", "https://example.com");
537 req.headers_mut()
538 .insert("Access-Control-Request-Method", "POST".try_into().unwrap());
539 assert!(is_403(pre_process(&opts, &mut req)));
540
541 let mut req = make_request("OPTIONS", "https://example.com");
542 req.headers_mut()
543 .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
544 req.headers_mut().insert(
545 "Access-Control-Request-Headers",
546 "X-Forbidden".try_into().unwrap(),
547 );
548 assert!(is_403(pre_process(&opts, &mut req)));
549 }
550
551 #[test]
552 fn test_allowed_request() -> Result<(), Error> {
553 let opts = RequestHandlerOpts {
554 cors: make_cors_config(),
555 ..Default::default()
556 };
557
558 let mut req = make_request("GET", "https://example.com");
559 assert!(pre_process(&opts, &mut req).is_none());
560
561 let resp = post_process(&opts, &req, make_response())?;
562 assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
563
564 let mut req = make_request("GET", "https://example.com");
565 req.headers_mut()
566 .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
567 req.headers_mut().insert(
568 "Access-Control-Request-Headers",
569 "X-Allowed".try_into().unwrap(),
570 );
571 assert!(pre_process(&opts, &mut req).is_none());
572
573 let resp = post_process(&opts, &req, make_response())?;
574 assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
575
576 Ok(())
577 }
578}