1use std::{collections::HashSet, time::Duration};
2
3use regex::RegexSet;
4use tower_http::cors::CorsLayer;
5
6#[derive(Debug, Clone)]
7#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
8#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
9pub enum AllowedOrigins {
10 Any,
11 Mirror,
12 #[cfg_attr(feature = "serde", serde(untagged))]
13 List(SerdeRegexSet),
14}
15
16impl From<AllowedOrigins> for tower_http::cors::AllowOrigin {
17 fn from(value: AllowedOrigins) -> Self {
18 use tower_http::cors::AllowOrigin;
19 match value {
20 AllowedOrigins::Any => AllowOrigin::any(),
21 AllowedOrigins::Mirror => AllowOrigin::mirror_request(),
22 AllowedOrigins::List(origins) => AllowOrigin::predicate(move |origin, _parts| {
23 origin.to_str().is_ok_and(|origin| origins.is_match(origin))
24 }),
25 }
26 }
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
31#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
32pub enum AllowedHeaders {
33 Any,
34 Mirror,
35 #[cfg_attr(feature = "serde", serde(untagged))]
36 List(
37 #[cfg_attr(feature = "serde", serde(with = "serde_header_name"))] HashSet<http::HeaderName>,
38 ),
39}
40
41#[cfg(feature = "serde")]
42mod serde_header_name {
43 use std::collections::HashSet;
44
45 use http::HeaderName;
46 use serde::{de, ser::SerializeSeq, Deserialize, Deserializer, Serializer};
47
48 pub fn serialize<S>(value: &HashSet<HeaderName>, serializer: S) -> Result<S::Ok, S::Error>
49 where
50 S: Serializer,
51 {
52 let mut seq = serializer.serialize_seq(Some(value.len()))?;
53 for header in value {
54 seq.serialize_element(header.to_string().as_str())?;
55 }
56 seq.end()
57 }
58
59 pub fn deserialize<'de, D>(deserializer: D) -> Result<HashSet<HeaderName>, D::Error>
60 where
61 D: Deserializer<'de>,
62 {
63 let values: Vec<String> = Deserialize::deserialize(deserializer)?;
64 values
65 .into_iter()
66 .map(HeaderName::try_from)
67 .collect::<Result<HashSet<_>, _>>()
68 .map_err(de::Error::custom)
69 }
70}
71
72impl From<AllowedHeaders> for tower_http::cors::AllowHeaders {
73 fn from(value: AllowedHeaders) -> Self {
74 use tower_http::cors::AllowHeaders;
75 match value {
76 AllowedHeaders::Any => AllowHeaders::any(),
77 AllowedHeaders::Mirror => AllowHeaders::mirror_request(),
78 AllowedHeaders::List(allowed_headers) => AllowHeaders::list(allowed_headers),
79 }
80 }
81}
82
83#[derive(Debug, Clone, PartialEq, Eq)]
84#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
85#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
86pub enum AllowedMethods {
87 Mirror,
89 #[cfg_attr(feature = "serde", serde(untagged))]
91 List(#[cfg_attr(feature = "serde", serde(with = "serde_method"))] HashSet<http::Method>),
92}
93
94impl From<AllowedMethods> for tower_http::cors::AllowMethods {
95 fn from(value: AllowedMethods) -> Self {
96 use tower_http::cors::AllowMethods;
97 match value {
98 AllowedMethods::Mirror => AllowMethods::mirror_request(),
99 AllowedMethods::List(methods) => AllowMethods::list(methods),
100 }
101 }
102}
103#[cfg(feature = "serde")]
104mod serde_method {
105 use std::collections::HashSet;
106
107 use http::Method;
108 use serde::{de, ser::SerializeSeq, Deserialize, Deserializer, Serializer};
109
110 pub fn serialize<S>(value: &HashSet<Method>, serializer: S) -> Result<S::Ok, S::Error>
111 where
112 S: Serializer,
113 {
114 let mut seq = serializer.serialize_seq(Some(value.len()))?;
115 for header in value {
116 seq.serialize_element(header.to_string().as_str())?;
117 }
118 seq.end()
119 }
120
121 pub fn deserialize<'de, D>(deserializer: D) -> Result<HashSet<Method>, D::Error>
122 where
123 D: Deserializer<'de>,
124 {
125 use std::str::FromStr;
126 let values: Vec<String> = Deserialize::deserialize(deserializer)?;
127 values
128 .into_iter()
129 .map(|value| Method::from_str(&value).map_err(de::Error::custom))
130 .collect()
131 }
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
135#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
136#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
137pub enum ExposeHeaders {
138 Any,
140 #[cfg_attr(feature = "serde", serde(untagged))]
142 List(
143 #[cfg_attr(feature = "serde", serde(with = "serde_header_name"))] HashSet<http::HeaderName>,
144 ),
145}
146
147impl From<ExposeHeaders> for tower_http::cors::ExposeHeaders {
148 fn from(value: ExposeHeaders) -> Self {
149 match value {
150 ExposeHeaders::Any => tower_http::cors::ExposeHeaders::any(),
151 ExposeHeaders::List(headers) => tower_http::cors::ExposeHeaders::list(headers),
152 }
153 }
154}
155
156#[derive(Debug, Clone)]
158#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
159pub struct SerdeRegexSet(
160 #[cfg_attr(feature = "serde", serde(with = "serde_regex_set"))] pub RegexSet,
161);
162
163impl std::ops::Deref for SerdeRegexSet {
164 type Target = RegexSet;
165 fn deref(&self) -> &Self::Target {
166 &self.0
167 }
168}
169
170#[cfg(feature = "serde")]
171mod serde_regex_set {
172 use regex::RegexSet;
173 use serde::{de, ser::SerializeSeq, Deserialize, Deserializer, Serializer};
174 use std::collections::HashSet;
175
176 pub fn serialize<S>(value: &RegexSet, serializer: S) -> Result<S::Ok, S::Error>
177 where
178 S: Serializer,
179 {
180 let mut sequence = serializer.serialize_seq(Some(value.len()))?;
181 for regex in value.patterns() {
182 sequence.serialize_element(regex)?;
183 }
184 sequence.end()
185 }
186
187 pub fn deserialize<'de, D>(deserializer: D) -> Result<RegexSet, D::Error>
188 where
189 D: Deserializer<'de>,
190 {
191 let values: HashSet<String> = Deserialize::deserialize(deserializer)?;
192 RegexSet::new(values).map_err(de::Error::custom)
193 }
194}
195
196#[derive(Debug, Clone, Default, PartialEq, Eq)]
197#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
198pub struct Vary(
199 #[cfg_attr(feature = "serde", serde(with = "serde_header_name"))] pub HashSet<http::HeaderName>,
200);
201
202impl From<Vary> for tower_http::cors::Vary {
203 fn from(value: Vary) -> Self {
204 tower_http::cors::Vary::list(value.0)
205 }
206}
207
208#[derive(Debug, Clone)]
209#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
210#[cfg_attr(feature = "serde", serde(rename_all = "kebab-case"))]
211pub struct Config {
212 #[cfg_attr(feature = "serde", serde(default))]
214 pub allow_credentials: bool,
215 pub allowed_headers: AllowedHeaders,
218 pub allowed_methods: AllowedMethods,
220 pub allowed_origins: AllowedOrigins,
222 #[cfg_attr(feature = "serde", serde(default))]
224 pub allow_private_network: bool,
225 #[cfg_attr(
227 feature = "serde",
228 serde(
229 with = "humantime_serde",
230 default,
231 skip_serializing_if = "Option::is_none"
232 )
233 )]
234 pub max_age: Option<Duration>,
235 pub expose_headers: ExposeHeaders,
238 #[cfg_attr(feature = "serde", serde(default))]
240 pub vary: Vary,
241}
242
243impl From<Config> for CorsLayer {
244 fn from(config: Config) -> Self {
245 let mut layer = CorsLayer::new()
246 .allow_credentials(config.allow_credentials)
247 .allow_headers(config.allowed_headers)
248 .allow_methods(config.allowed_methods)
249 .allow_origin(config.allowed_origins)
250 .allow_private_network(config.allow_private_network)
251 .expose_headers(config.expose_headers)
252 .vary(config.vary);
253
254 if let Some(max_age) = config.max_age {
255 layer = layer.max_age(max_age);
256 }
257
258 layer
259 }
260}
261
262#[cfg(all(feature = "serde", test))]
263mod tests {
264 use super::*;
265 use pretty_assertions::assert_eq;
266
267 #[test]
268 fn test_roundtrip() {
269 let config = Config {
270 allow_credentials: true,
271 allowed_headers: AllowedHeaders::List(HashSet::from([
272 http::header::CONNECTION,
273 http::header::AUTHORIZATION,
274 ])),
275 allowed_methods: AllowedMethods::Mirror,
276 allowed_origins: AllowedOrigins::Any,
277 allow_private_network: true,
278 max_age: Some(Duration::from_secs(3600)),
279 expose_headers: ExposeHeaders::Any,
280 vary: Vary(HashSet::from([http::HeaderName::from_static("origin")])),
281 };
282 let serialized = serde_yml::to_string(&config).unwrap();
283 let deserialized: Config = serde_yml::from_str(&serialized).unwrap();
284 assert_eq!(config.allow_credentials, deserialized.allow_credentials);
286 assert_eq!(config.allowed_headers, deserialized.allowed_headers);
287 assert_eq!(config.allowed_methods, deserialized.allowed_methods);
288 assert_eq!(
289 config.allow_private_network,
290 deserialized.allow_private_network
291 );
292 assert_eq!(config.max_age, deserialized.max_age);
293 assert_eq!(config.expose_headers, deserialized.expose_headers);
294 assert_eq!(config.vary, deserialized.vary);
295 }
296}