1use std::collections::HashMap;
46use std::sync::RwLock;
47
48use serde::{Deserialize, Serialize};
49
50#[derive(Debug, thiserror::Error)]
52pub enum CorsValidationError {
53 #[error(
54 "AllowedMethod {0:?} is not a valid AWS S3 CORS verb (must be one of GET / PUT / POST / DELETE / HEAD; `*` is rejected)"
55 )]
56 UnsupportedMethod(String),
57}
58
59#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
66pub struct CorsRule {
67 pub allowed_origins: Vec<String>,
71 pub allowed_methods: Vec<String>,
75 pub allowed_headers: Vec<String>,
78 #[serde(default)]
81 pub expose_headers: Vec<String>,
82 #[serde(default)]
85 pub max_age_seconds: Option<u32>,
86 #[serde(default)]
88 pub id: Option<String>,
89}
90
91#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
93pub struct CorsConfig {
94 pub rules: Vec<CorsRule>,
96}
97
98#[derive(Debug, Default, Serialize, Deserialize)]
100struct CorsSnapshot {
101 by_bucket: HashMap<String, CorsConfig>,
102}
103
104#[derive(Debug, Default)]
110pub struct CorsManager {
111 by_bucket: RwLock<HashMap<String, CorsConfig>>,
112}
113
114impl CorsManager {
115 #[must_use]
117 pub fn new() -> Self {
118 Self::default()
119 }
120
121 pub fn put(&self, bucket: &str, config: CorsConfig) {
124 crate::lock_recovery::recover_write(&self.by_bucket, "cors.by_bucket")
125 .insert(bucket.to_owned(), config);
126 }
127
128 pub fn validate(config: &CorsConfig) -> Result<(), CorsValidationError> {
140 const VALID_METHODS: &[&str] = &["GET", "PUT", "POST", "DELETE", "HEAD"];
141 for rule in &config.rules {
142 for m in &rule.allowed_methods {
143 if !VALID_METHODS.contains(&m.as_str()) {
144 return Err(CorsValidationError::UnsupportedMethod(m.clone()));
145 }
146 }
147 }
148 Ok(())
149 }
150
151 #[must_use]
154 pub fn get(&self, bucket: &str) -> Option<CorsConfig> {
155 crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket")
156 .get(bucket)
157 .cloned()
158 }
159
160 pub fn delete(&self, bucket: &str) {
162 crate::lock_recovery::recover_write(&self.by_bucket, "cors.by_bucket").remove(bucket);
163 }
164
165 pub fn to_json(&self) -> Result<String, serde_json::Error> {
168 let snap = CorsSnapshot {
169 by_bucket: crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket")
170 .clone(),
171 };
172 serde_json::to_string(&snap)
173 }
174
175 pub fn from_json(s: &str) -> Result<Self, serde_json::Error> {
188 let snap: CorsSnapshot = serde_json::from_str(s)?;
189 for cfg in snap.by_bucket.values() {
190 if let Err(e) = Self::validate(cfg) {
191 return Err(serde::de::Error::custom(format!(
192 "CORS snapshot fails AWS S3 validation: {e}"
193 )));
194 }
195 }
196 Ok(Self {
197 by_bucket: RwLock::new(snap.by_bucket),
198 })
199 }
200
201 #[must_use]
210 pub fn match_preflight(
211 &self,
212 bucket: &str,
213 origin: &str,
214 method: &str,
215 request_headers: &[String],
216 ) -> Option<CorsRule> {
217 let map = crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket");
218 let cfg = map.get(bucket)?;
219 for rule in &cfg.rules {
220 if !rule_matches_origin(rule, origin) {
221 continue;
222 }
223 if !rule_matches_method(rule, method) {
224 continue;
225 }
226 if !rule_matches_headers(rule, request_headers) {
227 continue;
228 }
229 return Some(rule.clone());
230 }
231 None
232 }
233}
234
235fn rule_matches_origin(rule: &CorsRule, origin: &str) -> bool {
236 rule.allowed_origins
237 .iter()
238 .any(|pat| matches_glob(pat, origin))
239}
240
241fn rule_matches_method(rule: &CorsRule, method: &str) -> bool {
242 rule.allowed_methods.iter().any(|pat| pat == method)
249}
250
251fn rule_matches_headers(rule: &CorsRule, request_headers: &[String]) -> bool {
252 if request_headers.is_empty() {
253 return true;
254 }
255 request_headers.iter().all(|h| {
256 rule.allowed_headers
257 .iter()
258 .any(|pat| matches_glob_ci(pat, h))
259 })
260}
261
262#[must_use]
273pub fn matches_glob(pattern: &str, candidate: &str) -> bool {
274 if pattern == "*" {
275 return true;
276 }
277 pattern == candidate
278}
279
280#[must_use]
282pub fn matches_glob_ci(pattern: &str, candidate: &str) -> bool {
283 if pattern == "*" {
284 return true;
285 }
286 pattern.eq_ignore_ascii_case(candidate)
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 fn rule(origins: &[&str], methods: &[&str], headers: &[&str]) -> CorsRule {
294 CorsRule {
295 allowed_origins: origins.iter().map(|s| (*s).to_owned()).collect(),
296 allowed_methods: methods.iter().map(|s| (*s).to_owned()).collect(),
297 allowed_headers: headers.iter().map(|s| (*s).to_owned()).collect(),
298 expose_headers: Vec::new(),
299 max_age_seconds: Some(3600),
300 id: None,
301 }
302 }
303
304 #[test]
305 fn matches_glob_wildcard_matches_anything() {
306 assert!(matches_glob("*", "https://example.com"));
307 assert!(matches_glob("*", ""));
308 assert!(matches_glob("*", "GET"));
309 }
310
311 #[test]
312 fn matches_glob_exact_match() {
313 assert!(matches_glob("https://example.com", "https://example.com"));
314 assert!(matches_glob("GET", "GET"));
315 }
316
317 #[test]
318 fn matches_glob_no_match() {
319 assert!(!matches_glob("https://example.com", "https://evil.com"));
320 assert!(!matches_glob("GET", "PUT"));
321 }
322
323 #[test]
324 fn matches_glob_origin_is_case_sensitive() {
325 assert!(!matches_glob("https://Example.com", "https://example.com"));
327 }
328
329 #[test]
330 fn matches_glob_ci_header_is_case_insensitive() {
331 assert!(matches_glob_ci("Content-Type", "content-type"));
332 assert!(matches_glob_ci("X-Amz-Date", "x-amz-date"));
333 assert!(!matches_glob_ci("X-Other", "X-Different"));
334 }
335
336 #[test]
337 fn match_preflight_happy_path() {
338 let mgr = CorsManager::new();
339 mgr.put(
340 "b",
341 CorsConfig {
342 rules: vec![rule(
343 &["https://app.example.com"],
344 &["GET", "PUT"],
345 &["Content-Type"],
346 )],
347 },
348 );
349 let m = mgr.match_preflight(
350 "b",
351 "https://app.example.com",
352 "PUT",
353 &["Content-Type".to_owned()],
354 );
355 assert!(m.is_some());
356 let rule = m.unwrap();
357 assert_eq!(rule.max_age_seconds, Some(3600));
358 }
359
360 #[test]
361 fn match_preflight_no_rule_for_bucket() {
362 let mgr = CorsManager::new();
363 let m = mgr.match_preflight("ghost", "https://anything", "GET", &[]);
364 assert!(m.is_none());
365 }
366
367 #[test]
368 fn match_preflight_method_not_allowed() {
369 let mgr = CorsManager::new();
370 mgr.put(
371 "b",
372 CorsConfig {
373 rules: vec![rule(&["*"], &["GET"], &["*"])],
374 },
375 );
376 assert!(
378 mgr.match_preflight("b", "https://x", "DELETE", &[])
379 .is_none()
380 );
381 assert!(mgr.match_preflight("b", "https://x", "GET", &[]).is_some());
383 }
384
385 #[test]
386 fn match_preflight_origin_not_allowed() {
387 let mgr = CorsManager::new();
388 mgr.put(
389 "b",
390 CorsConfig {
391 rules: vec![rule(&["https://good.example.com"], &["GET"], &["*"])],
392 },
393 );
394 assert!(
395 mgr.match_preflight("b", "https://evil.example.com", "GET", &[])
396 .is_none()
397 );
398 }
399
400 #[test]
401 fn match_preflight_wildcard_origin() {
402 let mgr = CorsManager::new();
403 mgr.put(
404 "b",
405 CorsConfig {
406 rules: vec![rule(&["*"], &["GET"], &[])],
407 },
408 );
409 let m = mgr.match_preflight("b", "https://anywhere", "GET", &[]);
410 assert!(m.is_some());
411 }
412
413 #[test]
414 fn match_preflight_wildcard_header() {
415 let mgr = CorsManager::new();
416 mgr.put(
417 "b",
418 CorsConfig {
419 rules: vec![rule(&["*"], &["PUT"], &["*"])],
420 },
421 );
422 let m = mgr.match_preflight(
423 "b",
424 "https://x",
425 "PUT",
426 &["X-Custom-Header".to_owned(), "Content-Type".to_owned()],
427 );
428 assert!(m.is_some());
429 }
430
431 #[test]
432 fn match_preflight_first_matching_rule_wins() {
433 let mgr = CorsManager::new();
434 mgr.put(
435 "b",
436 CorsConfig {
437 rules: vec![
438 CorsRule {
439 allowed_origins: vec!["*".into()],
440 allowed_methods: vec!["GET".into()],
441 allowed_headers: vec!["*".into()],
442 expose_headers: Vec::new(),
443 max_age_seconds: Some(60),
444 id: Some("first".into()),
445 },
446 CorsRule {
447 allowed_origins: vec!["*".into()],
448 allowed_methods: vec!["GET".into()],
449 allowed_headers: vec!["*".into()],
450 expose_headers: Vec::new(),
451 max_age_seconds: Some(7200),
452 id: Some("second".into()),
453 },
454 ],
455 },
456 );
457 let m = mgr
458 .match_preflight("b", "https://x", "GET", &[])
459 .expect("should match");
460 assert_eq!(m.id.as_deref(), Some("first"));
462 assert_eq!(m.max_age_seconds, Some(60));
463 }
464
465 #[test]
466 fn match_preflight_header_case_insensitive() {
467 let mgr = CorsManager::new();
468 mgr.put(
469 "b",
470 CorsConfig {
471 rules: vec![rule(&["*"], &["PUT"], &["Content-Type"])],
472 },
473 );
474 let m = mgr.match_preflight("b", "https://x", "PUT", &["content-type".to_owned()]);
477 assert!(m.is_some());
478 }
479
480 #[test]
481 fn put_replaces_previous_config() {
482 let mgr = CorsManager::new();
483 mgr.put(
484 "b",
485 CorsConfig {
486 rules: vec![rule(&["https://a"], &["GET"], &["*"])],
487 },
488 );
489 mgr.put(
490 "b",
491 CorsConfig {
492 rules: vec![rule(&["https://b"], &["PUT"], &["*"])],
493 },
494 );
495 let cfg = mgr.get("b").expect("config present");
496 assert_eq!(cfg.rules.len(), 1);
497 assert_eq!(cfg.rules[0].allowed_origins, vec!["https://b".to_string()]);
498 }
499
500 #[test]
501 fn delete_is_idempotent() {
502 let mgr = CorsManager::new();
503 mgr.delete("never-existed"); mgr.put(
505 "b",
506 CorsConfig {
507 rules: vec![rule(&["*"], &["GET"], &[])],
508 },
509 );
510 mgr.delete("b");
511 assert!(mgr.get("b").is_none());
512 }
513
514 #[test]
515 fn json_round_trip() {
516 let mgr = CorsManager::new();
517 mgr.put(
518 "b",
519 CorsConfig {
520 rules: vec![CorsRule {
521 allowed_origins: vec!["https://example.com".into()],
522 allowed_methods: vec!["GET".into(), "PUT".into()],
523 allowed_headers: vec!["Content-Type".into()],
524 expose_headers: vec!["ETag".into()],
525 max_age_seconds: Some(3600),
526 id: Some("rule-1".into()),
527 }],
528 },
529 );
530 let json = mgr.to_json().expect("to_json");
531 let mgr2 = CorsManager::from_json(&json).expect("from_json");
532 assert_eq!(mgr.get("b"), mgr2.get("b"));
533 }
534
535 #[test]
540 fn cors_to_json_after_panic_recovers_via_poison() {
541 let mgr = std::sync::Arc::new(CorsManager::new());
542 mgr.put(
543 "b",
544 CorsConfig {
545 rules: vec![rule(&["*"], &["GET"], &[])],
546 },
547 );
548 let mgr_cl = std::sync::Arc::clone(&mgr);
549 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
550 let mut g = mgr_cl.by_bucket.write().expect("clean lock");
551 g.entry("b2".into()).or_default();
552 panic!("force-poison");
553 }));
554 assert!(
555 mgr.by_bucket.is_poisoned(),
556 "write panic must poison by_bucket lock"
557 );
558 let json = mgr.to_json().expect("to_json after poison must succeed");
559 let mgr2 = CorsManager::from_json(&json).expect("from_json");
560 assert!(
561 mgr2.get("b").is_some(),
562 "recovered snapshot keeps original config"
563 );
564 }
565}