static_web_server/
cors.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! CORS module to handle incoming requests.
7//!
8
9// Part of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/filters/cors.rs
10
11use headers::{
12    AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMap,
13    HeaderMapExt, HeaderName, HeaderValue, Origin,
14};
15use http::header;
16use hyper::{Body, Request, Response, StatusCode};
17use std::collections::HashSet;
18
19use crate::{error_page, handler::RequestHandlerOpts, Error};
20
21/// It defines CORS instance.
22#[derive(Clone, Debug)]
23pub struct Cors {
24    allowed_headers: HashSet<HeaderName>,
25    exposed_headers: HashSet<HeaderName>,
26    max_age: Option<u64>,
27    allowed_methods: HashSet<http::Method>,
28    origins: Option<HashSet<HeaderValue>>,
29}
30
31/// It builds a new CORS instance.
32pub fn new(
33    origins_str: &str,
34    allow_headers_str: &str,
35    expose_headers_str: &str,
36) -> Option<Configured> {
37    let cors = Cors::new();
38    let cors = if origins_str.is_empty() {
39        None
40    } else {
41        let [allow_headers_vec, expose_headers_vec] =
42            [allow_headers_str, expose_headers_str].map(|s| {
43                if s.is_empty() {
44                    vec!["origin", "content-type"]
45                } else {
46                    s.split(',').map(|s| s.trim()).collect::<Vec<_>>()
47                }
48            });
49        let [allow_headers_str, expose_headers_str] =
50            [&allow_headers_vec, &expose_headers_vec].map(|v| v.join(","));
51
52        let cors_res = if origins_str == "*" {
53            Some(
54                cors.allow_any_origin()
55                    .allow_headers(allow_headers_vec)
56                    .expose_headers(expose_headers_vec)
57                    .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
58            )
59        } else {
60            let hosts = origins_str.split(',').map(|s| s.trim()).collect::<Vec<_>>();
61            if hosts.is_empty() {
62                None
63            } else {
64                Some(
65                    cors.allow_origins(hosts)
66                        .allow_headers(allow_headers_vec)
67                        .expose_headers(expose_headers_vec)
68                        .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
69                )
70            }
71        };
72
73        if cors_res.is_some() {
74            tracing::info!(
75                    "cors enabled=true, allow_methods=[GET,HEAD,OPTIONS], allow_origins={}, allow_headers=[{}], expose_headers=[{}]",
76                    origins_str,
77                    allow_headers_str,
78                    expose_headers_str,
79                );
80        }
81        cors_res
82    };
83
84    Cors::build(cors)
85}
86
87impl Cors {
88    /// Creates a new Cors instance.
89    pub fn new() -> Self {
90        Self {
91            origins: None,
92            allowed_headers: HashSet::new(),
93            exposed_headers: HashSet::new(),
94            allowed_methods: HashSet::new(),
95            max_age: None,
96        }
97    }
98
99    /// Adds multiple methods to the existing list of allowed request methods.
100    ///
101    /// # Panics
102    ///
103    /// Panics if the provided argument is not a valid `http::Method`.
104    pub fn allow_methods<I>(mut self, methods: I) -> Self
105    where
106        I: IntoIterator,
107        http::Method: TryFrom<I::Item>,
108    {
109        let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
110            Ok(m) => m,
111            Err(_) => panic!("cors: illegal method"),
112        });
113        self.allowed_methods.extend(iter);
114        self
115    }
116
117    /// Sets that *any* `Origin` header is allowed.
118    ///
119    /// # Warning
120    ///
121    /// This can allow websites you didn't intend to access this resource,
122    /// it is usually better to set an explicit list.
123    pub fn allow_any_origin(mut self) -> Self {
124        self.origins = None;
125        self
126    }
127
128    /// Add multiple origins to the existing list of allowed `Origin`s.
129    ///
130    /// # Panics
131    ///
132    /// Panics if the provided argument is not a valid `Origin`.
133    pub fn allow_origins<I>(mut self, origins: I) -> Self
134    where
135        I: IntoIterator,
136        I::Item: IntoOrigin,
137    {
138        let iter = origins
139            .into_iter()
140            .map(IntoOrigin::into_origin)
141            .map(|origin| {
142                origin
143                    .to_string()
144                    .parse()
145                    .expect("cors: Origin is always a valid HeaderValue")
146            });
147
148        self.origins.get_or_insert_with(HashSet::new).extend(iter);
149        self
150    }
151
152    /// Adds multiple headers to the list of allowed request headers.
153    ///
154    /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`.
155    ///
156    /// # Panics
157    ///
158    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
159    pub fn allow_headers<I>(mut self, headers: I) -> Self
160    where
161        I: IntoIterator,
162        HeaderName: TryFrom<I::Item>,
163    {
164        let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
165            Ok(h) => h,
166            Err(_) => panic!("cors: illegal Header"),
167        });
168        self.allowed_headers.extend(iter);
169        self
170    }
171
172    /// Adds multiple headers to the list of exposed request headers.
173    ///
174    /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`.
175    ///
176    /// # Panics
177    ///
178    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
179    pub fn expose_headers<I>(mut self, headers: I) -> Self
180    where
181        I: IntoIterator,
182        HeaderName: TryFrom<I::Item>,
183    {
184        let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
185            Ok(h) => h,
186            Err(_) => panic!("cors: illegal Header"),
187        });
188        self.exposed_headers.extend(iter);
189        self
190    }
191
192    /// Builds the `Cors` wrapper from the configured settings.
193    pub fn build(cors: Option<Cors>) -> Option<Configured> {
194        cors.as_ref()?;
195        let cors = cors?;
196
197        let allowed_headers = cors.allowed_headers.iter().cloned().collect();
198        let exposed_headers = cors.exposed_headers.iter().cloned().collect();
199        let methods_header = cors.allowed_methods.iter().cloned().collect();
200
201        Some(Configured {
202            cors,
203            allowed_headers,
204            exposed_headers,
205            methods_header,
206        })
207    }
208}
209
210impl Default for Cors {
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216#[derive(Clone, Debug)]
217/// CORS configured.
218pub struct Configured {
219    cors: Cors,
220    allowed_headers: AccessControlAllowHeaders,
221    exposed_headers: AccessControlExposeHeaders,
222    methods_header: AccessControlAllowMethods,
223}
224
225#[derive(Debug)]
226/// Validated CORS request.
227pub enum Validated {
228    /// Validated as preflight.
229    Preflight(HeaderValue),
230    /// Validated as simple.
231    Simple(HeaderValue),
232    /// Validated as not cors.
233    NotCors,
234}
235
236#[derive(Debug)]
237/// Forbidden errors.
238pub enum Forbidden {
239    /// Forbidden error origin.
240    Origin,
241    /// Forbidden error method.
242    Method,
243    /// Forbidden error header.
244    Header,
245}
246
247impl Default for Forbidden {
248    fn default() -> Self {
249        Self::Origin
250    }
251}
252
253impl Configured {
254    /// Check for the incoming CORS request.
255    pub fn check_request(
256        &self,
257        method: &http::Method,
258        headers: &HeaderMap,
259    ) -> Result<(HeaderMap, Validated), Forbidden> {
260        match (headers.get(header::ORIGIN), method) {
261            (Some(origin), &http::Method::OPTIONS) => {
262                // OPTIONS requests are preflight CORS requests...
263
264                if !self.is_origin_allowed(origin) {
265                    return Err(Forbidden::Origin);
266                }
267
268                if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
269                    if !self.is_method_allowed(req_method) {
270                        return Err(Forbidden::Method);
271                    }
272                } else {
273                    tracing::warn!(
274                        "cors: preflight request missing `access-control-request-method` header"
275                    );
276                    return Err(Forbidden::Method);
277                }
278
279                if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
280                    let headers = match req_headers.to_str() {
281                        Ok(val) => val,
282                        Err(err) => {
283                            tracing::error!(
284                                "cors: error parsing header `access-control-request-headers` value: {:?}",
285                                err,
286                            );
287                            return Err(Forbidden::Header);
288                        }
289                    };
290
291                    for header in headers.split(',') {
292                        let h = header.trim();
293                        if !self.is_header_allowed(h) {
294                            tracing::error!(
295                                "cors: header `{}` is not allowed because is missing in `cors_allow_headers` server option", h
296                            );
297                            return Err(Forbidden::Header);
298                        }
299                    }
300                }
301
302                let mut headers = HeaderMap::new();
303                self.append_preflight_headers(&mut headers);
304                headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
305
306                Ok((headers, Validated::Preflight(origin.clone())))
307            }
308            (Some(origin), _) => {
309                // Any other method, simply check for a valid origin...
310                tracing::trace!("cors origin header: {:?}", origin);
311
312                if self.is_origin_allowed(origin) {
313                    let mut headers = HeaderMap::new();
314                    self.append_preflight_headers(&mut headers);
315                    headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
316
317                    Ok((headers, Validated::Simple(origin.clone())))
318                } else {
319                    Err(Forbidden::Origin)
320                }
321            }
322            _ => {
323                // No `ORIGIN` header means this isn't CORS!
324                Ok((HeaderMap::new(), Validated::NotCors))
325            }
326        }
327    }
328
329    fn is_method_allowed(&self, header: &HeaderValue) -> bool {
330        http::Method::from_bytes(header.as_bytes())
331            .map(|method| self.cors.allowed_methods.contains(&method))
332            .unwrap_or(false)
333    }
334
335    fn is_header_allowed(&self, header: &str) -> bool {
336        if header.is_empty() {
337            return false;
338        }
339        HeaderName::from_bytes(header.as_bytes())
340            .map(|header| self.cors.allowed_headers.contains(&header))
341            .unwrap_or(false)
342    }
343
344    fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
345        if origin.is_empty() {
346            return false;
347        }
348        if let Some(ref allowed) = self.cors.origins {
349            allowed.contains(origin)
350        } else {
351            true
352        }
353    }
354
355    fn append_preflight_headers(&self, headers: &mut HeaderMap) {
356        headers.typed_insert(self.allowed_headers.clone());
357        headers.typed_insert(self.exposed_headers.clone());
358        headers.typed_insert(self.methods_header.clone());
359
360        if let Some(max_age) = self.cors.max_age {
361            headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
362        }
363    }
364}
365
366/// Cast values into the origin header.
367pub trait IntoOrigin {
368    /// Cast actual value into an origin header.
369    fn into_origin(self) -> Origin;
370}
371
372impl IntoOrigin for &str {
373    fn into_origin(self) -> Origin {
374        let mut parts = self.splitn(2, "://");
375        let scheme = parts.next().expect("cors::into_origin: missing url scheme");
376        let rest = parts.next().expect("cors::into_origin: missing url scheme");
377
378        Origin::try_from_parts(scheme, rest, None).expect("cors::into_origin: invalid Origin")
379    }
380}
381
382/// Initializes CORS settings
383pub(crate) fn init(
384    cors_allow_origins: &str,
385    cors_allow_headers: &str,
386    cors_expose_headers: &str,
387    handler_opts: &mut RequestHandlerOpts,
388) {
389    handler_opts.cors = new(
390        cors_allow_origins.trim(),
391        cors_allow_headers.trim(),
392        cors_expose_headers.trim(),
393    );
394}
395
396/// Rejects requests with wrong CORS headers
397pub(crate) fn pre_process<T>(
398    opts: &RequestHandlerOpts,
399    req: &Request<T>,
400) -> Option<Result<Response<Body>, Error>> {
401    let cors = opts.cors.as_ref()?;
402    match cors.check_request(req.method(), req.headers()) {
403        Ok((_, state)) => {
404            tracing::debug!("cors state: {:?}", state);
405            None
406        }
407        Err(err) => {
408            tracing::error!("cors error kind: {:?}", err);
409            Some(error_page::error_response(
410                req.uri(),
411                req.method(),
412                &StatusCode::FORBIDDEN,
413                &opts.page404,
414                &opts.page50x,
415            ))
416        }
417    }
418}
419
420/// Adds CORS headers to response
421pub(crate) fn post_process<T>(
422    opts: &RequestHandlerOpts,
423    req: &Request<T>,
424    mut resp: Response<Body>,
425) -> Result<Response<Body>, Error> {
426    if let Some(cors) = opts.cors.as_ref() {
427        if let Ok((headers, _)) = cors.check_request(req.method(), req.headers()) {
428            if !headers.is_empty() {
429                for (k, v) in headers.iter() {
430                    resp.headers_mut().insert(k, v.to_owned());
431                }
432                resp.headers_mut().insert(
433                    hyper::header::VARY,
434                    HeaderValue::from_name(hyper::header::ORIGIN),
435                );
436                resp.headers_mut().remove(http::header::ALLOW);
437            }
438        }
439    }
440    Ok(resp)
441}
442
443#[cfg(test)]
444mod tests {
445    use super::{post_process, pre_process, Configured, Cors};
446    use crate::{handler::RequestHandlerOpts, Error};
447    use hyper::{Body, Request, Response, StatusCode};
448
449    fn make_request(method: &str, origin: &str) -> Request<Body> {
450        let mut builder = Request::builder();
451        if !origin.is_empty() {
452            builder = builder.header("Origin", origin);
453        }
454        builder.method(method).uri("/").body(Body::empty()).unwrap()
455    }
456
457    fn make_response() -> Response<Body> {
458        Response::builder().body(Body::empty()).unwrap()
459    }
460
461    fn make_cors_config() -> Option<Configured> {
462        Cors::build(Some(
463            Cors::new()
464                .allow_origins(vec!["https://example.com/"])
465                .allow_headers(vec!["X-Allowed"])
466                .allow_methods(vec!["GET", "HEAD"]),
467        ))
468    }
469
470    fn get_allowed_origin(resp: Response<Body>) -> Option<String> {
471        resp.headers()
472            .get("Access-Control-Allow-Origin")
473            .and_then(|v| v.to_str().ok())
474            .map(|s| s.to_owned())
475    }
476
477    fn is_403(result: Option<Result<Response<Body>, Error>>) -> bool {
478        if let Some(Ok(response)) = result {
479            response.status() == StatusCode::FORBIDDEN
480        } else {
481            false
482        }
483    }
484
485    #[test]
486    fn test_cors_disabled() -> Result<(), Error> {
487        let opts = RequestHandlerOpts {
488            cors: None,
489            ..Default::default()
490        };
491        let req = make_request("GET", "https://example.com/");
492
493        assert!(pre_process(&opts, &req).is_none());
494
495        let resp = post_process(&opts, &req, make_response())?;
496        assert_eq!(get_allowed_origin(resp), None);
497
498        Ok(())
499    }
500
501    #[test]
502    fn test_non_cors_request() -> Result<(), Error> {
503        let opts = RequestHandlerOpts {
504            cors: make_cors_config(),
505            ..Default::default()
506        };
507        let req = make_request("GET", "");
508
509        assert!(pre_process(&opts, &req).is_none());
510
511        let resp = post_process(&opts, &req, make_response())?;
512        assert_eq!(get_allowed_origin(resp), None);
513
514        Ok(())
515    }
516
517    #[test]
518    fn test_forbidden_request() {
519        let opts = RequestHandlerOpts {
520            cors: make_cors_config(),
521            ..Default::default()
522        };
523
524        assert!(is_403(pre_process(
525            &opts,
526            &make_request("GET", "https://example.info")
527        )));
528        assert!(is_403(pre_process(
529            &opts,
530            &make_request("OPTIONS", "https://example.com")
531        )));
532
533        let mut req = make_request("OPTIONS", "https://example.com");
534        req.headers_mut()
535            .insert("Access-Control-Request-Method", "POST".try_into().unwrap());
536        assert!(is_403(pre_process(&opts, &req)));
537
538        let mut req = make_request("OPTIONS", "https://example.com");
539        req.headers_mut()
540            .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
541        req.headers_mut().insert(
542            "Access-Control-Request-Headers",
543            "X-Forbidden".try_into().unwrap(),
544        );
545        assert!(is_403(pre_process(&opts, &req)));
546    }
547
548    #[test]
549    fn test_allowed_request() -> Result<(), Error> {
550        let opts = RequestHandlerOpts {
551            cors: make_cors_config(),
552            ..Default::default()
553        };
554
555        let req = make_request("GET", "https://example.com");
556        assert!(pre_process(&opts, &req).is_none());
557
558        let resp = post_process(&opts, &req, make_response())?;
559        assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
560
561        let mut req = make_request("GET", "https://example.com");
562        req.headers_mut()
563            .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
564        req.headers_mut().insert(
565            "Access-Control-Request-Headers",
566            "X-Allowed".try_into().unwrap(),
567        );
568        assert!(pre_process(&opts, &req).is_none());
569
570        let resp = post_process(&opts, &req, make_response())?;
571        assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
572
573        Ok(())
574    }
575}