tower_http_cors_config/
lib.rs

1use std::{collections::HashSet, time::Duration};
2
3use regex::RegexSet;
4use tower_http::cors::CorsLayer;
5
6#[derive(Debug, Clone)]
7#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
8#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
9pub enum AllowedOrigins {
10    Any,
11    Mirror,
12    #[cfg_attr(feature = "serde", serde(untagged))]
13    List(SerdeRegexSet),
14}
15
16impl From<AllowedOrigins> for tower_http::cors::AllowOrigin {
17    fn from(value: AllowedOrigins) -> Self {
18        use tower_http::cors::AllowOrigin;
19        match value {
20            AllowedOrigins::Any => AllowOrigin::any(),
21            AllowedOrigins::Mirror => AllowOrigin::mirror_request(),
22            AllowedOrigins::List(origins) => AllowOrigin::predicate(move |origin, _parts| {
23                origin.to_str().is_ok_and(|origin| origins.is_match(origin))
24            }),
25        }
26    }
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
31#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
32pub enum AllowedHeaders {
33    Any,
34    Mirror,
35    #[cfg_attr(feature = "serde", serde(untagged))]
36    List(
37        #[cfg_attr(feature = "serde", serde(with = "serde_header_name"))] HashSet<http::HeaderName>,
38    ),
39}
40
41#[cfg(feature = "serde")]
42mod serde_header_name {
43    use std::collections::HashSet;
44
45    use http::HeaderName;
46    use serde::{de, ser::SerializeSeq, Deserialize, Deserializer, Serializer};
47
48    pub fn serialize<S>(value: &HashSet<HeaderName>, serializer: S) -> Result<S::Ok, S::Error>
49    where
50        S: Serializer,
51    {
52        let mut seq = serializer.serialize_seq(Some(value.len()))?;
53        for header in value {
54            seq.serialize_element(header.to_string().as_str())?;
55        }
56        seq.end()
57    }
58
59    pub fn deserialize<'de, D>(deserializer: D) -> Result<HashSet<HeaderName>, D::Error>
60    where
61        D: Deserializer<'de>,
62    {
63        let values: Vec<String> = Deserialize::deserialize(deserializer)?;
64        values
65            .into_iter()
66            .map(HeaderName::try_from)
67            .collect::<Result<HashSet<_>, _>>()
68            .map_err(de::Error::custom)
69    }
70}
71
72impl From<AllowedHeaders> for tower_http::cors::AllowHeaders {
73    fn from(value: AllowedHeaders) -> Self {
74        use tower_http::cors::AllowHeaders;
75        match value {
76            AllowedHeaders::Any => AllowHeaders::any(),
77            AllowedHeaders::Mirror => AllowHeaders::mirror_request(),
78            AllowedHeaders::List(allowed_headers) => AllowHeaders::list(allowed_headers),
79        }
80    }
81}
82
83#[derive(Debug, Clone, PartialEq, Eq)]
84#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
85#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
86pub enum AllowedMethods {
87    /// Mirror the request method
88    Mirror,
89    /// Allow a specific list of methods
90    #[cfg_attr(feature = "serde", serde(untagged))]
91    List(#[cfg_attr(feature = "serde", serde(with = "serde_method"))] HashSet<http::Method>),
92}
93
94impl From<AllowedMethods> for tower_http::cors::AllowMethods {
95    fn from(value: AllowedMethods) -> Self {
96        use tower_http::cors::AllowMethods;
97        match value {
98            AllowedMethods::Mirror => AllowMethods::mirror_request(),
99            AllowedMethods::List(methods) => AllowMethods::list(methods),
100        }
101    }
102}
103#[cfg(feature = "serde")]
104mod serde_method {
105    use std::collections::HashSet;
106
107    use http::Method;
108    use serde::{de, ser::SerializeSeq, Deserialize, Deserializer, Serializer};
109
110    pub fn serialize<S>(value: &HashSet<Method>, serializer: S) -> Result<S::Ok, S::Error>
111    where
112        S: Serializer,
113    {
114        let mut seq = serializer.serialize_seq(Some(value.len()))?;
115        for header in value {
116            seq.serialize_element(header.to_string().as_str())?;
117        }
118        seq.end()
119    }
120
121    pub fn deserialize<'de, D>(deserializer: D) -> Result<HashSet<Method>, D::Error>
122    where
123        D: Deserializer<'de>,
124    {
125        use std::str::FromStr;
126        let values: Vec<String> = Deserialize::deserialize(deserializer)?;
127        values
128            .into_iter()
129            .map(|value| Method::from_str(&value).map_err(de::Error::custom))
130            .collect()
131    }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
135#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
136#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
137pub enum ExposeHeaders {
138    /// Expose all headers by responding with `*`
139    Any,
140    /// Only expose a specific list of headers
141    #[cfg_attr(feature = "serde", serde(untagged))]
142    List(
143        #[cfg_attr(feature = "serde", serde(with = "serde_header_name"))] HashSet<http::HeaderName>,
144    ),
145}
146
147impl From<ExposeHeaders> for tower_http::cors::ExposeHeaders {
148    fn from(value: ExposeHeaders) -> Self {
149        match value {
150            ExposeHeaders::Any => tower_http::cors::ExposeHeaders::any(),
151            ExposeHeaders::List(headers) => tower_http::cors::ExposeHeaders::list(headers),
152        }
153    }
154}
155
156/// A wrapper around `RegexSet` that is serializable with serde
157#[derive(Debug, Clone)]
158#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
159pub struct SerdeRegexSet(
160    #[cfg_attr(feature = "serde", serde(with = "serde_regex_set"))] pub RegexSet,
161);
162
163impl std::ops::Deref for SerdeRegexSet {
164    type Target = RegexSet;
165    fn deref(&self) -> &Self::Target {
166        &self.0
167    }
168}
169
170#[cfg(feature = "serde")]
171mod serde_regex_set {
172    use regex::RegexSet;
173    use serde::{de, ser::SerializeSeq, Deserialize, Deserializer, Serializer};
174    use std::collections::HashSet;
175
176    pub fn serialize<S>(value: &RegexSet, serializer: S) -> Result<S::Ok, S::Error>
177    where
178        S: Serializer,
179    {
180        let mut sequence = serializer.serialize_seq(Some(value.len()))?;
181        for regex in value.patterns() {
182            sequence.serialize_element(regex)?;
183        }
184        sequence.end()
185    }
186
187    pub fn deserialize<'de, D>(deserializer: D) -> Result<RegexSet, D::Error>
188    where
189        D: Deserializer<'de>,
190    {
191        let values: HashSet<String> = Deserialize::deserialize(deserializer)?;
192        RegexSet::new(values).map_err(de::Error::custom)
193    }
194}
195
196#[derive(Debug, Clone, Default, PartialEq, Eq)]
197#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
198pub struct Vary(
199    #[cfg_attr(feature = "serde", serde(with = "serde_header_name"))] pub HashSet<http::HeaderName>,
200);
201
202impl From<Vary> for tower_http::cors::Vary {
203    fn from(value: Vary) -> Self {
204        tower_http::cors::Vary::list(value.0)
205    }
206}
207
208#[derive(Debug, Clone)]
209#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
210#[cfg_attr(feature = "serde", serde(rename_all = "kebab-case"))]
211pub struct Config {
212    /// Whether to allow credentials in CORS requests
213    #[cfg_attr(feature = "serde", serde(default))]
214    pub allow_credentials: bool,
215    /// Which request headers can be sent in the actual request.
216    /// Controls the [`Access-Control-Allow-Headers`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers) response header.
217    pub allowed_headers: AllowedHeaders,
218    /// Controls how to set the [`Access-Control-Allow-Methods`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods) response header.
219    pub allowed_methods: AllowedMethods,
220    /// Controls how to set the [`Access-Control-Allow-Origin`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin) response header.
221    pub allowed_origins: AllowedOrigins,
222    /// If true, include the [`Access-Control-Allow-Private-Network`](https://wicg.github.io/private-network-access/) response header.
223    #[cfg_attr(feature = "serde", serde(default))]
224    pub allow_private_network: bool,
225    /// The maximum age of the CORS request in seconds
226    #[cfg_attr(
227        feature = "serde",
228        serde(
229            with = "humantime_serde",
230            default,
231            skip_serializing_if = "Option::is_none"
232        )
233    )]
234    pub max_age: Option<Duration>,
235    /// Which headers are exposed to the client.
236    /// Controls the [`Access-Control-Expose-Headers`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Headers) response header.
237    pub expose_headers: ExposeHeaders,
238    /// Which headers to set in the Vary response header
239    #[cfg_attr(feature = "serde", serde(default))]
240    pub vary: Vary,
241}
242
243impl From<Config> for CorsLayer {
244    fn from(config: Config) -> Self {
245        let mut layer = CorsLayer::new()
246            .allow_credentials(config.allow_credentials)
247            .allow_headers(config.allowed_headers)
248            .allow_methods(config.allowed_methods)
249            .allow_origin(config.allowed_origins)
250            .allow_private_network(config.allow_private_network)
251            .expose_headers(config.expose_headers)
252            .vary(config.vary);
253
254        if let Some(max_age) = config.max_age {
255            layer = layer.max_age(max_age);
256        }
257
258        layer
259    }
260}
261
262#[cfg(all(feature = "serde", test))]
263mod tests {
264    use super::*;
265    use pretty_assertions::assert_eq;
266
267    #[test]
268    fn test_roundtrip() {
269        let config = Config {
270            allow_credentials: true,
271            allowed_headers: AllowedHeaders::List(HashSet::from([
272                http::header::CONNECTION,
273                http::header::AUTHORIZATION,
274            ])),
275            allowed_methods: AllowedMethods::Mirror,
276            allowed_origins: AllowedOrigins::Any,
277            allow_private_network: true,
278            max_age: Some(Duration::from_secs(3600)),
279            expose_headers: ExposeHeaders::Any,
280            vary: Vary(HashSet::from([http::HeaderName::from_static("origin")])),
281        };
282        let serialized = serde_yml::to_string(&config).unwrap();
283        let deserialized: Config = serde_yml::from_str(&serialized).unwrap();
284        // regex::RegexSet does not implement PartialEq
285        assert_eq!(config.allow_credentials, deserialized.allow_credentials);
286        assert_eq!(config.allowed_headers, deserialized.allowed_headers);
287        assert_eq!(config.allowed_methods, deserialized.allowed_methods);
288        assert_eq!(
289            config.allow_private_network,
290            deserialized.allow_private_network
291        );
292        assert_eq!(config.max_age, deserialized.max_age);
293        assert_eq!(config.expose_headers, deserialized.expose_headers);
294        assert_eq!(config.vary, deserialized.vary);
295    }
296}