1use std::collections::HashMap;
45use std::sync::RwLock;
46
47use serde::{Deserialize, Serialize};
48use totp_rs::{Algorithm, TOTP};
49
50#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
56pub struct MfaSecret {
57 pub secret_base32: String,
60 pub serial: String,
65}
66
67#[derive(Debug, Default)]
73pub struct MfaDeleteManager {
74 default_secret: RwLock<Option<MfaSecret>>,
77 by_bucket: RwLock<HashMap<String, MfaSecret>>,
79 enabled: RwLock<HashMap<String, bool>>,
83}
84
85#[derive(Debug, Default, Serialize, Deserialize)]
88struct MfaSnapshot {
89 default_secret: Option<MfaSecret>,
90 by_bucket: HashMap<String, MfaSecret>,
91 enabled: HashMap<String, bool>,
92}
93
94impl MfaDeleteManager {
95 #[must_use]
98 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn set_default_secret(&self, secret: MfaSecret) {
106 *self
107 .default_secret
108 .write()
109 .expect("MFA default-secret RwLock poisoned") = Some(secret);
110 }
111
112 pub fn set_bucket_secret(&self, bucket: &str, secret: MfaSecret) {
114 self.by_bucket
115 .write()
116 .expect("MFA per-bucket RwLock poisoned")
117 .insert(bucket.to_owned(), secret);
118 }
119
120 pub fn set_bucket_state(&self, bucket: &str, enabled: bool) {
125 self.enabled
126 .write()
127 .expect("MFA enabled-state RwLock poisoned")
128 .insert(bucket.to_owned(), enabled);
129 }
130
131 #[must_use]
134 pub fn is_enabled(&self, bucket: &str) -> bool {
135 self.enabled
136 .read()
137 .expect("MFA enabled-state RwLock poisoned")
138 .get(bucket)
139 .copied()
140 .unwrap_or(false)
141 }
142
143 #[must_use]
147 pub fn lookup_secret(&self, bucket: &str) -> Option<MfaSecret> {
148 if let Some(s) = self
149 .by_bucket
150 .read()
151 .expect("MFA per-bucket RwLock poisoned")
152 .get(bucket)
153 .cloned()
154 {
155 return Some(s);
156 }
157 self.default_secret
158 .read()
159 .expect("MFA default-secret RwLock poisoned")
160 .clone()
161 }
162
163 pub fn to_json(&self) -> Result<String, serde_json::Error> {
166 let snap = MfaSnapshot {
167 default_secret: self
168 .default_secret
169 .read()
170 .expect("MFA default-secret RwLock poisoned")
171 .clone(),
172 by_bucket: self
173 .by_bucket
174 .read()
175 .expect("MFA per-bucket RwLock poisoned")
176 .clone(),
177 enabled: self
178 .enabled
179 .read()
180 .expect("MFA enabled-state RwLock poisoned")
181 .clone(),
182 };
183 serde_json::to_string(&snap)
184 }
185
186 pub fn from_json(s: &str) -> Result<Self, serde_json::Error> {
188 let snap: MfaSnapshot = serde_json::from_str(s)?;
189 Ok(Self {
190 default_secret: RwLock::new(snap.default_secret),
191 by_bucket: RwLock::new(snap.by_bucket),
192 enabled: RwLock::new(snap.enabled),
193 })
194 }
195}
196
197#[derive(Debug, thiserror::Error)]
199pub enum MfaError {
200 #[error("missing x-amz-mfa header (MFA Delete is Enabled on this bucket)")]
201 Missing,
202 #[error("malformed x-amz-mfa header")]
203 Malformed,
204 #[error("MFA serial does not match configured device")]
205 SerialMismatch,
206 #[error("invalid MFA code")]
207 InvalidCode,
208}
209
210pub fn parse_mfa_header(value: &str) -> Result<(String, String), MfaError> {
217 let mut parts = value.splitn(2, ' ');
218 let serial = parts.next().ok_or(MfaError::Malformed)?;
219 let code = parts.next().ok_or(MfaError::Malformed)?;
220 if serial.is_empty() || code.is_empty() {
221 return Err(MfaError::Malformed);
222 }
223 if value.split(' ').count() != 2 {
225 return Err(MfaError::Malformed);
226 }
227 if code.len() != 6 || !code.chars().all(|c| c.is_ascii_digit()) {
228 return Err(MfaError::Malformed);
229 }
230 Ok((serial.to_owned(), code.to_owned()))
231}
232
233#[must_use]
240pub fn verify_totp(secret_base32: &str, code: &str, now_unix_secs: u64) -> bool {
241 let Some(raw) = base32::decode(
242 base32::Alphabet::Rfc4648 { padding: false },
243 secret_base32,
244 ) else {
245 return false;
246 };
247 let Ok(totp) = TOTP::new(Algorithm::SHA1, 6, 1, 30, raw) else {
248 return false;
249 };
250 totp.check(code, now_unix_secs)
251}
252
253pub fn check_mfa(
259 bucket: &str,
260 header_value: Option<&str>,
261 manager: &MfaDeleteManager,
262 now_unix_secs: u64,
263) -> Result<(), MfaError> {
264 if !manager.is_enabled(bucket) {
265 return Ok(());
266 }
267 let header = header_value.ok_or(MfaError::Missing)?;
268 let (serial, code) = parse_mfa_header(header)?;
269 let secret = manager.lookup_secret(bucket).ok_or(MfaError::InvalidCode)?;
270 if serial != secret.serial {
271 return Err(MfaError::SerialMismatch);
272 }
273 if !verify_totp(&secret.secret_base32, &code, now_unix_secs) {
274 return Err(MfaError::InvalidCode);
275 }
276 Ok(())
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 const TEST_SECRET_B32: &str = "JBSWY3DPEHPK3PXPJBSWY3DPEHPK3PXP";
289
290 fn raw_secret() -> Vec<u8> {
291 base32::decode(base32::Alphabet::Rfc4648 { padding: false }, TEST_SECRET_B32)
292 .expect("decode test secret")
293 }
294
295 fn totp_at(time: u64) -> String {
296 let totp = TOTP::new(Algorithm::SHA1, 6, 1, 30, raw_secret()).expect("totp");
297 totp.generate(time)
298 }
299
300 #[test]
301 fn parse_mfa_header_happy_path() {
302 let (serial, code) = parse_mfa_header("SERIAL 123456").expect("parse");
303 assert_eq!(serial, "SERIAL");
304 assert_eq!(code, "123456");
305 }
306
307 #[test]
308 fn parse_mfa_header_rejects_no_space() {
309 let err = parse_mfa_header("SERIAL123456").expect_err("must fail");
310 assert!(matches!(err, MfaError::Malformed));
311 }
312
313 #[test]
314 fn parse_mfa_header_rejects_extra_token() {
315 let err = parse_mfa_header("SERIAL 123456 trailing").expect_err("must fail");
316 assert!(matches!(err, MfaError::Malformed));
317 }
318
319 #[test]
320 fn parse_mfa_header_rejects_non_digit_code() {
321 let err = parse_mfa_header("SERIAL 12345A").expect_err("must fail");
322 assert!(matches!(err, MfaError::Malformed));
323 }
324
325 #[test]
326 fn parse_mfa_header_rejects_wrong_length_code() {
327 for bad in ["SERIAL 12345", "SERIAL 1234567"] {
328 let err = parse_mfa_header(bad).expect_err("must fail");
329 assert!(matches!(err, MfaError::Malformed));
330 }
331 }
332
333 #[test]
334 fn parse_mfa_header_rejects_empty_serial_or_code() {
335 let err = parse_mfa_header(" 123456").expect_err("empty serial");
336 assert!(matches!(err, MfaError::Malformed));
337 let err = parse_mfa_header("SERIAL ").expect_err("empty code");
338 assert!(matches!(err, MfaError::Malformed));
339 }
340
341 #[test]
342 fn verify_totp_happy_path() {
343 let now = 1_700_000_000_u64;
344 let code = totp_at(now);
345 assert!(verify_totp(TEST_SECRET_B32, &code, now));
346 }
347
348 #[test]
349 fn verify_totp_clock_skew_within_one_step_ok() {
350 let now = 1_700_000_000_u64;
352 let code_prev = totp_at(now - 30);
353 assert!(
354 verify_totp(TEST_SECRET_B32, &code_prev, now),
355 "previous 30s window must validate"
356 );
357 let code_next = totp_at(now + 30);
358 assert!(
359 verify_totp(TEST_SECRET_B32, &code_next, now),
360 "next 30s window must validate"
361 );
362 }
363
364 #[test]
365 fn verify_totp_clock_skew_beyond_window_fails() {
366 let now = 1_700_000_000_u64;
369 let code_old = totp_at(now - 90);
370 assert!(!verify_totp(TEST_SECRET_B32, &code_old, now));
371 }
372
373 #[test]
374 fn verify_totp_wrong_code_fails() {
375 let now = 1_700_000_000_u64;
376 assert!(!verify_totp(TEST_SECRET_B32, "000000", now));
377 }
378
379 #[test]
380 fn verify_totp_short_secret_rejected() {
381 let short_b32 = "JBSWY3DP";
383 let now = 1_700_000_000_u64;
384 assert!(!verify_totp(short_b32, "000000", now));
385 }
386
387 #[test]
388 fn check_mfa_disabled_bucket_is_noop() {
389 let m = MfaDeleteManager::new();
390 assert!(check_mfa("b", None, &m, 0).is_ok());
393 assert!(check_mfa("b", Some("garbage"), &m, 0).is_ok());
394 }
395
396 #[test]
397 fn check_mfa_enabled_correct_code_ok() {
398 let m = MfaDeleteManager::new();
399 m.set_default_secret(MfaSecret {
400 secret_base32: TEST_SECRET_B32.to_owned(),
401 serial: "SERIAL-A".to_owned(),
402 });
403 m.set_bucket_state("b", true);
404 let now = 1_700_000_000_u64;
405 let code = totp_at(now);
406 let header = format!("SERIAL-A {code}");
407 assert!(check_mfa("b", Some(&header), &m, now).is_ok());
408 }
409
410 #[test]
411 fn check_mfa_enabled_wrong_code_fails() {
412 let m = MfaDeleteManager::new();
413 m.set_default_secret(MfaSecret {
414 secret_base32: TEST_SECRET_B32.to_owned(),
415 serial: "SERIAL-A".to_owned(),
416 });
417 m.set_bucket_state("b", true);
418 let now = 1_700_000_000_u64;
419 let err = check_mfa("b", Some("SERIAL-A 000000"), &m, now).expect_err("must fail");
420 assert!(matches!(err, MfaError::InvalidCode), "got {err:?}");
421 }
422
423 #[test]
424 fn check_mfa_enabled_missing_header_fails() {
425 let m = MfaDeleteManager::new();
426 m.set_default_secret(MfaSecret {
427 secret_base32: TEST_SECRET_B32.to_owned(),
428 serial: "SERIAL-A".to_owned(),
429 });
430 m.set_bucket_state("b", true);
431 let err = check_mfa("b", None, &m, 0).expect_err("must fail");
432 assert!(matches!(err, MfaError::Missing), "got {err:?}");
433 }
434
435 #[test]
436 fn check_mfa_enabled_serial_mismatch_fails() {
437 let m = MfaDeleteManager::new();
438 m.set_default_secret(MfaSecret {
439 secret_base32: TEST_SECRET_B32.to_owned(),
440 serial: "SERIAL-A".to_owned(),
441 });
442 m.set_bucket_state("b", true);
443 let now = 1_700_000_000_u64;
444 let code = totp_at(now);
445 let header = format!("SERIAL-OTHER {code}");
446 let err = check_mfa("b", Some(&header), &m, now).expect_err("must fail");
447 assert!(matches!(err, MfaError::SerialMismatch), "got {err:?}");
448 }
449
450 #[test]
451 fn check_mfa_per_bucket_override_takes_precedence() {
452 let m = MfaDeleteManager::new();
453 m.set_default_secret(MfaSecret {
454 secret_base32: TEST_SECRET_B32.to_owned(),
455 serial: "DEFAULT".to_owned(),
456 });
457 m.set_bucket_secret(
458 "b",
459 MfaSecret {
460 secret_base32: TEST_SECRET_B32.to_owned(),
461 serial: "BUCKET-OVERRIDE".to_owned(),
462 },
463 );
464 m.set_bucket_state("b", true);
465 let now = 1_700_000_000_u64;
466 let code = totp_at(now);
467 let header_default = format!("DEFAULT {code}");
469 assert!(matches!(
470 check_mfa("b", Some(&header_default), &m, now).expect_err("must fail"),
471 MfaError::SerialMismatch
472 ));
473 let header_override = format!("BUCKET-OVERRIDE {code}");
475 assert!(check_mfa("b", Some(&header_override), &m, now).is_ok());
476 }
477
478 #[test]
479 fn snapshot_roundtrip() {
480 let m = MfaDeleteManager::new();
481 m.set_default_secret(MfaSecret {
482 secret_base32: TEST_SECRET_B32.to_owned(),
483 serial: "DEFAULT".to_owned(),
484 });
485 m.set_bucket_secret(
486 "b1",
487 MfaSecret {
488 secret_base32: TEST_SECRET_B32.to_owned(),
489 serial: "B1-OVR".to_owned(),
490 },
491 );
492 m.set_bucket_state("b1", true);
493 m.set_bucket_state("b2", false);
494 let json = m.to_json().expect("to_json");
495 let m2 = MfaDeleteManager::from_json(&json).expect("from_json");
496 assert!(m2.is_enabled("b1"));
497 assert!(!m2.is_enabled("b2"));
498 let s = m2.lookup_secret("b1").expect("override survives");
499 assert_eq!(s.serial, "B1-OVR");
500 let s = m2.lookup_secret("other").expect("default survives");
502 assert_eq!(s.serial, "DEFAULT");
503 }
504}