1use headers::{
12 AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMap,
13 HeaderMapExt, HeaderName, HeaderValue, Origin,
14};
15use http::header;
16use hyper::{Body, Request, Response, StatusCode};
17use std::collections::HashSet;
18
19use crate::{error_page, handler::RequestHandlerOpts, Error};
20
21#[derive(Clone, Debug)]
23pub struct Cors {
24 allowed_headers: HashSet<HeaderName>,
25 exposed_headers: HashSet<HeaderName>,
26 max_age: Option<u64>,
27 allowed_methods: HashSet<http::Method>,
28 origins: Option<HashSet<HeaderValue>>,
29}
30
31pub fn new(
33 origins_str: &str,
34 allow_headers_str: &str,
35 expose_headers_str: &str,
36) -> Option<Configured> {
37 let cors = Cors::new();
38 let cors = if origins_str.is_empty() {
39 None
40 } else {
41 let [allow_headers_vec, expose_headers_vec] =
42 [allow_headers_str, expose_headers_str].map(|s| {
43 if s.is_empty() {
44 vec!["origin", "content-type"]
45 } else {
46 s.split(',').map(|s| s.trim()).collect::<Vec<_>>()
47 }
48 });
49 let [allow_headers_str, expose_headers_str] =
50 [&allow_headers_vec, &expose_headers_vec].map(|v| v.join(","));
51
52 let cors_res = if origins_str == "*" {
53 Some(
54 cors.allow_any_origin()
55 .allow_headers(allow_headers_vec)
56 .expose_headers(expose_headers_vec)
57 .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
58 )
59 } else {
60 let hosts = origins_str.split(',').map(|s| s.trim()).collect::<Vec<_>>();
61 if hosts.is_empty() {
62 None
63 } else {
64 Some(
65 cors.allow_origins(hosts)
66 .allow_headers(allow_headers_vec)
67 .expose_headers(expose_headers_vec)
68 .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
69 )
70 }
71 };
72
73 if cors_res.is_some() {
74 tracing::info!(
75 "cors enabled=true, allow_methods=[GET,HEAD,OPTIONS], allow_origins={}, allow_headers=[{}], expose_headers=[{}]",
76 origins_str,
77 allow_headers_str,
78 expose_headers_str,
79 );
80 }
81 cors_res
82 };
83
84 Cors::build(cors)
85}
86
87impl Cors {
88 pub fn new() -> Self {
90 Self {
91 origins: None,
92 allowed_headers: HashSet::new(),
93 exposed_headers: HashSet::new(),
94 allowed_methods: HashSet::new(),
95 max_age: None,
96 }
97 }
98
99 pub fn allow_methods<I>(mut self, methods: I) -> Self
105 where
106 I: IntoIterator,
107 http::Method: TryFrom<I::Item>,
108 {
109 let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
110 Ok(m) => m,
111 Err(_) => panic!("cors: illegal method"),
112 });
113 self.allowed_methods.extend(iter);
114 self
115 }
116
117 pub fn allow_any_origin(mut self) -> Self {
124 self.origins = None;
125 self
126 }
127
128 pub fn allow_origins<I>(mut self, origins: I) -> Self
134 where
135 I: IntoIterator,
136 I::Item: IntoOrigin,
137 {
138 let iter = origins
139 .into_iter()
140 .map(IntoOrigin::into_origin)
141 .map(|origin| {
142 origin
143 .to_string()
144 .parse()
145 .expect("cors: Origin is always a valid HeaderValue")
146 });
147
148 self.origins.get_or_insert_with(HashSet::new).extend(iter);
149 self
150 }
151
152 pub fn allow_headers<I>(mut self, headers: I) -> Self
160 where
161 I: IntoIterator,
162 HeaderName: TryFrom<I::Item>,
163 {
164 let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
165 Ok(h) => h,
166 Err(_) => panic!("cors: illegal Header"),
167 });
168 self.allowed_headers.extend(iter);
169 self
170 }
171
172 pub fn expose_headers<I>(mut self, headers: I) -> Self
180 where
181 I: IntoIterator,
182 HeaderName: TryFrom<I::Item>,
183 {
184 let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
185 Ok(h) => h,
186 Err(_) => panic!("cors: illegal Header"),
187 });
188 self.exposed_headers.extend(iter);
189 self
190 }
191
192 pub fn build(cors: Option<Cors>) -> Option<Configured> {
194 cors.as_ref()?;
195 let cors = cors?;
196
197 let allowed_headers = cors.allowed_headers.iter().cloned().collect();
198 let exposed_headers = cors.exposed_headers.iter().cloned().collect();
199 let methods_header = cors.allowed_methods.iter().cloned().collect();
200
201 Some(Configured {
202 cors,
203 allowed_headers,
204 exposed_headers,
205 methods_header,
206 })
207 }
208}
209
210impl Default for Cors {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216#[derive(Clone, Debug)]
217pub struct Configured {
219 cors: Cors,
220 allowed_headers: AccessControlAllowHeaders,
221 exposed_headers: AccessControlExposeHeaders,
222 methods_header: AccessControlAllowMethods,
223}
224
225#[derive(Debug)]
226pub enum Validated {
228 Preflight(HeaderValue),
230 Simple(HeaderValue),
232 NotCors,
234}
235
236#[derive(Debug)]
237pub enum Forbidden {
239 Origin,
241 Method,
243 Header,
245}
246
247impl Default for Forbidden {
248 fn default() -> Self {
249 Self::Origin
250 }
251}
252
253impl Configured {
254 pub fn check_request(
256 &self,
257 method: &http::Method,
258 headers: &HeaderMap,
259 ) -> Result<(HeaderMap, Validated), Forbidden> {
260 match (headers.get(header::ORIGIN), method) {
261 (Some(origin), &http::Method::OPTIONS) => {
262 if !self.is_origin_allowed(origin) {
265 return Err(Forbidden::Origin);
266 }
267
268 if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
269 if !self.is_method_allowed(req_method) {
270 return Err(Forbidden::Method);
271 }
272 } else {
273 tracing::warn!(
274 "cors: preflight request missing `access-control-request-method` header"
275 );
276 return Err(Forbidden::Method);
277 }
278
279 if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
280 let headers = match req_headers.to_str() {
281 Ok(val) => val,
282 Err(err) => {
283 tracing::error!(
284 "cors: error parsing header `access-control-request-headers` value: {:?}",
285 err,
286 );
287 return Err(Forbidden::Header);
288 }
289 };
290
291 for header in headers.split(',') {
292 let h = header.trim();
293 if !self.is_header_allowed(h) {
294 tracing::error!(
295 "cors: header `{}` is not allowed because is missing in `cors_allow_headers` server option", h
296 );
297 return Err(Forbidden::Header);
298 }
299 }
300 }
301
302 let mut headers = HeaderMap::new();
303 self.append_preflight_headers(&mut headers);
304 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
305
306 Ok((headers, Validated::Preflight(origin.clone())))
307 }
308 (Some(origin), _) => {
309 tracing::trace!("cors origin header: {:?}", origin);
311
312 if self.is_origin_allowed(origin) {
313 let mut headers = HeaderMap::new();
314 self.append_preflight_headers(&mut headers);
315 headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
316
317 Ok((headers, Validated::Simple(origin.clone())))
318 } else {
319 Err(Forbidden::Origin)
320 }
321 }
322 _ => {
323 Ok((HeaderMap::new(), Validated::NotCors))
325 }
326 }
327 }
328
329 fn is_method_allowed(&self, header: &HeaderValue) -> bool {
330 http::Method::from_bytes(header.as_bytes())
331 .map(|method| self.cors.allowed_methods.contains(&method))
332 .unwrap_or(false)
333 }
334
335 fn is_header_allowed(&self, header: &str) -> bool {
336 if header.is_empty() {
337 return false;
338 }
339 HeaderName::from_bytes(header.as_bytes())
340 .map(|header| self.cors.allowed_headers.contains(&header))
341 .unwrap_or(false)
342 }
343
344 fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
345 if origin.is_empty() {
346 return false;
347 }
348 if let Some(ref allowed) = self.cors.origins {
349 allowed.contains(origin)
350 } else {
351 true
352 }
353 }
354
355 fn append_preflight_headers(&self, headers: &mut HeaderMap) {
356 headers.typed_insert(self.allowed_headers.clone());
357 headers.typed_insert(self.exposed_headers.clone());
358 headers.typed_insert(self.methods_header.clone());
359
360 if let Some(max_age) = self.cors.max_age {
361 headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
362 }
363 }
364}
365
366pub trait IntoOrigin {
368 fn into_origin(self) -> Origin;
370}
371
372impl IntoOrigin for &str {
373 fn into_origin(self) -> Origin {
374 let mut parts = self.splitn(2, "://");
375 let scheme = parts.next().expect("cors::into_origin: missing url scheme");
376 let rest = parts.next().expect("cors::into_origin: missing url scheme");
377
378 Origin::try_from_parts(scheme, rest, None).expect("cors::into_origin: invalid Origin")
379 }
380}
381
382pub(crate) fn init(
384 cors_allow_origins: &str,
385 cors_allow_headers: &str,
386 cors_expose_headers: &str,
387 handler_opts: &mut RequestHandlerOpts,
388) {
389 handler_opts.cors = new(
390 cors_allow_origins.trim(),
391 cors_allow_headers.trim(),
392 cors_expose_headers.trim(),
393 );
394}
395
396pub(crate) fn pre_process<T>(
398 opts: &RequestHandlerOpts,
399 req: &Request<T>,
400) -> Option<Result<Response<Body>, Error>> {
401 let cors = opts.cors.as_ref()?;
402 match cors.check_request(req.method(), req.headers()) {
403 Ok((_, state)) => {
404 tracing::debug!("cors state: {:?}", state);
405 None
406 }
407 Err(err) => {
408 tracing::error!("cors error kind: {:?}", err);
409 Some(error_page::error_response(
410 req.uri(),
411 req.method(),
412 &StatusCode::FORBIDDEN,
413 &opts.page404,
414 &opts.page50x,
415 ))
416 }
417 }
418}
419
420pub(crate) fn post_process<T>(
422 opts: &RequestHandlerOpts,
423 req: &Request<T>,
424 mut resp: Response<Body>,
425) -> Result<Response<Body>, Error> {
426 if let Some(cors) = opts.cors.as_ref() {
427 if let Ok((headers, _)) = cors.check_request(req.method(), req.headers()) {
428 if !headers.is_empty() {
429 for (k, v) in headers.iter() {
430 resp.headers_mut().insert(k, v.to_owned());
431 }
432 resp.headers_mut().insert(
433 hyper::header::VARY,
434 HeaderValue::from_name(hyper::header::ORIGIN),
435 );
436 resp.headers_mut().remove(http::header::ALLOW);
437 }
438 }
439 }
440 Ok(resp)
441}
442
443#[cfg(test)]
444mod tests {
445 use super::{post_process, pre_process, Configured, Cors};
446 use crate::{handler::RequestHandlerOpts, Error};
447 use hyper::{Body, Request, Response, StatusCode};
448
449 fn make_request(method: &str, origin: &str) -> Request<Body> {
450 let mut builder = Request::builder();
451 if !origin.is_empty() {
452 builder = builder.header("Origin", origin);
453 }
454 builder.method(method).uri("/").body(Body::empty()).unwrap()
455 }
456
457 fn make_response() -> Response<Body> {
458 Response::builder().body(Body::empty()).unwrap()
459 }
460
461 fn make_cors_config() -> Option<Configured> {
462 Cors::build(Some(
463 Cors::new()
464 .allow_origins(vec!["https://example.com/"])
465 .allow_headers(vec!["X-Allowed"])
466 .allow_methods(vec!["GET", "HEAD"]),
467 ))
468 }
469
470 fn get_allowed_origin(resp: Response<Body>) -> Option<String> {
471 resp.headers()
472 .get("Access-Control-Allow-Origin")
473 .and_then(|v| v.to_str().ok())
474 .map(|s| s.to_owned())
475 }
476
477 fn is_403(result: Option<Result<Response<Body>, Error>>) -> bool {
478 if let Some(Ok(response)) = result {
479 response.status() == StatusCode::FORBIDDEN
480 } else {
481 false
482 }
483 }
484
485 #[test]
486 fn test_cors_disabled() -> Result<(), Error> {
487 let opts = RequestHandlerOpts {
488 cors: None,
489 ..Default::default()
490 };
491 let req = make_request("GET", "https://example.com/");
492
493 assert!(pre_process(&opts, &req).is_none());
494
495 let resp = post_process(&opts, &req, make_response())?;
496 assert_eq!(get_allowed_origin(resp), None);
497
498 Ok(())
499 }
500
501 #[test]
502 fn test_non_cors_request() -> Result<(), Error> {
503 let opts = RequestHandlerOpts {
504 cors: make_cors_config(),
505 ..Default::default()
506 };
507 let req = make_request("GET", "");
508
509 assert!(pre_process(&opts, &req).is_none());
510
511 let resp = post_process(&opts, &req, make_response())?;
512 assert_eq!(get_allowed_origin(resp), None);
513
514 Ok(())
515 }
516
517 #[test]
518 fn test_forbidden_request() {
519 let opts = RequestHandlerOpts {
520 cors: make_cors_config(),
521 ..Default::default()
522 };
523
524 assert!(is_403(pre_process(
525 &opts,
526 &make_request("GET", "https://example.info")
527 )));
528 assert!(is_403(pre_process(
529 &opts,
530 &make_request("OPTIONS", "https://example.com")
531 )));
532
533 let mut req = make_request("OPTIONS", "https://example.com");
534 req.headers_mut()
535 .insert("Access-Control-Request-Method", "POST".try_into().unwrap());
536 assert!(is_403(pre_process(&opts, &req)));
537
538 let mut req = make_request("OPTIONS", "https://example.com");
539 req.headers_mut()
540 .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
541 req.headers_mut().insert(
542 "Access-Control-Request-Headers",
543 "X-Forbidden".try_into().unwrap(),
544 );
545 assert!(is_403(pre_process(&opts, &req)));
546 }
547
548 #[test]
549 fn test_allowed_request() -> Result<(), Error> {
550 let opts = RequestHandlerOpts {
551 cors: make_cors_config(),
552 ..Default::default()
553 };
554
555 let req = make_request("GET", "https://example.com");
556 assert!(pre_process(&opts, &req).is_none());
557
558 let resp = post_process(&opts, &req, make_response())?;
559 assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
560
561 let mut req = make_request("GET", "https://example.com");
562 req.headers_mut()
563 .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
564 req.headers_mut().insert(
565 "Access-Control-Request-Headers",
566 "X-Allowed".try_into().unwrap(),
567 );
568 assert!(pre_process(&opts, &req).is_none());
569
570 let resp = post_process(&opts, &req, make_response())?;
571 assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
572
573 Ok(())
574 }
575}