Skip to main content

static_web_server/
cors.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! CORS module to handle incoming requests.
7//!
8
9// Part of the file is borrowed from https://github.com/seanmonstar/warp/blob/master/src/filters/cors.rs
10
11use headers::{
12    AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMap,
13    HeaderMapExt, HeaderName, HeaderValue, Origin,
14};
15use http::header;
16use hyper::{Body, Request, Response, StatusCode};
17use std::collections::HashSet;
18
19use crate::{Error, error_page, handler::RequestHandlerOpts};
20
21/// It defines CORS instance.
22#[derive(Clone, Debug)]
23pub struct Cors {
24    allowed_headers: HashSet<HeaderName>,
25    exposed_headers: HashSet<HeaderName>,
26    max_age: Option<u64>,
27    allowed_methods: HashSet<http::Method>,
28    origins: Option<HashSet<HeaderValue>>,
29}
30
31/// It builds a new CORS instance.
32pub fn new(
33    origins_str: &str,
34    allow_headers_str: &str,
35    expose_headers_str: &str,
36) -> Option<Configured> {
37    let cors = Cors::new();
38    let cors = if origins_str.is_empty() {
39        None
40    } else {
41        let [allow_headers_vec, expose_headers_vec] =
42            [allow_headers_str, expose_headers_str].map(|s| {
43                if s.is_empty() {
44                    vec!["origin", "content-type"]
45                } else {
46                    s.split(',').map(|s| s.trim()).collect::<Vec<_>>()
47                }
48            });
49        let [allow_headers_str, expose_headers_str] =
50            [&allow_headers_vec, &expose_headers_vec].map(|v| v.join(","));
51
52        let cors_res = if origins_str == "*" {
53            Some(
54                cors.allow_any_origin()
55                    .allow_headers(allow_headers_vec)
56                    .expose_headers(expose_headers_vec)
57                    .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
58            )
59        } else {
60            let hosts = origins_str.split(',').map(|s| s.trim()).collect::<Vec<_>>();
61            if hosts.is_empty() {
62                None
63            } else {
64                Some(
65                    cors.allow_origins(hosts)
66                        .allow_headers(allow_headers_vec)
67                        .expose_headers(expose_headers_vec)
68                        .allow_methods(vec!["GET", "HEAD", "OPTIONS"]),
69                )
70            }
71        };
72
73        if cors_res.is_some() {
74            tracing::info!(
75                "cors enabled=true, allow_methods=[GET,HEAD,OPTIONS], allow_origins={}, allow_headers=[{}], expose_headers=[{}]",
76                origins_str,
77                allow_headers_str,
78                expose_headers_str,
79            );
80        }
81        cors_res
82    };
83
84    Cors::build(cors)
85}
86
87impl Cors {
88    /// Creates a new Cors instance.
89    pub fn new() -> Self {
90        Self {
91            origins: None,
92            allowed_headers: HashSet::new(),
93            exposed_headers: HashSet::new(),
94            allowed_methods: HashSet::new(),
95            max_age: None,
96        }
97    }
98
99    /// Adds multiple methods to the existing list of allowed request methods.
100    ///
101    /// # Panics
102    ///
103    /// Panics if the provided argument is not a valid `http::Method`.
104    pub fn allow_methods<I>(mut self, methods: I) -> Self
105    where
106        I: IntoIterator,
107        http::Method: TryFrom<I::Item>,
108    {
109        let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
110            Ok(m) => m,
111            Err(_) => panic!("cors: illegal method"),
112        });
113        self.allowed_methods.extend(iter);
114        self
115    }
116
117    /// Sets that *any* `Origin` header is allowed.
118    ///
119    /// # Warning
120    ///
121    /// This can allow websites you didn't intend to access this resource,
122    /// it is usually better to set an explicit list.
123    pub fn allow_any_origin(mut self) -> Self {
124        self.origins = None;
125        self
126    }
127
128    /// Add multiple origins to the existing list of allowed `Origin`s.
129    ///
130    /// # Panics
131    ///
132    /// Panics if the provided argument is not a valid `Origin`.
133    pub fn allow_origins<I>(mut self, origins: I) -> Self
134    where
135        I: IntoIterator,
136        I::Item: IntoOrigin,
137    {
138        let iter = origins
139            .into_iter()
140            .map(IntoOrigin::into_origin)
141            .map(|origin| {
142                origin
143                    .to_string()
144                    .parse()
145                    .expect("cors: Origin is always a valid HeaderValue")
146            });
147
148        self.origins.get_or_insert_with(HashSet::new).extend(iter);
149        self
150    }
151
152    /// Adds multiple headers to the list of allowed request headers.
153    ///
154    /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`.
155    ///
156    /// # Panics
157    ///
158    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
159    pub fn allow_headers<I>(mut self, headers: I) -> Self
160    where
161        I: IntoIterator,
162        HeaderName: TryFrom<I::Item>,
163    {
164        let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
165            Ok(h) => h,
166            Err(_) => panic!("cors: illegal Header"),
167        });
168        self.allowed_headers.extend(iter);
169        self
170    }
171
172    /// Adds multiple headers to the list of exposed request headers.
173    ///
174    /// **Note**: These should match the values the browser sends via `Access-Control-Request-Headers`, e.g.`content-type`.
175    ///
176    /// # Panics
177    ///
178    /// Panics if any of the headers are not a valid `http::header::HeaderName`.
179    pub fn expose_headers<I>(mut self, headers: I) -> Self
180    where
181        I: IntoIterator,
182        HeaderName: TryFrom<I::Item>,
183    {
184        let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
185            Ok(h) => h,
186            Err(_) => panic!("cors: illegal Header"),
187        });
188        self.exposed_headers.extend(iter);
189        self
190    }
191
192    /// Builds the `Cors` wrapper from the configured settings.
193    pub fn build(cors: Option<Cors>) -> Option<Configured> {
194        cors.as_ref()?;
195        let cors = cors?;
196
197        let allowed_headers = cors.allowed_headers.iter().cloned().collect();
198        let exposed_headers = cors.exposed_headers.iter().cloned().collect();
199        let methods_header = cors.allowed_methods.iter().cloned().collect();
200
201        Some(Configured {
202            cors,
203            allowed_headers,
204            exposed_headers,
205            methods_header,
206        })
207    }
208}
209
210impl Default for Cors {
211    fn default() -> Self {
212        Self::new()
213    }
214}
215
216#[derive(Clone, Debug)]
217/// CORS configured.
218pub struct Configured {
219    cors: Cors,
220    allowed_headers: AccessControlAllowHeaders,
221    exposed_headers: AccessControlExposeHeaders,
222    methods_header: AccessControlAllowMethods,
223}
224
225#[derive(Debug)]
226/// Validated CORS request.
227pub enum Validated {
228    /// Validated as preflight.
229    Preflight(HeaderValue),
230    /// Validated as simple.
231    Simple(HeaderValue),
232    /// Validated as not cors.
233    NotCors,
234}
235
236#[derive(Debug, Default)]
237/// Forbidden errors.
238pub enum Forbidden {
239    /// Forbidden error origin.
240    #[default]
241    Origin,
242    /// Forbidden error method.
243    Method,
244    /// Forbidden error header.
245    Header,
246}
247
248impl Configured {
249    /// Check for the incoming CORS request.
250    pub fn check_request(
251        &self,
252        method: &http::Method,
253        headers: &HeaderMap,
254    ) -> Result<(HeaderMap, Validated), Forbidden> {
255        match (headers.get(header::ORIGIN), method) {
256            (Some(origin), &http::Method::OPTIONS) => {
257                // OPTIONS requests are preflight CORS requests...
258
259                if !self.is_origin_allowed(origin) {
260                    return Err(Forbidden::Origin);
261                }
262
263                if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
264                    if !self.is_method_allowed(req_method) {
265                        return Err(Forbidden::Method);
266                    }
267                } else {
268                    tracing::warn!(
269                        "cors: preflight request missing `access-control-request-method` header"
270                    );
271                    return Err(Forbidden::Method);
272                }
273
274                if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
275                    let headers = match req_headers.to_str() {
276                        Ok(val) => val,
277                        Err(err) => {
278                            tracing::error!(
279                                "cors: error parsing header `access-control-request-headers` value: {:?}",
280                                err,
281                            );
282                            return Err(Forbidden::Header);
283                        }
284                    };
285
286                    for header in headers.split(',') {
287                        let h = header.trim();
288                        if !self.is_header_allowed(h) {
289                            tracing::error!(
290                                "cors: header `{}` is not allowed because is missing in `cors_allow_headers` server option",
291                                h
292                            );
293                            return Err(Forbidden::Header);
294                        }
295                    }
296                }
297
298                let mut headers = HeaderMap::new();
299                self.append_preflight_headers(&mut headers);
300                headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
301
302                Ok((headers, Validated::Preflight(origin.clone())))
303            }
304            (Some(origin), _) => {
305                // Any other method, simply check for a valid origin...
306                tracing::trace!("cors origin header: {:?}", origin);
307
308                if self.is_origin_allowed(origin) {
309                    let mut headers = HeaderMap::new();
310                    self.append_preflight_headers(&mut headers);
311                    headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.into());
312
313                    Ok((headers, Validated::Simple(origin.clone())))
314                } else {
315                    Err(Forbidden::Origin)
316                }
317            }
318            _ => {
319                // No `ORIGIN` header means this isn't CORS!
320                Ok((HeaderMap::new(), Validated::NotCors))
321            }
322        }
323    }
324
325    fn is_method_allowed(&self, header: &HeaderValue) -> bool {
326        http::Method::from_bytes(header.as_bytes())
327            .map(|method| self.cors.allowed_methods.contains(&method))
328            .unwrap_or(false)
329    }
330
331    fn is_header_allowed(&self, header: &str) -> bool {
332        if header.is_empty() {
333            return false;
334        }
335        HeaderName::from_bytes(header.as_bytes())
336            .map(|header| self.cors.allowed_headers.contains(&header))
337            .unwrap_or(false)
338    }
339
340    fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
341        if origin.is_empty() {
342            return false;
343        }
344        if let Some(ref allowed) = self.cors.origins {
345            allowed.contains(origin)
346        } else {
347            true
348        }
349    }
350
351    fn append_preflight_headers(&self, headers: &mut HeaderMap) {
352        headers.typed_insert(self.allowed_headers.clone());
353        headers.typed_insert(self.exposed_headers.clone());
354        headers.typed_insert(self.methods_header.clone());
355
356        if let Some(max_age) = self.cors.max_age {
357            headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
358        }
359    }
360}
361
362/// Cast values into the origin header.
363pub trait IntoOrigin {
364    /// Cast actual value into an origin header.
365    fn into_origin(self) -> Origin;
366}
367
368impl IntoOrigin for &str {
369    fn into_origin(self) -> Origin {
370        let mut parts = self.splitn(2, "://");
371        let scheme = parts.next().expect("cors::into_origin: missing url scheme");
372        let rest = parts.next().expect("cors::into_origin: missing url scheme");
373
374        Origin::try_from_parts(scheme, rest, None).expect("cors::into_origin: invalid Origin")
375    }
376}
377
378/// Initializes CORS settings
379pub(crate) fn init(
380    cors_allow_origins: &str,
381    cors_allow_headers: &str,
382    cors_expose_headers: &str,
383    handler_opts: &mut RequestHandlerOpts,
384) {
385    handler_opts.cors = new(
386        cors_allow_origins.trim(),
387        cors_allow_headers.trim(),
388        cors_expose_headers.trim(),
389    );
390}
391
392/// Cached CORS headers stored in request extensions to avoid
393/// re-validating the request in `post_process`.
394#[derive(Clone)]
395pub(crate) struct CorsHeaders(pub(crate) HeaderMap);
396
397/// Rejects requests with wrong CORS headers
398pub(crate) fn pre_process<T>(
399    opts: &RequestHandlerOpts,
400    req: &mut Request<T>,
401) -> Option<Result<Response<Body>, Error>> {
402    let cors = opts.cors.as_ref()?;
403    match cors.check_request(req.method(), req.headers()) {
404        Ok((headers, state)) => {
405            tracing::debug!("cors state: {:?}", state);
406            // Stash validated headers for post_process to reuse
407            if !headers.is_empty() {
408                req.extensions_mut().insert(CorsHeaders(headers));
409            }
410            None
411        }
412        Err(err) => {
413            tracing::error!("cors error kind: {:?}", err);
414            Some(error_page::error_response(
415                req.uri(),
416                req.method(),
417                &StatusCode::FORBIDDEN,
418                &opts.page404,
419                &opts.page50x,
420            ))
421        }
422    }
423}
424
425/// Adds CORS headers to response
426pub(crate) fn post_process<T>(
427    opts: &RequestHandlerOpts,
428    req: &Request<T>,
429    mut resp: Response<Body>,
430) -> Result<Response<Body>, Error> {
431    if opts.cors.is_some()
432        && let Some(cors_headers) = req.extensions().get::<CorsHeaders>()
433    {
434        for (k, v) in cors_headers.0.iter() {
435            resp.headers_mut().insert(k, v.to_owned());
436        }
437        resp.headers_mut().insert(
438            hyper::header::VARY,
439            HeaderValue::from_name(hyper::header::ORIGIN),
440        );
441        resp.headers_mut().remove(http::header::ALLOW);
442    }
443    Ok(resp)
444}
445
446#[cfg(test)]
447mod tests {
448    use super::{Configured, Cors, post_process, pre_process};
449    use crate::{Error, handler::RequestHandlerOpts};
450    use hyper::{Body, Request, Response, StatusCode};
451
452    fn make_request(method: &str, origin: &str) -> Request<Body> {
453        let mut builder = Request::builder();
454        if !origin.is_empty() {
455            builder = builder.header("Origin", origin);
456        }
457        builder.method(method).uri("/").body(Body::empty()).unwrap()
458    }
459
460    fn make_response() -> Response<Body> {
461        Response::builder().body(Body::empty()).unwrap()
462    }
463
464    fn make_cors_config() -> Option<Configured> {
465        Cors::build(Some(
466            Cors::new()
467                .allow_origins(vec!["https://example.com/"])
468                .allow_headers(vec!["X-Allowed"])
469                .allow_methods(vec!["GET", "HEAD"]),
470        ))
471    }
472
473    fn get_allowed_origin(resp: Response<Body>) -> Option<String> {
474        resp.headers()
475            .get("Access-Control-Allow-Origin")
476            .and_then(|v| v.to_str().ok())
477            .map(|s| s.to_owned())
478    }
479
480    fn is_403(result: Option<Result<Response<Body>, Error>>) -> bool {
481        if let Some(Ok(response)) = result {
482            response.status() == StatusCode::FORBIDDEN
483        } else {
484            false
485        }
486    }
487
488    #[test]
489    fn test_cors_disabled() -> Result<(), Error> {
490        let opts = RequestHandlerOpts {
491            cors: None,
492            ..Default::default()
493        };
494        let mut req = make_request("GET", "https://example.com/");
495
496        assert!(pre_process(&opts, &mut req).is_none());
497
498        let resp = post_process(&opts, &req, make_response())?;
499        assert_eq!(get_allowed_origin(resp), None);
500
501        Ok(())
502    }
503
504    #[test]
505    fn test_non_cors_request() -> Result<(), Error> {
506        let opts = RequestHandlerOpts {
507            cors: make_cors_config(),
508            ..Default::default()
509        };
510        let mut req = make_request("GET", "");
511
512        assert!(pre_process(&opts, &mut req).is_none());
513
514        let resp = post_process(&opts, &req, make_response())?;
515        assert_eq!(get_allowed_origin(resp), None);
516
517        Ok(())
518    }
519
520    #[test]
521    fn test_forbidden_request() {
522        let opts = RequestHandlerOpts {
523            cors: make_cors_config(),
524            ..Default::default()
525        };
526
527        assert!(is_403(pre_process(
528            &opts,
529            &mut make_request("GET", "https://example.info")
530        )));
531        assert!(is_403(pre_process(
532            &opts,
533            &mut make_request("OPTIONS", "https://example.com")
534        )));
535
536        let mut req = make_request("OPTIONS", "https://example.com");
537        req.headers_mut()
538            .insert("Access-Control-Request-Method", "POST".try_into().unwrap());
539        assert!(is_403(pre_process(&opts, &mut req)));
540
541        let mut req = make_request("OPTIONS", "https://example.com");
542        req.headers_mut()
543            .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
544        req.headers_mut().insert(
545            "Access-Control-Request-Headers",
546            "X-Forbidden".try_into().unwrap(),
547        );
548        assert!(is_403(pre_process(&opts, &mut req)));
549    }
550
551    #[test]
552    fn test_allowed_request() -> Result<(), Error> {
553        let opts = RequestHandlerOpts {
554            cors: make_cors_config(),
555            ..Default::default()
556        };
557
558        let mut req = make_request("GET", "https://example.com");
559        assert!(pre_process(&opts, &mut req).is_none());
560
561        let resp = post_process(&opts, &req, make_response())?;
562        assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
563
564        let mut req = make_request("GET", "https://example.com");
565        req.headers_mut()
566            .insert("Access-Control-Request-Method", "GET".try_into().unwrap());
567        req.headers_mut().insert(
568            "Access-Control-Request-Headers",
569            "X-Allowed".try_into().unwrap(),
570        );
571        assert!(pre_process(&opts, &mut req).is_none());
572
573        let resp = post_process(&opts, &req, make_response())?;
574        assert_eq!(get_allowed_origin(resp), Some("https://example.com".into()));
575
576        Ok(())
577    }
578}