1use std::collections::HashMap;
46use std::sync::RwLock;
47
48use serde::{Deserialize, Serialize};
49
50#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
57pub struct CorsRule {
58 pub allowed_origins: Vec<String>,
62 pub allowed_methods: Vec<String>,
66 pub allowed_headers: Vec<String>,
69 #[serde(default)]
72 pub expose_headers: Vec<String>,
73 #[serde(default)]
76 pub max_age_seconds: Option<u32>,
77 #[serde(default)]
79 pub id: Option<String>,
80}
81
82#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
84pub struct CorsConfig {
85 pub rules: Vec<CorsRule>,
87}
88
89#[derive(Debug, Default, Serialize, Deserialize)]
91struct CorsSnapshot {
92 by_bucket: HashMap<String, CorsConfig>,
93}
94
95#[derive(Debug, Default)]
101pub struct CorsManager {
102 by_bucket: RwLock<HashMap<String, CorsConfig>>,
103}
104
105impl CorsManager {
106 #[must_use]
108 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn put(&self, bucket: &str, config: CorsConfig) {
115 crate::lock_recovery::recover_write(&self.by_bucket, "cors.by_bucket")
116 .insert(bucket.to_owned(), config);
117 }
118
119 #[must_use]
122 pub fn get(&self, bucket: &str) -> Option<CorsConfig> {
123 crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket")
124 .get(bucket)
125 .cloned()
126 }
127
128 pub fn delete(&self, bucket: &str) {
130 crate::lock_recovery::recover_write(&self.by_bucket, "cors.by_bucket").remove(bucket);
131 }
132
133 pub fn to_json(&self) -> Result<String, serde_json::Error> {
136 let snap = CorsSnapshot {
137 by_bucket: crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket")
138 .clone(),
139 };
140 serde_json::to_string(&snap)
141 }
142
143 pub fn from_json(s: &str) -> Result<Self, serde_json::Error> {
146 let snap: CorsSnapshot = serde_json::from_str(s)?;
147 Ok(Self {
148 by_bucket: RwLock::new(snap.by_bucket),
149 })
150 }
151
152 #[must_use]
161 pub fn match_preflight(
162 &self,
163 bucket: &str,
164 origin: &str,
165 method: &str,
166 request_headers: &[String],
167 ) -> Option<CorsRule> {
168 let map = crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket");
169 let cfg = map.get(bucket)?;
170 for rule in &cfg.rules {
171 if !rule_matches_origin(rule, origin) {
172 continue;
173 }
174 if !rule_matches_method(rule, method) {
175 continue;
176 }
177 if !rule_matches_headers(rule, request_headers) {
178 continue;
179 }
180 return Some(rule.clone());
181 }
182 None
183 }
184}
185
186fn rule_matches_origin(rule: &CorsRule, origin: &str) -> bool {
187 rule.allowed_origins
188 .iter()
189 .any(|pat| matches_glob(pat, origin))
190}
191
192fn rule_matches_method(rule: &CorsRule, method: &str) -> bool {
193 rule.allowed_methods
196 .iter()
197 .any(|pat| pat == "*" || pat == method)
198}
199
200fn rule_matches_headers(rule: &CorsRule, request_headers: &[String]) -> bool {
201 if request_headers.is_empty() {
202 return true;
203 }
204 request_headers.iter().all(|h| {
205 rule.allowed_headers
206 .iter()
207 .any(|pat| matches_glob_ci(pat, h))
208 })
209}
210
211#[must_use]
222pub fn matches_glob(pattern: &str, candidate: &str) -> bool {
223 if pattern == "*" {
224 return true;
225 }
226 pattern == candidate
227}
228
229#[must_use]
231pub fn matches_glob_ci(pattern: &str, candidate: &str) -> bool {
232 if pattern == "*" {
233 return true;
234 }
235 pattern.eq_ignore_ascii_case(candidate)
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 fn rule(origins: &[&str], methods: &[&str], headers: &[&str]) -> CorsRule {
243 CorsRule {
244 allowed_origins: origins.iter().map(|s| (*s).to_owned()).collect(),
245 allowed_methods: methods.iter().map(|s| (*s).to_owned()).collect(),
246 allowed_headers: headers.iter().map(|s| (*s).to_owned()).collect(),
247 expose_headers: Vec::new(),
248 max_age_seconds: Some(3600),
249 id: None,
250 }
251 }
252
253 #[test]
254 fn matches_glob_wildcard_matches_anything() {
255 assert!(matches_glob("*", "https://example.com"));
256 assert!(matches_glob("*", ""));
257 assert!(matches_glob("*", "GET"));
258 }
259
260 #[test]
261 fn matches_glob_exact_match() {
262 assert!(matches_glob("https://example.com", "https://example.com"));
263 assert!(matches_glob("GET", "GET"));
264 }
265
266 #[test]
267 fn matches_glob_no_match() {
268 assert!(!matches_glob("https://example.com", "https://evil.com"));
269 assert!(!matches_glob("GET", "PUT"));
270 }
271
272 #[test]
273 fn matches_glob_origin_is_case_sensitive() {
274 assert!(!matches_glob("https://Example.com", "https://example.com"));
276 }
277
278 #[test]
279 fn matches_glob_ci_header_is_case_insensitive() {
280 assert!(matches_glob_ci("Content-Type", "content-type"));
281 assert!(matches_glob_ci("X-Amz-Date", "x-amz-date"));
282 assert!(!matches_glob_ci("X-Other", "X-Different"));
283 }
284
285 #[test]
286 fn match_preflight_happy_path() {
287 let mgr = CorsManager::new();
288 mgr.put(
289 "b",
290 CorsConfig {
291 rules: vec![rule(
292 &["https://app.example.com"],
293 &["GET", "PUT"],
294 &["Content-Type"],
295 )],
296 },
297 );
298 let m = mgr.match_preflight(
299 "b",
300 "https://app.example.com",
301 "PUT",
302 &["Content-Type".to_owned()],
303 );
304 assert!(m.is_some());
305 let rule = m.unwrap();
306 assert_eq!(rule.max_age_seconds, Some(3600));
307 }
308
309 #[test]
310 fn match_preflight_no_rule_for_bucket() {
311 let mgr = CorsManager::new();
312 let m = mgr.match_preflight("ghost", "https://anything", "GET", &[]);
313 assert!(m.is_none());
314 }
315
316 #[test]
317 fn match_preflight_method_not_allowed() {
318 let mgr = CorsManager::new();
319 mgr.put(
320 "b",
321 CorsConfig {
322 rules: vec![rule(&["*"], &["GET"], &["*"])],
323 },
324 );
325 assert!(
327 mgr.match_preflight("b", "https://x", "DELETE", &[])
328 .is_none()
329 );
330 assert!(mgr.match_preflight("b", "https://x", "GET", &[]).is_some());
332 }
333
334 #[test]
335 fn match_preflight_origin_not_allowed() {
336 let mgr = CorsManager::new();
337 mgr.put(
338 "b",
339 CorsConfig {
340 rules: vec![rule(&["https://good.example.com"], &["GET"], &["*"])],
341 },
342 );
343 assert!(
344 mgr.match_preflight("b", "https://evil.example.com", "GET", &[])
345 .is_none()
346 );
347 }
348
349 #[test]
350 fn match_preflight_wildcard_origin() {
351 let mgr = CorsManager::new();
352 mgr.put(
353 "b",
354 CorsConfig {
355 rules: vec![rule(&["*"], &["GET"], &[])],
356 },
357 );
358 let m = mgr.match_preflight("b", "https://anywhere", "GET", &[]);
359 assert!(m.is_some());
360 }
361
362 #[test]
363 fn match_preflight_wildcard_header() {
364 let mgr = CorsManager::new();
365 mgr.put(
366 "b",
367 CorsConfig {
368 rules: vec![rule(&["*"], &["PUT"], &["*"])],
369 },
370 );
371 let m = mgr.match_preflight(
372 "b",
373 "https://x",
374 "PUT",
375 &["X-Custom-Header".to_owned(), "Content-Type".to_owned()],
376 );
377 assert!(m.is_some());
378 }
379
380 #[test]
381 fn match_preflight_first_matching_rule_wins() {
382 let mgr = CorsManager::new();
383 mgr.put(
384 "b",
385 CorsConfig {
386 rules: vec![
387 CorsRule {
388 allowed_origins: vec!["*".into()],
389 allowed_methods: vec!["GET".into()],
390 allowed_headers: vec!["*".into()],
391 expose_headers: Vec::new(),
392 max_age_seconds: Some(60),
393 id: Some("first".into()),
394 },
395 CorsRule {
396 allowed_origins: vec!["*".into()],
397 allowed_methods: vec!["GET".into()],
398 allowed_headers: vec!["*".into()],
399 expose_headers: Vec::new(),
400 max_age_seconds: Some(7200),
401 id: Some("second".into()),
402 },
403 ],
404 },
405 );
406 let m = mgr
407 .match_preflight("b", "https://x", "GET", &[])
408 .expect("should match");
409 assert_eq!(m.id.as_deref(), Some("first"));
411 assert_eq!(m.max_age_seconds, Some(60));
412 }
413
414 #[test]
415 fn match_preflight_header_case_insensitive() {
416 let mgr = CorsManager::new();
417 mgr.put(
418 "b",
419 CorsConfig {
420 rules: vec![rule(&["*"], &["PUT"], &["Content-Type"])],
421 },
422 );
423 let m = mgr.match_preflight("b", "https://x", "PUT", &["content-type".to_owned()]);
426 assert!(m.is_some());
427 }
428
429 #[test]
430 fn put_replaces_previous_config() {
431 let mgr = CorsManager::new();
432 mgr.put(
433 "b",
434 CorsConfig {
435 rules: vec![rule(&["https://a"], &["GET"], &["*"])],
436 },
437 );
438 mgr.put(
439 "b",
440 CorsConfig {
441 rules: vec![rule(&["https://b"], &["PUT"], &["*"])],
442 },
443 );
444 let cfg = mgr.get("b").expect("config present");
445 assert_eq!(cfg.rules.len(), 1);
446 assert_eq!(cfg.rules[0].allowed_origins, vec!["https://b".to_string()]);
447 }
448
449 #[test]
450 fn delete_is_idempotent() {
451 let mgr = CorsManager::new();
452 mgr.delete("never-existed"); mgr.put(
454 "b",
455 CorsConfig {
456 rules: vec![rule(&["*"], &["GET"], &[])],
457 },
458 );
459 mgr.delete("b");
460 assert!(mgr.get("b").is_none());
461 }
462
463 #[test]
464 fn json_round_trip() {
465 let mgr = CorsManager::new();
466 mgr.put(
467 "b",
468 CorsConfig {
469 rules: vec![CorsRule {
470 allowed_origins: vec!["https://example.com".into()],
471 allowed_methods: vec!["GET".into(), "PUT".into()],
472 allowed_headers: vec!["Content-Type".into()],
473 expose_headers: vec!["ETag".into()],
474 max_age_seconds: Some(3600),
475 id: Some("rule-1".into()),
476 }],
477 },
478 );
479 let json = mgr.to_json().expect("to_json");
480 let mgr2 = CorsManager::from_json(&json).expect("from_json");
481 assert_eq!(mgr.get("b"), mgr2.get("b"));
482 }
483
484 #[test]
489 fn cors_to_json_after_panic_recovers_via_poison() {
490 let mgr = std::sync::Arc::new(CorsManager::new());
491 mgr.put(
492 "b",
493 CorsConfig {
494 rules: vec![rule(&["*"], &["GET"], &[])],
495 },
496 );
497 let mgr_cl = std::sync::Arc::clone(&mgr);
498 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
499 let mut g = mgr_cl.by_bucket.write().expect("clean lock");
500 g.entry("b2".into()).or_default();
501 panic!("force-poison");
502 }));
503 assert!(
504 mgr.by_bucket.is_poisoned(),
505 "write panic must poison by_bucket lock"
506 );
507 let json = mgr.to_json().expect("to_json after poison must succeed");
508 let mgr2 = CorsManager::from_json(&json).expect("from_json");
509 assert!(
510 mgr2.get("b").is_some(),
511 "recovered snapshot keeps original config"
512 );
513 }
514}