1use std::collections::HashMap;
46use std::sync::RwLock;
47
48use serde::{Deserialize, Serialize};
49
50#[derive(Debug, thiserror::Error)]
52pub enum CorsValidationError {
53 #[error(
54 "AllowedMethod {0:?} is not a valid AWS S3 CORS verb (must be one of GET / PUT / POST / DELETE / HEAD; `*` is rejected)"
55 )]
56 UnsupportedMethod(String),
57}
58
59#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
66pub struct CorsRule {
67 pub allowed_origins: Vec<String>,
71 pub allowed_methods: Vec<String>,
75 pub allowed_headers: Vec<String>,
78 #[serde(default)]
81 pub expose_headers: Vec<String>,
82 #[serde(default)]
85 pub max_age_seconds: Option<u32>,
86 #[serde(default)]
88 pub id: Option<String>,
89}
90
91#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
93pub struct CorsConfig {
94 pub rules: Vec<CorsRule>,
96}
97
98#[derive(Debug, Default, Serialize, Deserialize)]
100struct CorsSnapshot {
101 by_bucket: HashMap<String, CorsConfig>,
102}
103
104#[derive(Debug, Default)]
110pub struct CorsManager {
111 by_bucket: RwLock<HashMap<String, CorsConfig>>,
112}
113
114impl CorsManager {
115 #[must_use]
117 pub fn new() -> Self {
118 Self::default()
119 }
120
121 pub fn put(&self, bucket: &str, config: CorsConfig) {
124 crate::lock_recovery::recover_write(&self.by_bucket, "cors.by_bucket")
125 .insert(bucket.to_owned(), config);
126 }
127
128 pub fn validate(config: &CorsConfig) -> Result<(), CorsValidationError> {
140 const VALID_METHODS: &[&str] = &["GET", "PUT", "POST", "DELETE", "HEAD"];
141 for rule in &config.rules {
142 for m in &rule.allowed_methods {
143 if !VALID_METHODS.contains(&m.as_str()) {
144 return Err(CorsValidationError::UnsupportedMethod(m.clone()));
145 }
146 }
147 }
148 Ok(())
149 }
150
151 #[must_use]
154 pub fn get(&self, bucket: &str) -> Option<CorsConfig> {
155 crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket")
156 .get(bucket)
157 .cloned()
158 }
159
160 pub fn delete(&self, bucket: &str) {
162 crate::lock_recovery::recover_write(&self.by_bucket, "cors.by_bucket").remove(bucket);
163 }
164
165 pub fn to_json(&self) -> Result<String, serde_json::Error> {
168 let snap = CorsSnapshot {
169 by_bucket: crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket")
170 .clone(),
171 };
172 serde_json::to_string(&snap)
173 }
174
175 pub fn from_json(s: &str) -> Result<Self, serde_json::Error> {
178 let snap: CorsSnapshot = serde_json::from_str(s)?;
179 Ok(Self {
180 by_bucket: RwLock::new(snap.by_bucket),
181 })
182 }
183
184 #[must_use]
193 pub fn match_preflight(
194 &self,
195 bucket: &str,
196 origin: &str,
197 method: &str,
198 request_headers: &[String],
199 ) -> Option<CorsRule> {
200 let map = crate::lock_recovery::recover_read(&self.by_bucket, "cors.by_bucket");
201 let cfg = map.get(bucket)?;
202 for rule in &cfg.rules {
203 if !rule_matches_origin(rule, origin) {
204 continue;
205 }
206 if !rule_matches_method(rule, method) {
207 continue;
208 }
209 if !rule_matches_headers(rule, request_headers) {
210 continue;
211 }
212 return Some(rule.clone());
213 }
214 None
215 }
216}
217
218fn rule_matches_origin(rule: &CorsRule, origin: &str) -> bool {
219 rule.allowed_origins
220 .iter()
221 .any(|pat| matches_glob(pat, origin))
222}
223
224fn rule_matches_method(rule: &CorsRule, method: &str) -> bool {
225 rule.allowed_methods
228 .iter()
229 .any(|pat| pat == "*" || pat == method)
230}
231
232fn rule_matches_headers(rule: &CorsRule, request_headers: &[String]) -> bool {
233 if request_headers.is_empty() {
234 return true;
235 }
236 request_headers.iter().all(|h| {
237 rule.allowed_headers
238 .iter()
239 .any(|pat| matches_glob_ci(pat, h))
240 })
241}
242
243#[must_use]
254pub fn matches_glob(pattern: &str, candidate: &str) -> bool {
255 if pattern == "*" {
256 return true;
257 }
258 pattern == candidate
259}
260
261#[must_use]
263pub fn matches_glob_ci(pattern: &str, candidate: &str) -> bool {
264 if pattern == "*" {
265 return true;
266 }
267 pattern.eq_ignore_ascii_case(candidate)
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 fn rule(origins: &[&str], methods: &[&str], headers: &[&str]) -> CorsRule {
275 CorsRule {
276 allowed_origins: origins.iter().map(|s| (*s).to_owned()).collect(),
277 allowed_methods: methods.iter().map(|s| (*s).to_owned()).collect(),
278 allowed_headers: headers.iter().map(|s| (*s).to_owned()).collect(),
279 expose_headers: Vec::new(),
280 max_age_seconds: Some(3600),
281 id: None,
282 }
283 }
284
285 #[test]
286 fn matches_glob_wildcard_matches_anything() {
287 assert!(matches_glob("*", "https://example.com"));
288 assert!(matches_glob("*", ""));
289 assert!(matches_glob("*", "GET"));
290 }
291
292 #[test]
293 fn matches_glob_exact_match() {
294 assert!(matches_glob("https://example.com", "https://example.com"));
295 assert!(matches_glob("GET", "GET"));
296 }
297
298 #[test]
299 fn matches_glob_no_match() {
300 assert!(!matches_glob("https://example.com", "https://evil.com"));
301 assert!(!matches_glob("GET", "PUT"));
302 }
303
304 #[test]
305 fn matches_glob_origin_is_case_sensitive() {
306 assert!(!matches_glob("https://Example.com", "https://example.com"));
308 }
309
310 #[test]
311 fn matches_glob_ci_header_is_case_insensitive() {
312 assert!(matches_glob_ci("Content-Type", "content-type"));
313 assert!(matches_glob_ci("X-Amz-Date", "x-amz-date"));
314 assert!(!matches_glob_ci("X-Other", "X-Different"));
315 }
316
317 #[test]
318 fn match_preflight_happy_path() {
319 let mgr = CorsManager::new();
320 mgr.put(
321 "b",
322 CorsConfig {
323 rules: vec![rule(
324 &["https://app.example.com"],
325 &["GET", "PUT"],
326 &["Content-Type"],
327 )],
328 },
329 );
330 let m = mgr.match_preflight(
331 "b",
332 "https://app.example.com",
333 "PUT",
334 &["Content-Type".to_owned()],
335 );
336 assert!(m.is_some());
337 let rule = m.unwrap();
338 assert_eq!(rule.max_age_seconds, Some(3600));
339 }
340
341 #[test]
342 fn match_preflight_no_rule_for_bucket() {
343 let mgr = CorsManager::new();
344 let m = mgr.match_preflight("ghost", "https://anything", "GET", &[]);
345 assert!(m.is_none());
346 }
347
348 #[test]
349 fn match_preflight_method_not_allowed() {
350 let mgr = CorsManager::new();
351 mgr.put(
352 "b",
353 CorsConfig {
354 rules: vec![rule(&["*"], &["GET"], &["*"])],
355 },
356 );
357 assert!(
359 mgr.match_preflight("b", "https://x", "DELETE", &[])
360 .is_none()
361 );
362 assert!(mgr.match_preflight("b", "https://x", "GET", &[]).is_some());
364 }
365
366 #[test]
367 fn match_preflight_origin_not_allowed() {
368 let mgr = CorsManager::new();
369 mgr.put(
370 "b",
371 CorsConfig {
372 rules: vec![rule(&["https://good.example.com"], &["GET"], &["*"])],
373 },
374 );
375 assert!(
376 mgr.match_preflight("b", "https://evil.example.com", "GET", &[])
377 .is_none()
378 );
379 }
380
381 #[test]
382 fn match_preflight_wildcard_origin() {
383 let mgr = CorsManager::new();
384 mgr.put(
385 "b",
386 CorsConfig {
387 rules: vec![rule(&["*"], &["GET"], &[])],
388 },
389 );
390 let m = mgr.match_preflight("b", "https://anywhere", "GET", &[]);
391 assert!(m.is_some());
392 }
393
394 #[test]
395 fn match_preflight_wildcard_header() {
396 let mgr = CorsManager::new();
397 mgr.put(
398 "b",
399 CorsConfig {
400 rules: vec![rule(&["*"], &["PUT"], &["*"])],
401 },
402 );
403 let m = mgr.match_preflight(
404 "b",
405 "https://x",
406 "PUT",
407 &["X-Custom-Header".to_owned(), "Content-Type".to_owned()],
408 );
409 assert!(m.is_some());
410 }
411
412 #[test]
413 fn match_preflight_first_matching_rule_wins() {
414 let mgr = CorsManager::new();
415 mgr.put(
416 "b",
417 CorsConfig {
418 rules: vec![
419 CorsRule {
420 allowed_origins: vec!["*".into()],
421 allowed_methods: vec!["GET".into()],
422 allowed_headers: vec!["*".into()],
423 expose_headers: Vec::new(),
424 max_age_seconds: Some(60),
425 id: Some("first".into()),
426 },
427 CorsRule {
428 allowed_origins: vec!["*".into()],
429 allowed_methods: vec!["GET".into()],
430 allowed_headers: vec!["*".into()],
431 expose_headers: Vec::new(),
432 max_age_seconds: Some(7200),
433 id: Some("second".into()),
434 },
435 ],
436 },
437 );
438 let m = mgr
439 .match_preflight("b", "https://x", "GET", &[])
440 .expect("should match");
441 assert_eq!(m.id.as_deref(), Some("first"));
443 assert_eq!(m.max_age_seconds, Some(60));
444 }
445
446 #[test]
447 fn match_preflight_header_case_insensitive() {
448 let mgr = CorsManager::new();
449 mgr.put(
450 "b",
451 CorsConfig {
452 rules: vec![rule(&["*"], &["PUT"], &["Content-Type"])],
453 },
454 );
455 let m = mgr.match_preflight("b", "https://x", "PUT", &["content-type".to_owned()]);
458 assert!(m.is_some());
459 }
460
461 #[test]
462 fn put_replaces_previous_config() {
463 let mgr = CorsManager::new();
464 mgr.put(
465 "b",
466 CorsConfig {
467 rules: vec![rule(&["https://a"], &["GET"], &["*"])],
468 },
469 );
470 mgr.put(
471 "b",
472 CorsConfig {
473 rules: vec![rule(&["https://b"], &["PUT"], &["*"])],
474 },
475 );
476 let cfg = mgr.get("b").expect("config present");
477 assert_eq!(cfg.rules.len(), 1);
478 assert_eq!(cfg.rules[0].allowed_origins, vec!["https://b".to_string()]);
479 }
480
481 #[test]
482 fn delete_is_idempotent() {
483 let mgr = CorsManager::new();
484 mgr.delete("never-existed"); mgr.put(
486 "b",
487 CorsConfig {
488 rules: vec![rule(&["*"], &["GET"], &[])],
489 },
490 );
491 mgr.delete("b");
492 assert!(mgr.get("b").is_none());
493 }
494
495 #[test]
496 fn json_round_trip() {
497 let mgr = CorsManager::new();
498 mgr.put(
499 "b",
500 CorsConfig {
501 rules: vec![CorsRule {
502 allowed_origins: vec!["https://example.com".into()],
503 allowed_methods: vec!["GET".into(), "PUT".into()],
504 allowed_headers: vec!["Content-Type".into()],
505 expose_headers: vec!["ETag".into()],
506 max_age_seconds: Some(3600),
507 id: Some("rule-1".into()),
508 }],
509 },
510 );
511 let json = mgr.to_json().expect("to_json");
512 let mgr2 = CorsManager::from_json(&json).expect("from_json");
513 assert_eq!(mgr.get("b"), mgr2.get("b"));
514 }
515
516 #[test]
521 fn cors_to_json_after_panic_recovers_via_poison() {
522 let mgr = std::sync::Arc::new(CorsManager::new());
523 mgr.put(
524 "b",
525 CorsConfig {
526 rules: vec![rule(&["*"], &["GET"], &[])],
527 },
528 );
529 let mgr_cl = std::sync::Arc::clone(&mgr);
530 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
531 let mut g = mgr_cl.by_bucket.write().expect("clean lock");
532 g.entry("b2".into()).or_default();
533 panic!("force-poison");
534 }));
535 assert!(
536 mgr.by_bucket.is_poisoned(),
537 "write panic must poison by_bucket lock"
538 );
539 let json = mgr.to_json().expect("to_json after poison must succeed");
540 let mgr2 = CorsManager::from_json(&json).expect("from_json");
541 assert!(
542 mgr2.get("b").is_some(),
543 "recovered snapshot keeps original config"
544 );
545 }
546}