static_web_server/
cors.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! CORS module to handle incoming requests.
7//!
8
9// Part of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/filters/cors.rs
10
11use headers::{
12    AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMap,
13    HeaderMapExt, HeaderName, HeaderValue, Origin,
14};
15use http::header;
16use hyper::{Body, Request, Response, StatusCode};
17use std::collections::HashSet;
18
19use crate::{Error, error_page, handler::RequestHandlerOpts};
20
21/// It defines CORS instance.
22#[derive(Clone, Debug)]
23pub struct Cors {
24    allowed_headers: HashSet<HeaderName>,
25    exposed_headers: HashSet<HeaderName>,
26    max_age: Option<u64>,
27    allowed_methods: HashSet<http::Method>,
28    origins: Option<HashSet<HeaderValue>>,
29}
30
31/// It builds a new CORS instance.
32pub fn new(
33    origins_str: &str,
34    allow_headers_str: &str,
35    expose_headers_str: &str,
36) -> Option<Configured> {
37    let cors = Cors::new();
38    let cors = if origins_str.is_empty() {
39        None
40    } else {
41        let [allow_headers_vec, expose_headers_vec] =
42            [allow_headers_str, expose_headers_str].map(|s| {
43                if s.is_empty() {
44                    vec!["origin", "content-type"]
45                } else {
46                    s.split(',').map(|s| s.trim()).collect::<Vec<_>>()
47                }
48            });
49        let [allow_headers_str, expose_headers_str] =
50            [&allow_headers_vec, &expose_headers_vec].map(|v| v.join(","));
51
52        let cors_res = if origins_str == "*" {
53            Some(
54                cors.allow_any_origin()
55                    .allow_headers(allow_headers_vec)
56                    .expose_headers(expose_headers_vec)
57                    .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
58            )
59        } else {
60            let hosts = origins_str.split(',').map(|s| s.trim()).collect::<Vec<_>>();
61            if hosts.is_empty() {
62                None
63            } else {
64                Some(
65                    cors.allow_origins(hosts)
66                        .allow_headers(allow_headers_vec)
67                        .expose_headers(expose_headers_vec)
68                        .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
69                )
70            }
71        };
72
73        if cors_res.is_some() {
74            tracing::info!(
75                "cors enabled=true, allow_methods=[GET,HEAD,OPTIONS], allow_origins={}, allow_headers=[{}], expose_headers=[{}]",
76                origins_str,
77                allow_headers_str,
78                expose_headers_str,
79            );
80        }
81        cors_res
82    };
83
84    Cors::build(cors)
85}
86
87impl Cors {
88    /// Creates a new Cors instance.
89    pub fn new() -> Self {
90        Self {
91            origins: None,
92            allowed_headers: HashSet::new(),
93            exposed_headers: HashSet::new(),
94            allowed_methods: HashSet::new(),
95            max_age: None,
96        }
97    }
98
99    /// Adds multiple methods to the existing list of allowed request methods.
100    ///
101    /// # Panics
102    ///
103    /// Panics if the provided argument is not a valid `http::Method`.
104    pub fn allow_methods<I>(mut self, methods: I) -> Self
105    where
106        I: IntoIterator,
107        http::Method: TryFrom<I::Item>,
108    {
109        let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
110            Ok(m) => m,
111            Err(_) => panic!("cors: illegal method"),
112        });
113        self.allowed_methods.extend(iter);
114        self
115    }
116
117    /// Sets that *any* `Origin` header is allowed.
118    ///
119    /// # Warning
120    ///
121    /// This can allow websites you didn't intend to access this resource,
122    /// it is usually better to set an explicit list.
123    pub fn allow_any_origin(mut self) -> Self {
124        self.origins = None;
125        self
126    }
127
128    /// Add multiple origins to the existing list of allowed `Origin`s.
129    ///
130    /// # Panics
131    ///
132    /// Panics if the provided argument is not a valid `Origin`.
133    pub fn allow_origins<I>(mut self, origins: I) -> Self
134    where
135        I: IntoIterator,
136        I::Item: IntoOrigin,
137    {
138        let iter = origins
139            .into_iter()
140            .map(IntoOrigin::into_origin)
141            .map(|origin| {
142                origin
143                    .to_string()
144                    .parse()
145                    .expect("cors: Origin is always a valid HeaderValue")
146            });
147
148        self.origins.get_or_insert_with(HashSet::new).extend(iter);
149        self
150    }
151
152    /// Adds multiple headers to the list of allowed request headers.
153    ///
154    /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`.
155    ///
156    /// # Panics
157    ///
158    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
159    pub fn allow_headers<I>(mut self, headers: I) -> Self
160    where
161        I: IntoIterator,
162        HeaderName: TryFrom<I::Item>,
163    {
164        let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
165            Ok(h) => h,
166            Err(_) => panic!("cors: illegal Header"),
167        });
168        self.allowed_headers.extend(iter);
169        self
170    }
171
172    /// Adds multiple headers to the list of exposed request headers.
173    ///
174    /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`.
175    ///
176    /// # Panics
177    ///
178    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
179    pub fn expose_headers<I>(mut self, headers: I) -> Self
180    where
181        I: IntoIterator,
182        HeaderName: TryFrom<I::Item>,
183    {
184        let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
185            Ok(h) => h,
186            Err(_) => panic!("cors: illegal Header"),
187        });
188        self.exposed_headers.extend(iter);
189        self
190    }
191
192    /// Builds the `Cors` wrapper from the configured settings.
193    pub fn build(cors: Option<Cors>) -> Option<Configured> {
194        cors.as_ref()?;
195        let cors = cors?;
196
197        let allowed_headers = cors.allowed_headers.iter().cloned().collect();
198        let exposed_headers = cors.exposed_headers.iter().cloned().collect();
199        let methods_header = cors.allowed_methods.iter().cloned().collect();
200
201        Some(Configured {
202            cors,
203            allowed_headers,
204            exposed_headers,
205            methods_header,
206        })
207    }
208}
209
210impl Default for Cors {
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216#[derive(Clone, Debug)]
217/// CORS configured.
218pub struct Configured {
219    cors: Cors,
220    allowed_headers: AccessControlAllowHeaders,
221    exposed_headers: AccessControlExposeHeaders,
222    methods_header: AccessControlAllowMethods,
223}
224
225#[derive(Debug)]
226/// Validated CORS request.
227pub enum Validated {
228    /// Validated as preflight.
229    Preflight(HeaderValue),
230    /// Validated as simple.
231    Simple(HeaderValue),
232    /// Validated as not cors.
233    NotCors,
234}
235
236#[derive(Debug, Default)]
237/// Forbidden errors.
238pub enum Forbidden {
239    /// Forbidden error origin.
240    #[default]
241    Origin,
242    /// Forbidden error method.
243    Method,
244    /// Forbidden error header.
245    Header,
246}
247
248impl Configured {
249    /// Check for the incoming CORS request.
250    pub fn check_request(
251        &self,
252        method: &http::Method,
253        headers: &HeaderMap,
254    ) -> Result<(HeaderMap, Validated), Forbidden> {
255        match (headers.get(header::ORIGIN), method) {
256            (Some(origin), &http::Method::OPTIONS) => {
257                // OPTIONS requests are preflight CORS requests...
258
259                if !self.is_origin_allowed(origin) {
260                    return Err(Forbidden::Origin);
261                }
262
263                if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
264                    if !self.is_method_allowed(req_method) {
265                        return Err(Forbidden::Method);
266                    }
267                } else {
268                    tracing::warn!(
269                        "cors: preflight request missing `access-control-request-method` header"
270                    );
271                    return Err(Forbidden::Method);
272                }
273
274                if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
275                    let headers = match req_headers.to_str() {
276                        Ok(val) => val,
277                        Err(err) => {
278                            tracing::error!(
279                                "cors: error parsing header `access-control-request-headers` value: {:?}",
280                                err,
281                            );
282                            return Err(Forbidden::Header);
283                        }
284                    };
285
286                    for header in headers.split(',') {
287                        let h = header.trim();
288                        if !self.is_header_allowed(h) {
289                            tracing::error!(
290                                "cors: header `{}` is not allowed because is missing in `cors_allow_headers` server option",
291                                h
292                            );
293                            return Err(Forbidden::Header);
294                        }
295                    }
296                }
297
298                let mut headers = HeaderMap::new();
299                self.append_preflight_headers(&mut headers);
300                headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
301
302                Ok((headers, Validated::Preflight(origin.clone())))
303            }
304            (Some(origin), _) => {
305                // Any other method, simply check for a valid origin...
306                tracing::trace!("cors origin header: {:?}", origin);
307
308                if self.is_origin_allowed(origin) {
309                    let mut headers = HeaderMap::new();
310                    self.append_preflight_headers(&mut headers);
311                    headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
312
313                    Ok((headers, Validated::Simple(origin.clone())))
314                } else {
315                    Err(Forbidden::Origin)
316                }
317            }
318            _ => {
319                // No `ORIGIN` header means this isn't CORS!
320                Ok((HeaderMap::new(), Validated::NotCors))
321            }
322        }
323    }
324
325    fn is_method_allowed(&self, header: &HeaderValue) -> bool {
326        http::Method::from_bytes(header.as_bytes())
327            .map(|method| self.cors.allowed_methods.contains(&method))
328            .unwrap_or(false)
329    }
330
331    fn is_header_allowed(&self, header: &str) -> bool {
332        if header.is_empty() {
333            return false;
334        }
335        HeaderName::from_bytes(header.as_bytes())
336            .map(|header| self.cors.allowed_headers.contains(&header))
337            .unwrap_or(false)
338    }
339
340    fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
341        if origin.is_empty() {
342            return false;
343        }
344        if let Some(ref allowed) = self.cors.origins {
345            allowed.contains(origin)
346        } else {
347            true
348        }
349    }
350
351    fn append_preflight_headers(&self, headers: &mut HeaderMap) {
352        headers.typed_insert(self.allowed_headers.clone());
353        headers.typed_insert(self.exposed_headers.clone());
354        headers.typed_insert(self.methods_header.clone());
355
356        if let Some(max_age) = self.cors.max_age {
357            headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
358        }
359    }
360}
361
362/// Cast values into the origin header.
363pub trait IntoOrigin {
364    /// Cast actual value into an origin header.
365    fn into_origin(self) -> Origin;
366}
367
368impl IntoOrigin for &str {
369    fn into_origin(self) -> Origin {
370        let mut parts = self.splitn(2, "://");
371        let scheme = parts.next().expect("cors::into_origin: missing url scheme");
372        let rest = parts.next().expect("cors::into_origin: missing url scheme");
373
374        Origin::try_from_parts(scheme, rest, None).expect("cors::into_origin: invalid Origin")
375    }
376}
377
378/// Initializes CORS settings
379pub(crate) fn init(
380    cors_allow_origins: &str,
381    cors_allow_headers: &str,
382    cors_expose_headers: &str,
383    handler_opts: &mut RequestHandlerOpts,
384) {
385    handler_opts.cors = new(
386        cors_allow_origins.trim(),
387        cors_allow_headers.trim(),
388        cors_expose_headers.trim(),
389    );
390}
391
392/// Rejects requests with wrong CORS headers
393pub(crate) fn pre_process<T>(
394    opts: &RequestHandlerOpts,
395    req: &Request<T>,
396) -> Option<Result<Response<Body>, Error>> {
397    let cors = opts.cors.as_ref()?;
398    match cors.check_request(req.method(), req.headers()) {
399        Ok((_, state)) => {
400            tracing::debug!("cors state: {:?}", state);
401            None
402        }
403        Err(err) => {
404            tracing::error!("cors error kind: {:?}", err);
405            Some(error_page::error_response(
406                req.uri(),
407                req.method(),
408                &StatusCode::FORBIDDEN,
409                &opts.page404,
410                &opts.page50x,
411            ))
412        }
413    }
414}
415
416/// Adds CORS headers to response
417pub(crate) fn post_process<T>(
418    opts: &RequestHandlerOpts,
419    req: &Request<T>,
420    mut resp: Response<Body>,
421) -> Result<Response<Body>, Error> {
422    if let Some(cors) = opts.cors.as_ref() {
423        if let Ok((headers, _)) = cors.check_request(req.method(), req.headers()) {
424            if !headers.is_empty() {
425                for (k, v) in headers.iter() {
426                    resp.headers_mut().insert(k, v.to_owned());
427                }
428                resp.headers_mut().insert(
429                    hyper::header::VARY,
430                    HeaderValue::from_name(hyper::header::ORIGIN),
431                );
432                resp.headers_mut().remove(http::header::ALLOW);
433            }
434        }
435    }
436    Ok(resp)
437}
438
439#[cfg(test)]
440mod tests {
441    use super::{Configured, Cors, post_process, pre_process};
442    use crate::{Error, handler::RequestHandlerOpts};
443    use hyper::{Body, Request, Response, StatusCode};
444
445    fn make_request(method: &str, origin: &str) -> Request<Body> {
446        let mut builder = Request::builder();
447        if !origin.is_empty() {
448            builder = builder.header("Origin", origin);
449        }
450        builder.method(method).uri("/").body(Body::empty()).unwrap()
451    }
452
453    fn make_response() -> Response<Body> {
454        Response::builder().body(Body::empty()).unwrap()
455    }
456
457    fn make_cors_config() -> Option<Configured> {
458        Cors::build(Some(
459            Cors::new()
460                .allow_origins(vec!["https://example.com/"])
461                .allow_headers(vec!["X-Allowed"])
462                .allow_methods(vec!["GET", "HEAD"]),
463        ))
464    }
465
466    fn get_allowed_origin(resp: Response<Body>) -> Option<String> {
467        resp.headers()
468            .get("Access-Control-Allow-Origin")
469            .and_then(|v| v.to_str().ok())
470            .map(|s| s.to_owned())
471    }
472
473    fn is_403(result: Option<Result<Response<Body>, Error>>) -> bool {
474        if let Some(Ok(response)) = result {
475            response.status() == StatusCode::FORBIDDEN
476        } else {
477            false
478        }
479    }
480
481    #[test]
482    fn test_cors_disabled() -> Result<(), Error> {
483        let opts = RequestHandlerOpts {
484            cors: None,
485            ..Default::default()
486        };
487        let req = make_request("GET", "https://example.com/");
488
489        assert!(pre_process(&opts, &req).is_none());
490
491        let resp = post_process(&opts, &req, make_response())?;
492        assert_eq!(get_allowed_origin(resp), None);
493
494        Ok(())
495    }
496
497    #[test]
498    fn test_non_cors_request() -> Result<(), Error> {
499        let opts = RequestHandlerOpts {
500            cors: make_cors_config(),
501            ..Default::default()
502        };
503        let req = make_request("GET", "");
504
505        assert!(pre_process(&opts, &req).is_none());
506
507        let resp = post_process(&opts, &req, make_response())?;
508        assert_eq!(get_allowed_origin(resp), None);
509
510        Ok(())
511    }
512
513    #[test]
514    fn test_forbidden_request() {
515        let opts = RequestHandlerOpts {
516            cors: make_cors_config(),
517            ..Default::default()
518        };
519
520        assert!(is_403(pre_process(
521            &opts,
522            &make_request("GET", "https://example.info")
523        )));
524        assert!(is_403(pre_process(
525            &opts,
526            &make_request("OPTIONS", "https://example.com")
527        )));
528
529        let mut req = make_request("OPTIONS", "https://example.com");
530        req.headers_mut()
531            .insert("Access-Control-Request-Method", "POST".try_into().unwrap());
532        assert!(is_403(pre_process(&opts, &req)));
533
534        let mut req = make_request("OPTIONS", "https://example.com");
535        req.headers_mut()
536            .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
537        req.headers_mut().insert(
538            "Access-Control-Request-Headers",
539            "X-Forbidden".try_into().unwrap(),
540        );
541        assert!(is_403(pre_process(&opts, &req)));
542    }
543
544    #[test]
545    fn test_allowed_request() -> Result<(), Error> {
546        let opts = RequestHandlerOpts {
547            cors: make_cors_config(),
548            ..Default::default()
549        };
550
551        let req = make_request("GET", "https://example.com");
552        assert!(pre_process(&opts, &req).is_none());
553
554        let resp = post_process(&opts, &req, make_response())?;
555        assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
556
557        let mut req = make_request("GET", "https://example.com");
558        req.headers_mut()
559            .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
560        req.headers_mut().insert(
561            "Access-Control-Request-Headers",
562            "X-Allowed".try_into().unwrap(),
563        );
564        assert!(pre_process(&opts, &req).is_none());
565
566        let resp = post_process(&opts, &req, make_response())?;
567        assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
568
569        Ok(())
570    }
571}