viz_core/middleware/
cors.rs

1//! CORS Middleware.
2
3use std::{collections::HashSet, fmt, sync::Arc};
4
5use crate::{
6    header::{
7        HeaderMap, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS,
8        ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_REQUEST_HEADERS,
9        ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY,
10    },
11    headers::{
12        AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders,
13        HeaderMapExt,
14    },
15    Handler, IntoResponse, Method, Request, RequestExt, Response, Result, StatusCode, Transform,
16};
17
18/// A configuration for [`CorsMiddleware`].
19pub struct Config {
20    max_age: usize,
21    credentials: bool,
22    allow_methods: HashSet<Method>,
23    allow_headers: HashSet<HeaderName>,
24    allow_origins: HashSet<HeaderValue>,
25    expose_headers: HashSet<HeaderName>,
26    origin_verify: Option<Arc<dyn Fn(&HeaderValue) -> bool + Send + Sync>>,
27}
28
29impl Config {
30    /// Create a new [`Config`] with default values.
31    #[must_use]
32    pub fn new() -> Self {
33        Self::default()
34    }
35
36    /// Seconds a preflight request can be cached. [MDN]
37    ///
38    /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
39    #[must_use]
40    pub const fn max_age(mut self, max_age: usize) -> Self {
41        self.max_age = max_age;
42        self
43    }
44
45    /// Whether to allow credentials. [MDN]
46    ///
47    /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials
48    #[must_use]
49    pub const fn credentials(mut self, credentials: bool) -> Self {
50        self.credentials = credentials;
51        self
52    }
53
54    /// Allowed HTTP methods. [MDN]
55    ///
56    /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods
57    #[must_use]
58    pub fn allow_methods<H>(mut self, allow_methods: H) -> Self
59    where
60        H: IntoIterator,
61        H::Item: TryInto<Method>,
62    {
63        self.allow_methods = allow_methods
64            .into_iter()
65            .map(TryInto::try_into)
66            .filter_map(Result::ok)
67            .collect();
68        self
69    }
70
71    /// Allowed HTTP headers. [MDN]
72    ///
73    /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers
74    #[must_use]
75    pub fn allow_headers<H>(mut self, allow_headers: H) -> Self
76    where
77        H: IntoIterator,
78        H::Item: TryInto<HeaderName>,
79    {
80        self.allow_headers = allow_headers
81            .into_iter()
82            .map(TryInto::try_into)
83            .filter_map(Result::ok)
84            .collect();
85        self
86    }
87
88    /// Allowed origins. [MDN]
89    ///
90    /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
91    #[must_use]
92    pub fn allow_origins<H>(mut self, allow_origins: H) -> Self
93    where
94        H: IntoIterator,
95        H::Item: TryInto<HeaderValue>,
96    {
97        self.allow_origins = allow_origins
98            .into_iter()
99            .map(TryInto::try_into)
100            .filter_map(Result::ok)
101            .collect();
102        self
103    }
104
105    /// Exposed HTTP headers. [MDN]
106    ///
107    /// [MDN]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers
108    #[must_use]
109    pub fn expose_headers<H>(mut self, expose_headers: H) -> Self
110    where
111        H: IntoIterator,
112        H::Item: TryInto<HeaderName>,
113    {
114        self.expose_headers = expose_headers
115            .into_iter()
116            .map(TryInto::try_into)
117            .filter_map(Result::ok)
118            .collect();
119        self
120    }
121
122    /// A function to verify the origin. If the function returns false, the request will be rejected.
123    #[must_use]
124    pub fn origin_verify(
125        mut self,
126        origin_verify: Option<Arc<dyn Fn(&HeaderValue) -> bool + Send + Sync>>,
127    ) -> Self {
128        self.origin_verify = origin_verify;
129        self
130    }
131}
132
133impl Default for Config {
134    fn default() -> Self {
135        Self {
136            max_age: 86400,
137            credentials: false,
138            allow_methods: HashSet::from([
139                Method::GET,
140                Method::POST,
141                Method::HEAD,
142                Method::PUT,
143                Method::DELETE,
144                Method::PATCH,
145            ]),
146            allow_origins: HashSet::from([HeaderValue::from_static("*")]),
147            allow_headers: HashSet::new(),
148            expose_headers: HashSet::new(),
149            origin_verify: None,
150        }
151    }
152}
153
154impl Clone for Config {
155    fn clone(&self) -> Self {
156        Self {
157            max_age: self.max_age,
158            credentials: self.credentials,
159            allow_methods: self.allow_methods.clone(),
160            allow_headers: self.allow_headers.clone(),
161            allow_origins: self.allow_origins.clone(),
162            expose_headers: self.expose_headers.clone(),
163            origin_verify: self.origin_verify.clone(),
164        }
165    }
166}
167
168impl fmt::Debug for Config {
169    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170        f.debug_struct("CorsConfig")
171            .field("max_age", &self.max_age)
172            .field("credentials", &self.credentials)
173            .field("allow_methods", &self.allow_methods)
174            .field("allow_headers", &self.allow_headers)
175            .field("allow_origins", &self.allow_origins)
176            .field("expose_headers", &self.expose_headers)
177            .finish_non_exhaustive()
178    }
179}
180
181impl<H> Transform<H> for Config {
182    type Output = CorsMiddleware<H>;
183
184    fn transform(&self, h: H) -> Self::Output {
185        CorsMiddleware {
186            h,
187            acam: self.allow_methods.clone().into_iter().collect(),
188            acah: self.allow_headers.clone().into_iter().collect(),
189            aceh: self.expose_headers.clone().into_iter().collect(),
190            config: self.clone(),
191        }
192    }
193}
194
195/// CORS middleware.
196#[derive(Debug, Clone)]
197pub struct CorsMiddleware<H> {
198    h: H,
199    config: Config,
200    acam: AccessControlAllowMethods,
201    acah: AccessControlAllowHeaders,
202    aceh: AccessControlExposeHeaders,
203}
204
205#[crate::async_trait]
206impl<H, O> Handler<Request> for CorsMiddleware<H>
207where
208    H: Handler<Request, Output = Result<O>>,
209    O: IntoResponse,
210{
211    type Output = Result<Response>;
212
213    async fn call(&self, req: Request) -> Self::Output {
214        let Some(origin) = req.header(ORIGIN).filter(is_not_empty) else {
215            return self.h.call(req).await.map(IntoResponse::into_response);
216        };
217
218        if !self.config.allow_origins.contains(&origin)
219            || !self
220                .config
221                .origin_verify
222                .as_ref()
223                .map_or(true, |f| (f)(&origin))
224        {
225            return Err(StatusCode::FORBIDDEN.into_error());
226        }
227
228        let mut headers = HeaderMap::new();
229        let mut resp = if req.method() == Method::OPTIONS {
230            // Preflight request
231            if req
232                .header(ACCESS_CONTROL_REQUEST_METHOD)
233                .is_some_and(|method| {
234                    self.config.allow_methods.is_empty()
235                        || self.config.allow_methods.contains(&method)
236                })
237            {
238                headers.typed_insert(self.acam.clone());
239            } else {
240                return Err((StatusCode::FORBIDDEN, "Invalid Preflight Request").into_error());
241            }
242
243            let (allow_headers, request_headers) = req
244                .header(ACCESS_CONTROL_REQUEST_HEADERS)
245                .map_or((true, None), |hs: HeaderValue| {
246                    (
247                        hs.to_str()
248                            .map(|hs| {
249                                hs.split(',')
250                                    .map(str::as_bytes)
251                                    .map(HeaderName::from_bytes)
252                                    .filter_map(Result::ok)
253                                    .any(|header| self.config.allow_headers.contains(&header))
254                            })
255                            .unwrap_or(false),
256                        Some(hs),
257                    )
258                });
259
260            if !allow_headers {
261                return Err((StatusCode::FORBIDDEN, "Invalid Preflight Request").into_error());
262            }
263
264            if self.config.allow_headers.is_empty() {
265                headers.insert(
266                    ACCESS_CONTROL_ALLOW_HEADERS,
267                    request_headers.unwrap_or(HeaderValue::from_static("*")),
268                );
269            } else {
270                headers.typed_insert(self.acah.clone());
271            }
272
273            // 204 - no content
274            StatusCode::NO_CONTENT.into_response()
275        } else {
276            // Simple Request
277            if !self.config.expose_headers.is_empty() {
278                headers.typed_insert(self.aceh.clone());
279            }
280
281            self.h.call(req).await.map(IntoResponse::into_response)?
282        };
283
284        // https://github.com/rs/cors/issues/10
285        headers.insert(VARY, ORIGIN.into());
286        headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
287
288        if self.config.credentials {
289            headers.insert(
290                ACCESS_CONTROL_ALLOW_CREDENTIALS,
291                HeaderValue::from_static("true"),
292            );
293        }
294
295        resp.headers_mut().extend(headers);
296
297        Ok(resp)
298    }
299}
300
301fn is_not_empty(h: &HeaderValue) -> bool {
302    !h.is_empty()
303}