1use headers::{
12 AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMap,
13 HeaderMapExt, HeaderName, HeaderValue, Origin,
14};
15use http::header;
16use hyper::{Body, Request, Response, StatusCode};
17use std::collections::HashSet;
18
19use crate::{Error, error_page, handler::RequestHandlerOpts};
20
21#[derive(Clone, Debug)]
23pub struct Cors {
24 allowed_headers: HashSet<HeaderName>,
25 exposed_headers: HashSet<HeaderName>,
26 max_age: Option<u64>,
27 allowed_methods: HashSet<http::Method>,
28 origins: Option<HashSet<HeaderValue>>,
29}
30
31pub fn new(
33 origins_str: &str,
34 allow_headers_str: &str,
35 expose_headers_str: &str,
36) -> Option<Configured> {
37 let cors = Cors::new();
38 let cors = if origins_str.is_empty() {
39 None
40 } else {
41 let [allow_headers_vec, expose_headers_vec] =
42 [allow_headers_str, expose_headers_str].map(|s| {
43 if s.is_empty() {
44 vec!["origin", "content-type"]
45 } else {
46 s.split(',').map(|s| s.trim()).collect::<Vec<_>>()
47 }
48 });
49 let [allow_headers_str, expose_headers_str] =
50 [&allow_headers_vec, &expose_headers_vec].map(|v| v.join(","));
51
52 let cors_res = if origins_str == "*" {
53 Some(
54 cors.allow_any_origin()
55 .allow_headers(allow_headers_vec)
56 .expose_headers(expose_headers_vec)
57 .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
58 )
59 } else {
60 let hosts = origins_str.split(',').map(|s| s.trim()).collect::<Vec<_>>();
61 if hosts.is_empty() {
62 None
63 } else {
64 Some(
65 cors.allow_origins(hosts)
66 .allow_headers(allow_headers_vec)
67 .expose_headers(expose_headers_vec)
68 .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
69 )
70 }
71 };
72
73 if cors_res.is_some() {
74 tracing::info!(
75 "cors enabled=true, allow_methods=[GET,HEAD,OPTIONS], allow_origins={}, allow_headers=[{}], expose_headers=[{}]",
76 origins_str,
77 allow_headers_str,
78 expose_headers_str,
79 );
80 }
81 cors_res
82 };
83
84 Cors::build(cors)
85}
86
87impl Cors {
88 pub fn new() -> Self {
90 Self {
91 origins: None,
92 allowed_headers: HashSet::new(),
93 exposed_headers: HashSet::new(),
94 allowed_methods: HashSet::new(),
95 max_age: None,
96 }
97 }
98
99 pub fn allow_methods<I>(mut self, methods: I) -> Self
105 where
106 I: IntoIterator,
107 http::Method: TryFrom<I::Item>,
108 {
109 let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
110 Ok(m) => m,
111 Err(_) => panic!("cors: illegal method"),
112 });
113 self.allowed_methods.extend(iter);
114 self
115 }
116
117 pub fn allow_any_origin(mut self) -> Self {
124 self.origins = None;
125 self
126 }
127
128 pub fn allow_origins<I>(mut self, origins: I) -> Self
134 where
135 I: IntoIterator,
136 I::Item: IntoOrigin,
137 {
138 let iter = origins
139 .into_iter()
140 .map(IntoOrigin::into_origin)
141 .map(|origin| {
142 origin
143 .to_string()
144 .parse()
145 .expect("cors: Origin is always a valid HeaderValue")
146 });
147
148 self.origins.get_or_insert_with(HashSet::new).extend(iter);
149 self
150 }
151
152 pub fn allow_headers<I>(mut self, headers: I) -> Self
160 where
161 I: IntoIterator,
162 HeaderName: TryFrom<I::Item>,
163 {
164 let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
165 Ok(h) => h,
166 Err(_) => panic!("cors: illegal Header"),
167 });
168 self.allowed_headers.extend(iter);
169 self
170 }
171
172 pub fn expose_headers<I>(mut self, headers: I) -> Self
180 where
181 I: IntoIterator,
182 HeaderName: TryFrom<I::Item>,
183 {
184 let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
185 Ok(h) => h,
186 Err(_) => panic!("cors: illegal Header"),
187 });
188 self.exposed_headers.extend(iter);
189 self
190 }
191
192 pub fn build(cors: Option<Cors>) -> Option<Configured> {
194 cors.as_ref()?;
195 let cors = cors?;
196
197 let allowed_headers = cors.allowed_headers.iter().cloned().collect();
198 let exposed_headers = cors.exposed_headers.iter().cloned().collect();
199 let methods_header = cors.allowed_methods.iter().cloned().collect();
200
201 Some(Configured {
202 cors,
203 allowed_headers,
204 exposed_headers,
205 methods_header,
206 })
207 }
208}
209
210impl Default for Cors {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216#[derive(Clone, Debug)]
217pub struct Configured {
219 cors: Cors,
220 allowed_headers: AccessControlAllowHeaders,
221 exposed_headers: AccessControlExposeHeaders,
222 methods_header: AccessControlAllowMethods,
223}
224
225#[derive(Debug)]
226pub enum Validated {
228 Preflight(HeaderValue),
230 Simple(HeaderValue),
232 NotCors,
234}
235
236#[derive(Debug, Default)]
237pub enum Forbidden {
239 #[default]
241 Origin,
242 Method,
244 Header,
246}
247
248impl Configured {
249 pub fn check_request(
251 &self,
252 method: &http::Method,
253 headers: &HeaderMap,
254 ) -> Result<(HeaderMap, Validated), Forbidden> {
255 match (headers.get(header::ORIGIN), method) {
256 (Some(origin), &http::Method::OPTIONS) => {
257 if !self.is_origin_allowed(origin) {
260 return Err(Forbidden::Origin);
261 }
262
263 if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
264 if !self.is_method_allowed(req_method) {
265 return Err(Forbidden::Method);
266 }
267 } else {
268 tracing::warn!(
269 "cors: preflight request missing `access-control-request-method` header"
270 );
271 return Err(Forbidden::Method);
272 }
273
274 if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
275 let headers = match req_headers.to_str() {
276 Ok(val) => val,
277 Err(err) => {
278 tracing::error!(
279 "cors: error parsing header `access-control-request-headers` value: {:?}",
280 err,
281 );
282 return Err(Forbidden::Header);
283 }
284 };
285
286 for header in headers.split(',') {
287 let h = header.trim();
288 if !self.is_header_allowed(h) {
289 tracing::error!(
290 "cors: header `{}` is not allowed because is missing in `cors_allow_headers` server option",
291 h
292 );
293 return Err(Forbidden::Header);
294 }
295 }
296 }
297
298 let mut headers = HeaderMap::new();
299 self.append_preflight_headers(&mut headers);
300 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
301
302 Ok((headers, Validated::Preflight(origin.clone())))
303 }
304 (Some(origin), _) => {
305 tracing::trace!("cors origin header: {:?}", origin);
307
308 if self.is_origin_allowed(origin) {
309 let mut headers = HeaderMap::new();
310 self.append_preflight_headers(&mut headers);
311 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
312
313 Ok((headers, Validated::Simple(origin.clone())))
314 } else {
315 Err(Forbidden::Origin)
316 }
317 }
318 _ => {
319 Ok((HeaderMap::new(), Validated::NotCors))
321 }
322 }
323 }
324
325 fn is_method_allowed(&self, header: &HeaderValue) -> bool {
326 http::Method::from_bytes(header.as_bytes())
327 .map(|method| self.cors.allowed_methods.contains(&method))
328 .unwrap_or(false)
329 }
330
331 fn is_header_allowed(&self, header: &str) -> bool {
332 if header.is_empty() {
333 return false;
334 }
335 HeaderName::from_bytes(header.as_bytes())
336 .map(|header| self.cors.allowed_headers.contains(&header))
337 .unwrap_or(false)
338 }
339
340 fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
341 if origin.is_empty() {
342 return false;
343 }
344 if let Some(ref allowed) = self.cors.origins {
345 allowed.contains(origin)
346 } else {
347 true
348 }
349 }
350
351 fn append_preflight_headers(&self, headers: &mut HeaderMap) {
352 headers.typed_insert(self.allowed_headers.clone());
353 headers.typed_insert(self.exposed_headers.clone());
354 headers.typed_insert(self.methods_header.clone());
355
356 if let Some(max_age) = self.cors.max_age {
357 headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
358 }
359 }
360}
361
362pub trait IntoOrigin {
364 fn into_origin(self) -> Origin;
366}
367
368impl IntoOrigin for &str {
369 fn into_origin(self) -> Origin {
370 let mut parts = self.splitn(2, "://");
371 let scheme = parts.next().expect("cors::into_origin: missing url scheme");
372 let rest = parts.next().expect("cors::into_origin: missing url scheme");
373
374 Origin::try_from_parts(scheme, rest, None).expect("cors::into_origin: invalid Origin")
375 }
376}
377
378pub(crate) fn init(
380 cors_allow_origins: &str,
381 cors_allow_headers: &str,
382 cors_expose_headers: &str,
383 handler_opts: &mut RequestHandlerOpts,
384) {
385 handler_opts.cors = new(
386 cors_allow_origins.trim(),
387 cors_allow_headers.trim(),
388 cors_expose_headers.trim(),
389 );
390}
391
392pub(crate) fn pre_process<T>(
394 opts: &RequestHandlerOpts,
395 req: &Request<T>,
396) -> Option<Result<Response<Body>, Error>> {
397 let cors = opts.cors.as_ref()?;
398 match cors.check_request(req.method(), req.headers()) {
399 Ok((_, state)) => {
400 tracing::debug!("cors state: {:?}", state);
401 None
402 }
403 Err(err) => {
404 tracing::error!("cors error kind: {:?}", err);
405 Some(error_page::error_response(
406 req.uri(),
407 req.method(),
408 &StatusCode::FORBIDDEN,
409 &opts.page404,
410 &opts.page50x,
411 ))
412 }
413 }
414}
415
416pub(crate) fn post_process<T>(
418 opts: &RequestHandlerOpts,
419 req: &Request<T>,
420 mut resp: Response<Body>,
421) -> Result<Response<Body>, Error> {
422 if let Some(cors) = opts.cors.as_ref() {
423 if let Ok((headers, _)) = cors.check_request(req.method(), req.headers()) {
424 if !headers.is_empty() {
425 for (k, v) in headers.iter() {
426 resp.headers_mut().insert(k, v.to_owned());
427 }
428 resp.headers_mut().insert(
429 hyper::header::VARY,
430 HeaderValue::from_name(hyper::header::ORIGIN),
431 );
432 resp.headers_mut().remove(http::header::ALLOW);
433 }
434 }
435 }
436 Ok(resp)
437}
438
439#[cfg(test)]
440mod tests {
441 use super::{Configured, Cors, post_process, pre_process};
442 use crate::{Error, handler::RequestHandlerOpts};
443 use hyper::{Body, Request, Response, StatusCode};
444
445 fn make_request(method: &str, origin: &str) -> Request<Body> {
446 let mut builder = Request::builder();
447 if !origin.is_empty() {
448 builder = builder.header("Origin", origin);
449 }
450 builder.method(method).uri("/").body(Body::empty()).unwrap()
451 }
452
453 fn make_response() -> Response<Body> {
454 Response::builder().body(Body::empty()).unwrap()
455 }
456
457 fn make_cors_config() -> Option<Configured> {
458 Cors::build(Some(
459 Cors::new()
460 .allow_origins(vec!["https://example.com/"])
461 .allow_headers(vec!["X-Allowed"])
462 .allow_methods(vec!["GET", "HEAD"]),
463 ))
464 }
465
466 fn get_allowed_origin(resp: Response<Body>) -> Option<String> {
467 resp.headers()
468 .get("Access-Control-Allow-Origin")
469 .and_then(|v| v.to_str().ok())
470 .map(|s| s.to_owned())
471 }
472
473 fn is_403(result: Option<Result<Response<Body>, Error>>) -> bool {
474 if let Some(Ok(response)) = result {
475 response.status() == StatusCode::FORBIDDEN
476 } else {
477 false
478 }
479 }
480
481 #[test]
482 fn test_cors_disabled() -> Result<(), Error> {
483 let opts = RequestHandlerOpts {
484 cors: None,
485 ..Default::default()
486 };
487 let req = make_request("GET", "https://example.com/");
488
489 assert!(pre_process(&opts, &req).is_none());
490
491 let resp = post_process(&opts, &req, make_response())?;
492 assert_eq!(get_allowed_origin(resp), None);
493
494 Ok(())
495 }
496
497 #[test]
498 fn test_non_cors_request() -> Result<(), Error> {
499 let opts = RequestHandlerOpts {
500 cors: make_cors_config(),
501 ..Default::default()
502 };
503 let req = make_request("GET", "");
504
505 assert!(pre_process(&opts, &req).is_none());
506
507 let resp = post_process(&opts, &req, make_response())?;
508 assert_eq!(get_allowed_origin(resp), None);
509
510 Ok(())
511 }
512
513 #[test]
514 fn test_forbidden_request() {
515 let opts = RequestHandlerOpts {
516 cors: make_cors_config(),
517 ..Default::default()
518 };
519
520 assert!(is_403(pre_process(
521 &opts,
522 &make_request("GET", "https://example.info")
523 )));
524 assert!(is_403(pre_process(
525 &opts,
526 &make_request("OPTIONS", "https://example.com")
527 )));
528
529 let mut req = make_request("OPTIONS", "https://example.com");
530 req.headers_mut()
531 .insert("Access-Control-Request-Method", "POST".try_into().unwrap());
532 assert!(is_403(pre_process(&opts, &req)));
533
534 let mut req = make_request("OPTIONS", "https://example.com");
535 req.headers_mut()
536 .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
537 req.headers_mut().insert(
538 "Access-Control-Request-Headers",
539 "X-Forbidden".try_into().unwrap(),
540 );
541 assert!(is_403(pre_process(&opts, &req)));
542 }
543
544 #[test]
545 fn test_allowed_request() -> Result<(), Error> {
546 let opts = RequestHandlerOpts {
547 cors: make_cors_config(),
548 ..Default::default()
549 };
550
551 let req = make_request("GET", "https://example.com");
552 assert!(pre_process(&opts, &req).is_none());
553
554 let resp = post_process(&opts, &req, make_response())?;
555 assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
556
557 let mut req = make_request("GET", "https://example.com");
558 req.headers_mut()
559 .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
560 req.headers_mut().insert(
561 "Access-Control-Request-Headers",
562 "X-Allowed".try_into().unwrap(),
563 );
564 assert!(pre_process(&opts, &req).is_none());
565
566 let resp = post_process(&opts, &req, make_response())?;
567 assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
568
569 Ok(())
570 }
571}