1use std::collections::HashMap;
122use std::path::Path;
123use std::sync::Arc;
124
125use aes_gcm::aead::{Aead, KeyInit, Payload};
126use aes_gcm::{Aes256Gcm, Key, Nonce};
127use bytes::Bytes;
128use md5::{Digest as Md5Digest, Md5};
129use rand::RngCore;
130use thiserror::Error;
131
132use crate::kms::{KmsBackend, KmsError, WrappedDek};
133
134pub const SSE_MAGIC_V1: &[u8; 4] = b"S4E1";
135pub const SSE_MAGIC_V2: &[u8; 4] = b"S4E2";
136pub const SSE_MAGIC_V3: &[u8; 4] = b"S4E3";
137pub const SSE_MAGIC_V4: &[u8; 4] = b"S4E4";
138pub const SSE_MAGIC_V5: &[u8; 4] = b"S4E5";
145pub const SSE_MAGIC: &[u8; 4] = SSE_MAGIC_V1;
147
148pub const SSE_HEADER_BYTES: usize = 4 + 1 + 3 + 12 + 16; pub const SSE_HEADER_BYTES_V3: usize = 4 + 1 + KEY_MD5_LEN + 12 + 16; pub const ALGO_AES_256_GCM: u8 = 1;
159const NONCE_LEN: usize = 12;
160const TAG_LEN: usize = 16;
161const KEY_LEN: usize = 32;
162const KEY_MD5_LEN: usize = 16;
163pub const SSE_C_ALGORITHM: &str = "AES256";
167
168#[derive(Debug, Error)]
169pub enum SseError {
170 #[error("SSE key file {path:?}: {source}")]
171 KeyFileIo {
172 path: std::path::PathBuf,
173 source: std::io::Error,
174 },
175 #[error(
176 "SSE key file must be exactly 32 raw bytes (or 64-char hex / 44-char base64); got {got} bytes after parse"
177 )]
178 BadKeyLength { got: usize },
179 #[error("SSE-encrypted body too short ({got} bytes; need at least {SSE_HEADER_BYTES})")]
180 TooShort { got: usize },
181 #[error("SSE bad magic: expected S4E1/S4E2/S4E3/S4E4/S4E5, got {got:?}")]
182 BadMagic { got: [u8; 4] },
183 #[error("SSE unsupported algo tag: {tag} (this build only knows AES-256-GCM = 1)")]
184 UnsupportedAlgo { tag: u8 },
185 #[error(
186 "SSE key_id {id} (S4E2 frame) not present in keyring; rotation history likely incomplete"
187 )]
188 KeyNotInKeyring { id: u16 },
189 #[error("SSE decryption / authentication failed (key mismatch or ciphertext tampered with)")]
190 DecryptFailed,
191 #[error("SSE-C key MD5 fingerprint mismatch — client supplied a different key than PUT")]
199 WrongCustomerKey,
200 #[error("SSE-C customer-key headers invalid: {reason}")]
205 InvalidCustomerKey { reason: &'static str },
206 #[error("SSE-C algorithm {algo:?} unsupported (only {SSE_C_ALGORITHM:?} is allowed)")]
210 CustomerKeyAlgorithmUnsupported { algo: String },
211 #[error("S4E3 frame requires SseSource::CustomerKey; got Keyring")]
216 CustomerKeyRequired,
217 #[error("S4E1/S4E2 frame stored without SSE-C; SseSource::CustomerKey is unexpected")]
222 CustomerKeyUnexpected,
223 #[error(
230 "S4E4 (SSE-KMS) body requires async decrypt — call decrypt_with_kms() instead of decrypt()"
231 )]
232 KmsAsyncRequired,
233 #[error("S4E4 frame too short ({got} bytes; need at least {min})")]
237 KmsFrameTooShort { got: usize, min: usize },
238 #[error("S4E4 frame field length out of bounds: {what}")]
243 KmsFrameFieldOob { what: &'static str },
244 #[error("S4E4 key_id is not valid UTF-8")]
249 KmsKeyIdNotUtf8,
250 #[error(
257 "S4E4 SseSource::Kms wrapped DEK key_id {supplied:?} doesn't match frame key_id {stored:?}"
258 )]
259 KmsWrappedDekMismatch {
260 supplied: String,
261 stored: String,
262 },
263 #[error("S4E4 frame requires SseSource::Kms")]
270 KmsRequired,
271 #[error("KMS unwrap: {0}")]
274 KmsBackend(#[from] KmsError),
275 #[error("S4E5 chunk {chunk_index} auth tag verify failed (key mismatch or chunk tampered with)")]
284 ChunkAuthFailed { chunk_index: u32 },
285 #[error("S4E5 chunk_size must be > 0 (got 0)")]
290 ChunkSizeInvalid,
291 #[error("S4E5 frame truncated: {what}")]
297 ChunkFrameTruncated { what: &'static str },
298}
299
300pub struct SseKey {
305 pub bytes: [u8; 32],
306}
307
308impl SseKey {
309 pub fn from_path(path: &Path) -> Result<Self, SseError> {
313 let raw = std::fs::read(path).map_err(|source| SseError::KeyFileIo {
314 path: path.to_path_buf(),
315 source,
316 })?;
317 Self::from_bytes(&raw)
318 }
319
320 pub fn from_bytes(bytes: &[u8]) -> Result<Self, SseError> {
321 if bytes.len() == KEY_LEN {
323 let mut k = [0u8; KEY_LEN];
324 k.copy_from_slice(bytes);
325 return Ok(Self { bytes: k });
326 }
327 let s = std::str::from_utf8(bytes).unwrap_or("").trim();
329 if s.len() == KEY_LEN * 2 && s.chars().all(|c| c.is_ascii_hexdigit()) {
330 let mut k = [0u8; KEY_LEN];
331 for (i, k_byte) in k.iter_mut().enumerate() {
332 *k_byte = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
333 .map_err(|_| SseError::BadKeyLength { got: bytes.len() })?;
334 }
335 return Ok(Self { bytes: k });
336 }
337 if let Ok(decoded) =
338 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, s.as_bytes())
339 && decoded.len() == KEY_LEN
340 {
341 let mut k = [0u8; KEY_LEN];
342 k.copy_from_slice(&decoded);
343 return Ok(Self { bytes: k });
344 }
345 Err(SseError::BadKeyLength { got: bytes.len() })
346 }
347
348 fn as_aes_key(&self) -> &Key<Aes256Gcm> {
349 Key::<Aes256Gcm>::from_slice(&self.bytes)
350 }
351}
352
353impl std::fmt::Debug for SseKey {
354 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355 f.debug_struct("SseKey")
356 .field("len", &KEY_LEN)
357 .field("key", &"<redacted>")
358 .finish()
359 }
360}
361
362#[derive(Clone)]
367pub struct SseKeyring {
368 active: u16,
369 keys: HashMap<u16, Arc<SseKey>>,
370}
371
372impl SseKeyring {
373 pub fn new(active: u16, key: Arc<SseKey>) -> Self {
377 let mut keys = HashMap::new();
378 keys.insert(active, key);
379 Self { active, keys }
380 }
381
382 pub fn add(&mut self, id: u16, key: Arc<SseKey>) {
386 self.keys.insert(id, key);
387 }
388
389 pub fn active(&self) -> (u16, &SseKey) {
392 let id = self.active;
393 let key = self
394 .keys
395 .get(&id)
396 .expect("active key id must be present in keyring (constructor invariant)");
397 (id, key.as_ref())
398 }
399
400 pub fn get(&self, id: u16) -> Option<&SseKey> {
403 self.keys.get(&id).map(Arc::as_ref)
404 }
405}
406
407impl std::fmt::Debug for SseKeyring {
408 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409 f.debug_struct("SseKeyring")
410 .field("active", &self.active)
411 .field("key_count", &self.keys.len())
412 .field("key_ids", &self.keys.keys().collect::<Vec<_>>())
413 .finish()
414 }
415}
416
417pub type SharedSseKeyring = Arc<SseKeyring>;
418
419pub fn encrypt(key: &SseKey, plaintext: &[u8]) -> Bytes {
426 let cipher = Aes256Gcm::new(key.as_aes_key());
427 let mut nonce_bytes = [0u8; NONCE_LEN];
428 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
429 let nonce = Nonce::from_slice(&nonce_bytes);
430 let mut aad = [0u8; 8];
432 aad[..4].copy_from_slice(SSE_MAGIC_V1);
433 aad[4] = ALGO_AES_256_GCM;
434 let ct_with_tag = cipher
435 .encrypt(
436 nonce,
437 Payload {
438 msg: plaintext,
439 aad: &aad,
440 },
441 )
442 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
443 debug_assert!(ct_with_tag.len() >= TAG_LEN);
444 let split = ct_with_tag.len() - TAG_LEN;
445 let (ct, tag) = ct_with_tag.split_at(split);
446
447 let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
448 out.extend_from_slice(SSE_MAGIC_V1);
449 out.push(ALGO_AES_256_GCM);
450 out.extend_from_slice(&[0u8; 3]); out.extend_from_slice(&nonce_bytes);
452 out.extend_from_slice(tag);
453 out.extend_from_slice(ct);
454 Bytes::from(out)
455}
456
457pub fn encrypt_v2(plaintext: &[u8], keyring: &SseKeyring) -> Bytes {
462 let (key_id, key) = keyring.active();
463 let cipher = Aes256Gcm::new(key.as_aes_key());
464 let mut nonce_bytes = [0u8; NONCE_LEN];
465 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
466 let nonce = Nonce::from_slice(&nonce_bytes);
467 let aad = aad_v2(key_id);
468 let ct_with_tag = cipher
469 .encrypt(
470 nonce,
471 Payload {
472 msg: plaintext,
473 aad: &aad,
474 },
475 )
476 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
477 debug_assert!(ct_with_tag.len() >= TAG_LEN);
478 let split = ct_with_tag.len() - TAG_LEN;
479 let (ct, tag) = ct_with_tag.split_at(split);
480
481 let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
482 out.extend_from_slice(SSE_MAGIC_V2);
483 out.push(ALGO_AES_256_GCM);
484 out.extend_from_slice(&key_id.to_be_bytes()); out.push(0u8); out.extend_from_slice(&nonce_bytes);
487 out.extend_from_slice(tag);
488 out.extend_from_slice(ct);
489 Bytes::from(out)
490}
491
492fn aad_v1() -> [u8; 8] {
493 let mut aad = [0u8; 8];
494 aad[..4].copy_from_slice(SSE_MAGIC_V1);
495 aad[4] = ALGO_AES_256_GCM;
496 aad
497}
498
499fn aad_v2(key_id: u16) -> [u8; 8] {
500 let mut aad = [0u8; 8];
501 aad[..4].copy_from_slice(SSE_MAGIC_V2);
502 aad[4] = ALGO_AES_256_GCM;
503 aad[5..7].copy_from_slice(&key_id.to_be_bytes());
504 aad[7] = 0u8;
505 aad
506}
507
508fn aad_v3(key_md5: &[u8; KEY_MD5_LEN]) -> [u8; 4 + 1 + KEY_MD5_LEN] {
514 let mut aad = [0u8; 4 + 1 + KEY_MD5_LEN];
515 aad[..4].copy_from_slice(SSE_MAGIC_V3);
516 aad[4] = ALGO_AES_256_GCM;
517 aad[5..5 + KEY_MD5_LEN].copy_from_slice(key_md5);
518 aad
519}
520
521#[derive(Clone)]
527pub struct CustomerKeyMaterial {
528 pub key: [u8; KEY_LEN],
529 pub key_md5: [u8; KEY_MD5_LEN],
530}
531
532impl std::fmt::Debug for CustomerKeyMaterial {
533 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
534 f.debug_struct("CustomerKeyMaterial")
537 .field("key", &"<redacted>")
538 .field("key_md5_hex", &hex_lower(&self.key_md5))
539 .finish()
540 }
541}
542
543fn hex_lower(bytes: &[u8]) -> String {
544 let mut s = String::with_capacity(bytes.len() * 2);
545 for b in bytes {
546 s.push_str(&format!("{b:02x}"));
547 }
548 s
549}
550
551#[derive(Debug, Clone, Copy)]
559pub enum SseSource<'a> {
560 Keyring(&'a SseKeyring),
563 CustomerKey {
567 key: &'a [u8; KEY_LEN],
568 key_md5: &'a [u8; KEY_MD5_LEN],
569 },
570 Kms {
576 dek: &'a [u8; KEY_LEN],
578 wrapped: &'a WrappedDek,
581 },
582}
583
584impl<'a> From<&'a SseKeyring> for SseSource<'a> {
591 fn from(kr: &'a SseKeyring) -> Self {
592 SseSource::Keyring(kr)
593 }
594}
595
596impl<'a> From<&'a Arc<SseKeyring>> for SseSource<'a> {
600 fn from(kr: &'a Arc<SseKeyring>) -> Self {
601 SseSource::Keyring(kr.as_ref())
602 }
603}
604
605impl<'a> From<&'a CustomerKeyMaterial> for SseSource<'a> {
606 fn from(m: &'a CustomerKeyMaterial) -> Self {
607 SseSource::CustomerKey {
608 key: &m.key,
609 key_md5: &m.key_md5,
610 }
611 }
612}
613
614pub fn parse_customer_key_headers(
626 algorithm: &str,
627 key_base64: &str,
628 key_md5_base64: &str,
629) -> Result<CustomerKeyMaterial, SseError> {
630 use base64::Engine as _;
631 if algorithm != SSE_C_ALGORITHM {
632 return Err(SseError::CustomerKeyAlgorithmUnsupported {
633 algo: algorithm.to_string(),
634 });
635 }
636 let key_bytes = base64::engine::general_purpose::STANDARD
637 .decode(key_base64.trim().as_bytes())
638 .map_err(|_| SseError::InvalidCustomerKey {
639 reason: "base64 decode of key",
640 })?;
641 if key_bytes.len() != KEY_LEN {
642 return Err(SseError::InvalidCustomerKey {
643 reason: "key length (must be 32 bytes after base64 decode)",
644 });
645 }
646 let supplied_md5 = base64::engine::general_purpose::STANDARD
647 .decode(key_md5_base64.trim().as_bytes())
648 .map_err(|_| SseError::InvalidCustomerKey {
649 reason: "base64 decode of key MD5",
650 })?;
651 if supplied_md5.len() != KEY_MD5_LEN {
652 return Err(SseError::InvalidCustomerKey {
653 reason: "key MD5 length (must be 16 bytes after base64 decode)",
654 });
655 }
656 let actual_md5 = compute_key_md5(&key_bytes);
657 if !constant_time_eq(&actual_md5, &supplied_md5) {
660 return Err(SseError::InvalidCustomerKey {
661 reason: "supplied MD5 does not match MD5 of supplied key",
662 });
663 }
664 let mut key = [0u8; KEY_LEN];
665 key.copy_from_slice(&key_bytes);
666 let mut key_md5 = [0u8; KEY_MD5_LEN];
667 key_md5.copy_from_slice(&actual_md5);
668 Ok(CustomerKeyMaterial { key, key_md5 })
669}
670
671pub fn compute_key_md5(key: &[u8]) -> [u8; KEY_MD5_LEN] {
676 let mut h = Md5::new();
677 h.update(key);
678 let out = h.finalize();
679 let mut md5 = [0u8; KEY_MD5_LEN];
680 md5.copy_from_slice(&out);
681 md5
682}
683
684fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
687 if a.len() != b.len() {
688 return false;
689 }
690 let mut acc: u8 = 0;
691 for (x, y) in a.iter().zip(b.iter()) {
692 acc |= x ^ y;
693 }
694 acc == 0
695}
696
697pub fn encrypt_with_source(plaintext: &[u8], source: SseSource<'_>) -> Bytes {
707 match source {
708 SseSource::Keyring(kr) => encrypt_v2(plaintext, kr),
709 SseSource::CustomerKey { key, key_md5 } => encrypt_v3(plaintext, key, key_md5),
710 SseSource::Kms { dek, wrapped } => encrypt_v4(plaintext, dek, wrapped),
711 }
712}
713
714fn encrypt_v3(
715 plaintext: &[u8],
716 key: &[u8; KEY_LEN],
717 key_md5: &[u8; KEY_MD5_LEN],
718) -> Bytes {
719 let aes_key = Key::<Aes256Gcm>::from_slice(key);
720 let cipher = Aes256Gcm::new(aes_key);
721 let mut nonce_bytes = [0u8; NONCE_LEN];
722 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
723 let nonce = Nonce::from_slice(&nonce_bytes);
724 let aad = aad_v3(key_md5);
725 let ct_with_tag = cipher
726 .encrypt(
727 nonce,
728 Payload {
729 msg: plaintext,
730 aad: &aad,
731 },
732 )
733 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
734 debug_assert!(ct_with_tag.len() >= TAG_LEN);
735 let split = ct_with_tag.len() - TAG_LEN;
736 let (ct, tag) = ct_with_tag.split_at(split);
737
738 let mut out = Vec::with_capacity(SSE_HEADER_BYTES_V3 + ct.len());
739 out.extend_from_slice(SSE_MAGIC_V3);
740 out.push(ALGO_AES_256_GCM);
741 out.extend_from_slice(key_md5);
742 out.extend_from_slice(&nonce_bytes);
743 out.extend_from_slice(tag);
744 out.extend_from_slice(ct);
745 Bytes::from(out)
746}
747
748pub fn decrypt<'a, S: Into<SseSource<'a>>>(body: &[u8], source: S) -> Result<Bytes, SseError> {
767 let source = source.into();
768 if body.len() < SSE_HEADER_BYTES {
774 return Err(SseError::TooShort { got: body.len() });
775 }
776 let mut magic = [0u8; 4];
777 magic.copy_from_slice(&body[..4]);
778 match &magic {
779 m if m == SSE_MAGIC_V1 || m == SSE_MAGIC_V2 => {
780 let keyring = match source {
781 SseSource::Keyring(kr) => kr,
782 SseSource::CustomerKey { .. } => return Err(SseError::CustomerKeyUnexpected),
783 SseSource::Kms { .. } => return Err(SseError::CustomerKeyUnexpected),
789 };
790 if m == SSE_MAGIC_V1 {
791 decrypt_v1_with_keyring(body, keyring)
792 } else {
793 decrypt_v2_with_keyring(body, keyring)
794 }
795 }
796 m if m == SSE_MAGIC_V3 => {
797 if body.len() < SSE_HEADER_BYTES_V3 {
799 return Err(SseError::TooShort { got: body.len() });
800 }
801 let (key, key_md5) = match source {
802 SseSource::CustomerKey { key, key_md5 } => (key, key_md5),
803 SseSource::Keyring(_) => return Err(SseError::CustomerKeyRequired),
804 SseSource::Kms { .. } => return Err(SseError::CustomerKeyRequired),
805 };
806 decrypt_v3(body, key, key_md5)
807 }
808 m if m == SSE_MAGIC_V4 => {
809 Err(SseError::KmsAsyncRequired)
814 }
815 m if m == SSE_MAGIC_V5 => {
816 let keyring = match source {
823 SseSource::Keyring(kr) => kr,
824 SseSource::CustomerKey { .. } => {
825 return Err(SseError::CustomerKeyUnexpected);
826 }
827 SseSource::Kms { .. } => return Err(SseError::CustomerKeyUnexpected),
828 };
829 decrypt_v5_buffered(body, keyring)
830 }
831 _ => Err(SseError::BadMagic { got: magic }),
832 }
833}
834
835fn decrypt_v3(
836 body: &[u8],
837 key: &[u8; KEY_LEN],
838 supplied_md5: &[u8; KEY_MD5_LEN],
839) -> Result<Bytes, SseError> {
840 let algo = body[4];
841 if algo != ALGO_AES_256_GCM {
842 return Err(SseError::UnsupportedAlgo { tag: algo });
843 }
844 let mut stored_md5 = [0u8; KEY_MD5_LEN];
845 stored_md5.copy_from_slice(&body[5..5 + KEY_MD5_LEN]);
846 if !constant_time_eq(supplied_md5, &stored_md5) {
852 return Err(SseError::WrongCustomerKey);
853 }
854 let nonce_off = 5 + KEY_MD5_LEN;
855 let tag_off = nonce_off + NONCE_LEN;
856 let mut nonce_bytes = [0u8; NONCE_LEN];
857 nonce_bytes.copy_from_slice(&body[nonce_off..nonce_off + NONCE_LEN]);
858 let mut tag_bytes = [0u8; TAG_LEN];
859 tag_bytes.copy_from_slice(&body[tag_off..tag_off + TAG_LEN]);
860 let ct = &body[SSE_HEADER_BYTES_V3..];
861
862 let aad = aad_v3(&stored_md5);
863 let nonce = Nonce::from_slice(&nonce_bytes);
864 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
865 ct_with_tag.extend_from_slice(ct);
866 ct_with_tag.extend_from_slice(&tag_bytes);
867
868 let aes_key = Key::<Aes256Gcm>::from_slice(key);
869 let cipher = Aes256Gcm::new(aes_key);
870 let plain = cipher
871 .decrypt(
872 nonce,
873 Payload {
874 msg: &ct_with_tag,
875 aad: &aad,
876 },
877 )
878 .map_err(|_| SseError::DecryptFailed)?;
879 Ok(Bytes::from(plain))
880}
881
882fn aad_v4(key_id: &[u8], wrapped_dek: &[u8]) -> Vec<u8> {
893 let mut aad = Vec::with_capacity(4 + 1 + 1 + key_id.len() + 4 + wrapped_dek.len());
894 aad.extend_from_slice(SSE_MAGIC_V4);
895 aad.push(ALGO_AES_256_GCM);
896 aad.push(key_id.len() as u8);
897 aad.extend_from_slice(key_id);
898 aad.extend_from_slice(&(wrapped_dek.len() as u32).to_be_bytes());
899 aad.extend_from_slice(wrapped_dek);
900 aad
901}
902
903fn encrypt_v4(plaintext: &[u8], dek: &[u8; KEY_LEN], wrapped: &WrappedDek) -> Bytes {
904 assert!(
912 !wrapped.key_id.is_empty() && wrapped.key_id.len() <= u8::MAX as usize,
913 "S4E4 key_id must be 1..=255 bytes (got {})",
914 wrapped.key_id.len()
915 );
916 assert!(
917 wrapped.ciphertext.len() <= u32::MAX as usize,
918 "S4E4 wrapped_dek longer than u32::MAX",
919 );
920
921 let aes_key = Key::<Aes256Gcm>::from_slice(dek);
922 let cipher = Aes256Gcm::new(aes_key);
923 let mut nonce_bytes = [0u8; NONCE_LEN];
924 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
925 let nonce = Nonce::from_slice(&nonce_bytes);
926 let aad = aad_v4(wrapped.key_id.as_bytes(), &wrapped.ciphertext);
927 let ct_with_tag = cipher
928 .encrypt(
929 nonce,
930 Payload {
931 msg: plaintext,
932 aad: &aad,
933 },
934 )
935 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
936 debug_assert!(ct_with_tag.len() >= TAG_LEN);
937 let split = ct_with_tag.len() - TAG_LEN;
938 let (ct, tag) = ct_with_tag.split_at(split);
939
940 let key_id_bytes = wrapped.key_id.as_bytes();
941 let mut out = Vec::with_capacity(
942 4 + 1 + 1 + key_id_bytes.len() + 4 + wrapped.ciphertext.len() + NONCE_LEN + TAG_LEN + ct.len(),
943 );
944 out.extend_from_slice(SSE_MAGIC_V4);
945 out.push(ALGO_AES_256_GCM);
946 out.push(key_id_bytes.len() as u8);
947 out.extend_from_slice(key_id_bytes);
948 out.extend_from_slice(&(wrapped.ciphertext.len() as u32).to_be_bytes());
949 out.extend_from_slice(&wrapped.ciphertext);
950 out.extend_from_slice(&nonce_bytes);
951 out.extend_from_slice(tag);
952 out.extend_from_slice(ct);
953 Bytes::from(out)
954}
955
956#[derive(Debug)]
962pub struct S4E4Header<'a> {
963 pub key_id: &'a str,
964 pub wrapped_dek: &'a [u8],
965 pub nonce: &'a [u8],
966 pub tag: &'a [u8],
967 pub ciphertext: &'a [u8],
968}
969
970pub fn parse_s4e4_header(body: &[u8]) -> Result<S4E4Header<'_>, SseError> {
974 const S4E4_MIN: usize = 4 + 1 + 1 + 4 + NONCE_LEN + TAG_LEN; if body.len() < S4E4_MIN {
981 return Err(SseError::KmsFrameTooShort {
982 got: body.len(),
983 min: S4E4_MIN,
984 });
985 }
986 let magic = &body[..4];
987 if magic != SSE_MAGIC_V4 {
988 let mut got = [0u8; 4];
989 got.copy_from_slice(magic);
990 return Err(SseError::BadMagic { got });
991 }
992 let algo = body[4];
993 if algo != ALGO_AES_256_GCM {
994 return Err(SseError::UnsupportedAlgo { tag: algo });
995 }
996 let key_id_len = body[5] as usize;
997 let key_id_off: usize = 6;
998 let key_id_end = key_id_off
999 .checked_add(key_id_len)
1000 .ok_or(SseError::KmsFrameFieldOob { what: "key_id_len" })?;
1001 if key_id_end + 4 > body.len() {
1002 return Err(SseError::KmsFrameFieldOob { what: "key_id" });
1003 }
1004 let key_id = std::str::from_utf8(&body[key_id_off..key_id_end])
1005 .map_err(|_| SseError::KmsKeyIdNotUtf8)?;
1006 let wrapped_len_off = key_id_end;
1007 let wrapped_dek_len = u32::from_be_bytes([
1008 body[wrapped_len_off],
1009 body[wrapped_len_off + 1],
1010 body[wrapped_len_off + 2],
1011 body[wrapped_len_off + 3],
1012 ]) as usize;
1013 let wrapped_off = wrapped_len_off + 4;
1014 let wrapped_end = wrapped_off
1015 .checked_add(wrapped_dek_len)
1016 .ok_or(SseError::KmsFrameFieldOob { what: "wrapped_dek_len" })?;
1017 if wrapped_end + NONCE_LEN + TAG_LEN > body.len() {
1018 return Err(SseError::KmsFrameFieldOob { what: "wrapped_dek" });
1019 }
1020 let wrapped_dek = &body[wrapped_off..wrapped_end];
1021 let nonce_off = wrapped_end;
1022 let tag_off = nonce_off + NONCE_LEN;
1023 let ct_off = tag_off + TAG_LEN;
1024 let nonce = &body[nonce_off..nonce_off + NONCE_LEN];
1025 let tag = &body[tag_off..tag_off + TAG_LEN];
1026 let ciphertext = &body[ct_off..];
1027 Ok(S4E4Header {
1028 key_id,
1029 wrapped_dek,
1030 nonce,
1031 tag,
1032 ciphertext,
1033 })
1034}
1035
1036pub async fn decrypt_with_kms(
1052 body: &[u8],
1053 kms: &dyn KmsBackend,
1054) -> Result<Bytes, SseError> {
1055 let hdr = parse_s4e4_header(body)?;
1056 let wrapped = WrappedDek {
1057 key_id: hdr.key_id.to_string(),
1058 ciphertext: hdr.wrapped_dek.to_vec(),
1059 };
1060 let dek_vec = kms.decrypt_dek(&wrapped).await?;
1061 if dek_vec.len() != KEY_LEN {
1062 return Err(SseError::KmsBackend(KmsError::BackendUnavailable {
1067 message: format!(
1068 "KMS returned {} byte DEK; expected {KEY_LEN}",
1069 dek_vec.len()
1070 ),
1071 }));
1072 }
1073 let mut dek = [0u8; KEY_LEN];
1074 dek.copy_from_slice(&dek_vec);
1075
1076 let aad = aad_v4(hdr.key_id.as_bytes(), hdr.wrapped_dek);
1077 let aes_key = Key::<Aes256Gcm>::from_slice(&dek);
1078 let cipher = Aes256Gcm::new(aes_key);
1079 let nonce = Nonce::from_slice(hdr.nonce);
1080 let mut ct_with_tag = Vec::with_capacity(hdr.ciphertext.len() + TAG_LEN);
1081 ct_with_tag.extend_from_slice(hdr.ciphertext);
1082 ct_with_tag.extend_from_slice(hdr.tag);
1083 let plain = cipher
1084 .decrypt(
1085 nonce,
1086 Payload {
1087 msg: &ct_with_tag,
1088 aad: &aad,
1089 },
1090 )
1091 .map_err(|_| SseError::DecryptFailed)?;
1092 Ok(Bytes::from(plain))
1093}
1094
1095fn decrypt_v1_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
1096 let algo = body[4];
1097 if algo != ALGO_AES_256_GCM {
1098 return Err(SseError::UnsupportedAlgo { tag: algo });
1099 }
1100 let mut nonce_bytes = [0u8; NONCE_LEN];
1103 nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
1104 let mut tag_bytes = [0u8; TAG_LEN];
1105 tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
1106 let ct = &body[SSE_HEADER_BYTES..];
1107
1108 let aad = aad_v1();
1109 let nonce = Nonce::from_slice(&nonce_bytes);
1110 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
1111 ct_with_tag.extend_from_slice(ct);
1112 ct_with_tag.extend_from_slice(&tag_bytes);
1113
1114 let (active_id, _active_key) = keyring.active();
1118 let mut ids: Vec<u16> = keyring.keys.keys().copied().collect();
1119 ids.sort_by_key(|id| if *id == active_id { 0 } else { 1 });
1120 for id in ids {
1121 let key = keyring.get(id).expect("id came from keyring iteration");
1122 let cipher = Aes256Gcm::new(key.as_aes_key());
1123 if let Ok(plain) = cipher.decrypt(
1124 nonce,
1125 Payload {
1126 msg: &ct_with_tag,
1127 aad: &aad,
1128 },
1129 ) {
1130 return Ok(Bytes::from(plain));
1131 }
1132 }
1133 Err(SseError::DecryptFailed)
1134}
1135
1136fn decrypt_v2_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
1137 let algo = body[4];
1138 if algo != ALGO_AES_256_GCM {
1139 return Err(SseError::UnsupportedAlgo { tag: algo });
1140 }
1141 let key_id = u16::from_be_bytes([body[5], body[6]]);
1142 let key = keyring
1144 .get(key_id)
1145 .ok_or(SseError::KeyNotInKeyring { id: key_id })?;
1146 let mut nonce_bytes = [0u8; NONCE_LEN];
1147 nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
1148 let mut tag_bytes = [0u8; TAG_LEN];
1149 tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
1150 let ct = &body[SSE_HEADER_BYTES..];
1151
1152 let aad = aad_v2(key_id);
1153 let nonce = Nonce::from_slice(&nonce_bytes);
1154 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
1155 ct_with_tag.extend_from_slice(ct);
1156 ct_with_tag.extend_from_slice(&tag_bytes);
1157 let cipher = Aes256Gcm::new(key.as_aes_key());
1158 let plain = cipher
1159 .decrypt(
1160 nonce,
1161 Payload {
1162 msg: &ct_with_tag,
1163 aad: &aad,
1164 },
1165 )
1166 .map_err(|_| SseError::DecryptFailed)?;
1167 Ok(Bytes::from(plain))
1168}
1169
1170pub fn looks_encrypted(body: &[u8]) -> bool {
1181 if body.len() < SSE_HEADER_BYTES {
1182 return false;
1183 }
1184 let m = &body[..4];
1185 m == SSE_MAGIC_V1
1186 || m == SSE_MAGIC_V2
1187 || m == SSE_MAGIC_V3
1188 || m == SSE_MAGIC_V4
1189 || m == SSE_MAGIC_V5
1190}
1191
1192pub fn peek_magic(body: &[u8]) -> Option<&'static str> {
1203 if body.len() < SSE_HEADER_BYTES {
1204 return None;
1205 }
1206 match &body[..4] {
1207 m if m == SSE_MAGIC_V1 => Some("S4E1"),
1208 m if m == SSE_MAGIC_V2 => Some("S4E2"),
1209 m if m == SSE_MAGIC_V3 => Some("S4E3"),
1210 m if m == SSE_MAGIC_V4 => Some("S4E4"),
1211 m if m == SSE_MAGIC_V5 => Some("S4E5"),
1216 _ => None,
1217 }
1218}
1219
1220pub type SharedSseKey = Arc<SseKey>;
1221
1222pub const S4E5_HEADER_BYTES: usize = 4 + 1 + 2 + 1 + 4 + 4 + 4; pub const S4E5_PER_CHUNK_OVERHEAD: usize = TAG_LEN; const S4E5_NONCE_TAG: [u8; 4] = [b'E', b'5', 0, 0];
1289
1290fn aad_v5(
1295 chunk_index: u32,
1296 total_chunks: u32,
1297 key_id: u16,
1298 salt: &[u8; 4],
1299) -> [u8; 4 + 1 + 4 + 4 + 2 + 4] {
1300 let mut aad = [0u8; 4 + 1 + 4 + 4 + 2 + 4]; aad[..4].copy_from_slice(SSE_MAGIC_V5);
1302 aad[4] = ALGO_AES_256_GCM;
1303 aad[5..9].copy_from_slice(&chunk_index.to_be_bytes());
1304 aad[9..13].copy_from_slice(&total_chunks.to_be_bytes());
1305 aad[13..15].copy_from_slice(&key_id.to_be_bytes());
1306 aad[15..19].copy_from_slice(salt);
1307 aad
1308}
1309
1310fn nonce_v5(salt: &[u8; 4], chunk_index: u32) -> [u8; NONCE_LEN] {
1316 let mut n = [0u8; NONCE_LEN];
1317 n[..4].copy_from_slice(&S4E5_NONCE_TAG);
1318 n[4..8].copy_from_slice(salt);
1319 n[8..12].copy_from_slice(&chunk_index.to_be_bytes());
1320 n
1321}
1322
1323pub fn encrypt_v2_chunked(
1339 plaintext: &[u8],
1340 keyring: &SseKeyring,
1341 chunk_size: usize,
1342) -> Result<Bytes, SseError> {
1343 if chunk_size == 0 {
1344 return Err(SseError::ChunkSizeInvalid);
1345 }
1346 let (key_id, key) = keyring.active();
1347 let cipher = Aes256Gcm::new(key.as_aes_key());
1348 let mut salt = [0u8; 4];
1349 rand::rngs::OsRng.fill_bytes(&mut salt);
1350
1351 let chunk_count: u32 = if plaintext.is_empty() {
1354 1
1355 } else {
1356 plaintext
1357 .len()
1358 .div_ceil(chunk_size)
1359 .try_into()
1360 .expect("chunk_count overflows u32 — plaintext > 16 EiB at min chunk_size")
1361 };
1362
1363 let mut out = Vec::with_capacity(
1364 S4E5_HEADER_BYTES + plaintext.len() + (chunk_count as usize * S4E5_PER_CHUNK_OVERHEAD),
1365 );
1366 out.extend_from_slice(SSE_MAGIC_V5);
1367 out.push(ALGO_AES_256_GCM);
1368 out.extend_from_slice(&key_id.to_be_bytes());
1369 out.push(0u8); out.extend_from_slice(&(chunk_size as u32).to_be_bytes());
1371 out.extend_from_slice(&chunk_count.to_be_bytes());
1372 out.extend_from_slice(&salt);
1373
1374 for i in 0..chunk_count {
1375 let off = (i as usize).saturating_mul(chunk_size);
1376 let end = off.saturating_add(chunk_size).min(plaintext.len());
1377 let chunk_pt: &[u8] = if off >= plaintext.len() {
1378 &[]
1381 } else {
1382 &plaintext[off..end]
1383 };
1384 let nonce_bytes = nonce_v5(&salt, i);
1385 let nonce = Nonce::from_slice(&nonce_bytes);
1386 let aad = aad_v5(i, chunk_count, key_id, &salt);
1387 let ct_with_tag = cipher
1388 .encrypt(
1389 nonce,
1390 Payload {
1391 msg: chunk_pt,
1392 aad: &aad,
1393 },
1394 )
1395 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
1396 debug_assert!(ct_with_tag.len() >= TAG_LEN);
1397 let split = ct_with_tag.len() - TAG_LEN;
1398 let (ct, tag) = ct_with_tag.split_at(split);
1399 out.extend_from_slice(tag);
1400 out.extend_from_slice(ct);
1401 crate::metrics::record_sse_streaming_chunk("encrypt");
1402 }
1403 Ok(Bytes::from(out))
1404}
1405
1406#[derive(Debug, Clone, Copy)]
1410struct S4E5Header {
1411 key_id: u16,
1412 chunk_size: u32,
1413 chunk_count: u32,
1414 salt: [u8; 4],
1415 chunks_offset: usize,
1419}
1420
1421fn parse_s4e5_header(body: &[u8]) -> Result<S4E5Header, SseError> {
1422 if body.len() < S4E5_HEADER_BYTES {
1423 return Err(SseError::ChunkFrameTruncated { what: "header" });
1424 }
1425 if &body[..4] != SSE_MAGIC_V5 {
1426 let mut got = [0u8; 4];
1427 got.copy_from_slice(&body[..4]);
1428 return Err(SseError::BadMagic { got });
1429 }
1430 let algo = body[4];
1431 if algo != ALGO_AES_256_GCM {
1432 return Err(SseError::UnsupportedAlgo { tag: algo });
1433 }
1434 let key_id = u16::from_be_bytes([body[5], body[6]]);
1435 let chunk_size = u32::from_be_bytes([body[8], body[9], body[10], body[11]]);
1437 let chunk_count = u32::from_be_bytes([body[12], body[13], body[14], body[15]]);
1438 let mut salt = [0u8; 4];
1439 salt.copy_from_slice(&body[16..20]);
1440 if chunk_size == 0 {
1441 return Err(SseError::ChunkSizeInvalid);
1442 }
1443 if chunk_count == 0 {
1444 return Err(SseError::ChunkFrameTruncated {
1445 what: "chunk_count == 0",
1446 });
1447 }
1448 Ok(S4E5Header {
1449 key_id,
1450 chunk_size,
1451 chunk_count,
1452 salt,
1453 chunks_offset: S4E5_HEADER_BYTES,
1454 })
1455}
1456
1457fn decrypt_v5_chunk(
1460 cipher: &Aes256Gcm,
1461 chunk_index: u32,
1462 chunk_count: u32,
1463 key_id: u16,
1464 salt: &[u8; 4],
1465 tag: &[u8; TAG_LEN],
1466 ct: &[u8],
1467) -> Result<Bytes, SseError> {
1468 let nonce_bytes = nonce_v5(salt, chunk_index);
1469 let nonce = Nonce::from_slice(&nonce_bytes);
1470 let aad = aad_v5(chunk_index, chunk_count, key_id, salt);
1471 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
1472 ct_with_tag.extend_from_slice(ct);
1473 ct_with_tag.extend_from_slice(tag);
1474 cipher
1475 .decrypt(
1476 nonce,
1477 Payload {
1478 msg: &ct_with_tag,
1479 aad: &aad,
1480 },
1481 )
1482 .map(Bytes::from)
1483 .map_err(|_| SseError::ChunkAuthFailed { chunk_index })
1484}
1485
1486fn walk_s4e5<F: FnMut(Bytes) -> Result<(), SseError>>(
1492 body: &[u8],
1493 keyring: &SseKeyring,
1494 mut emit: F,
1495) -> Result<(), SseError> {
1496 let hdr = parse_s4e5_header(body)?;
1497 let key = keyring
1498 .get(hdr.key_id)
1499 .ok_or(SseError::KeyNotInKeyring { id: hdr.key_id })?;
1500 let cipher = Aes256Gcm::new(key.as_aes_key());
1501
1502 let mut cursor = hdr.chunks_offset;
1503 let chunk_size = hdr.chunk_size as usize;
1504 for i in 0..hdr.chunk_count {
1505 if cursor + TAG_LEN > body.len() {
1506 return Err(SseError::ChunkFrameTruncated { what: "chunk tag" });
1507 }
1508 let tag_off = cursor;
1509 let ct_off = tag_off + TAG_LEN;
1510 let is_last = i + 1 == hdr.chunk_count;
1511 let ct_len = if is_last {
1512 if ct_off > body.len() {
1513 return Err(SseError::ChunkFrameTruncated {
1514 what: "final chunk ciphertext",
1515 });
1516 }
1517 let remaining = body.len() - ct_off;
1518 if remaining > chunk_size {
1519 return Err(SseError::ChunkFrameTruncated {
1520 what: "trailing bytes after final chunk",
1521 });
1522 }
1523 remaining
1524 } else {
1525 chunk_size
1526 };
1527 let ct_end = ct_off + ct_len;
1528 if ct_end > body.len() {
1529 return Err(SseError::ChunkFrameTruncated {
1530 what: "chunk ciphertext",
1531 });
1532 }
1533 let mut tag = [0u8; TAG_LEN];
1534 tag.copy_from_slice(&body[tag_off..ct_off]);
1535 let ct = &body[ct_off..ct_end];
1536 let plain = decrypt_v5_chunk(
1537 &cipher,
1538 i,
1539 hdr.chunk_count,
1540 hdr.key_id,
1541 &hdr.salt,
1542 &tag,
1543 ct,
1544 )?;
1545 crate::metrics::record_sse_streaming_chunk("decrypt");
1546 emit(plain)?;
1547 cursor = ct_end;
1548 }
1549 if cursor != body.len() {
1550 return Err(SseError::ChunkFrameTruncated {
1551 what: "trailing bytes after declared chunk_count",
1552 });
1553 }
1554 Ok(())
1555}
1556
1557fn decrypt_v5_buffered(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
1563 let hdr = parse_s4e5_header(body)?;
1564 let mut out = Vec::with_capacity(hdr.chunk_size as usize * hdr.chunk_count as usize);
1565 walk_s4e5(body, keyring, |chunk| {
1566 out.extend_from_slice(&chunk);
1567 Ok(())
1568 })?;
1569 Ok(Bytes::from(out))
1570}
1571
1572pub fn decrypt_chunked_stream(
1595 body: bytes::Bytes,
1596 keyring: &SseKeyring,
1597) -> impl futures::Stream<Item = Result<Bytes, SseError>> + 'static {
1598 use futures::stream::{self, StreamExt};
1599
1600 let prelude = (|| {
1607 let hdr = parse_s4e5_header(&body)?;
1608 let key = keyring
1609 .get(hdr.key_id)
1610 .ok_or(SseError::KeyNotInKeyring { id: hdr.key_id })?;
1611 let cipher = Aes256Gcm::new(key.as_aes_key());
1612 Ok::<_, SseError>((hdr, cipher))
1613 })();
1614
1615 match prelude {
1616 Err(e) => stream::iter(std::iter::once(Err(e))).left_stream(),
1617 Ok((hdr, cipher)) => {
1618 let chunks_offset = hdr.chunks_offset;
1619 let state = ChunkedDecryptState {
1620 body,
1621 cipher,
1622 hdr,
1623 cursor: chunks_offset,
1624 next_index: 0,
1625 };
1626 stream::try_unfold(state, decrypt_next_chunk).right_stream()
1627 }
1628 }
1629}
1630
1631struct ChunkedDecryptState {
1635 body: bytes::Bytes,
1636 cipher: Aes256Gcm,
1637 hdr: S4E5Header,
1638 cursor: usize,
1639 next_index: u32,
1640}
1641
1642async fn decrypt_next_chunk(
1643 mut state: ChunkedDecryptState,
1644) -> Result<Option<(Bytes, ChunkedDecryptState)>, SseError> {
1645 if state.next_index >= state.hdr.chunk_count {
1646 if state.cursor != state.body.len() {
1649 return Err(SseError::ChunkFrameTruncated {
1650 what: "trailing bytes after declared chunk_count",
1651 });
1652 }
1653 return Ok(None);
1654 }
1655 let i = state.next_index;
1656 let chunk_size = state.hdr.chunk_size as usize;
1657 if state.cursor + TAG_LEN > state.body.len() {
1658 return Err(SseError::ChunkFrameTruncated { what: "chunk tag" });
1659 }
1660 let tag_off = state.cursor;
1661 let ct_off = tag_off + TAG_LEN;
1662 let is_last = i + 1 == state.hdr.chunk_count;
1663 let ct_len = if is_last {
1664 if ct_off > state.body.len() {
1665 return Err(SseError::ChunkFrameTruncated {
1666 what: "final chunk ciphertext",
1667 });
1668 }
1669 let remaining = state.body.len() - ct_off;
1670 if remaining > chunk_size {
1671 return Err(SseError::ChunkFrameTruncated {
1672 what: "trailing bytes after final chunk",
1673 });
1674 }
1675 remaining
1676 } else {
1677 chunk_size
1678 };
1679 let ct_end = ct_off + ct_len;
1680 if ct_end > state.body.len() {
1681 return Err(SseError::ChunkFrameTruncated {
1682 what: "chunk ciphertext",
1683 });
1684 }
1685 let mut tag = [0u8; TAG_LEN];
1686 tag.copy_from_slice(&state.body[tag_off..ct_off]);
1687 let ct = &state.body[ct_off..ct_end];
1688 let plain = decrypt_v5_chunk(
1689 &state.cipher,
1690 i,
1691 state.hdr.chunk_count,
1692 state.hdr.key_id,
1693 &state.hdr.salt,
1694 &tag,
1695 ct,
1696 )?;
1697 crate::metrics::record_sse_streaming_chunk("decrypt");
1698 state.cursor = ct_end;
1699 state.next_index += 1;
1700 Ok(Some((plain, state)))
1701}
1702
1703#[cfg(test)]
1704mod tests {
1705 use super::*;
1706
1707 fn key32(seed: u8) -> Arc<SseKey> {
1708 Arc::new(SseKey::from_bytes(&[seed; 32]).unwrap())
1709 }
1710
1711 fn keyring_single(seed: u8) -> SseKeyring {
1712 SseKeyring::new(1, key32(seed))
1713 }
1714
1715 #[test]
1716 fn roundtrip_basic_v1() {
1717 let k = SseKey::from_bytes(&[7u8; 32]).unwrap();
1719 let pt = b"the quick brown fox jumps over the lazy dog";
1720 let ct = encrypt(&k, pt);
1721 assert!(looks_encrypted(&ct));
1722 assert_eq!(&ct[..4], SSE_MAGIC_V1);
1723 assert_eq!(ct[4], ALGO_AES_256_GCM);
1724 assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
1725 let kr = SseKeyring::new(1, Arc::new(k));
1727 let pt2 = decrypt(&ct, &kr).unwrap();
1728 assert_eq!(pt2.as_ref(), pt);
1729 }
1730
1731 #[test]
1732 fn s4e2_roundtrip_active_key() {
1733 let kr = keyring_single(7);
1734 let pt = b"S4E2 active-key roundtrip";
1735 let ct = encrypt_v2(pt, &kr);
1736 assert_eq!(&ct[..4], SSE_MAGIC_V2);
1737 assert_eq!(ct[4], ALGO_AES_256_GCM);
1738 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1, "key_id BE");
1739 assert_eq!(ct[7], 0, "reserved byte");
1740 assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
1741 assert!(looks_encrypted(&ct));
1742 let pt2 = decrypt(&ct, &kr).unwrap();
1743 assert_eq!(pt2.as_ref(), pt);
1744 }
1745
1746 #[test]
1747 fn decrypt_s4e1_via_active_only_keyring() {
1748 let k_arc = key32(11);
1751 let legacy_ct = encrypt(&k_arc, b"v0.4 vintage object");
1752 assert_eq!(&legacy_ct[..4], SSE_MAGIC_V1);
1753 let kr = SseKeyring::new(1, Arc::clone(&k_arc));
1754 let plain = decrypt(&legacy_ct, &kr).unwrap();
1755 assert_eq!(plain.as_ref(), b"v0.4 vintage object");
1756 }
1757
1758 #[test]
1759 fn decrypt_s4e2_under_old_key_after_rotation() {
1760 let k1 = key32(1);
1764 let k2 = key32(2);
1765 let mut kr_old = SseKeyring::new(1, Arc::clone(&k1));
1766 let ct = encrypt_v2(b"old-rotation object", &kr_old);
1767 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
1768
1769 kr_old.add(2, Arc::clone(&k2));
1771 let mut kr_new = SseKeyring::new(2, Arc::clone(&k2));
1772 kr_new.add(1, Arc::clone(&k1));
1773
1774 let plain = decrypt(&ct, &kr_new).unwrap();
1775 assert_eq!(plain.as_ref(), b"old-rotation object");
1776
1777 let new_ct = encrypt_v2(b"new-rotation object", &kr_new);
1779 assert_eq!(u16::from_be_bytes([new_ct[5], new_ct[6]]), 2);
1780 let plain_new = decrypt(&new_ct, &kr_new).unwrap();
1781 assert_eq!(plain_new.as_ref(), b"new-rotation object");
1782 }
1783
1784 #[test]
1785 fn s4e2_unknown_key_id_errors() {
1786 let kr = keyring_single(3); let kr_other = SseKeyring::new(99, key32(3));
1788 let ct = encrypt_v2(b"x", &kr_other); let err = decrypt(&ct, &kr).unwrap_err();
1790 assert!(
1791 matches!(err, SseError::KeyNotInKeyring { id: 99 }),
1792 "got {err:?}"
1793 );
1794 }
1795
1796 #[test]
1797 fn s4e2_tampered_key_id_fails_auth() {
1798 let kr = SseKeyring::new(1, key32(4));
1799 let mut kr_with_2 = kr.clone();
1800 kr_with_2.add(2, key32(5)); let mut ct = encrypt_v2(b"do not flip my key id", &kr).to_vec();
1802 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
1806 ct[5] = 0;
1807 ct[6] = 2;
1808 let err = decrypt(&ct, &kr_with_2).unwrap_err();
1809 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
1810 }
1811
1812 #[test]
1813 fn s4e2_tampered_ciphertext_fails() {
1814 let kr = SseKeyring::new(7, key32(9));
1815 let mut ct = encrypt_v2(b"secret message v2", &kr).to_vec();
1816 let last = ct.len() - 1;
1817 ct[last] ^= 0x01;
1818 let err = decrypt(&ct, &kr).unwrap_err();
1819 assert!(matches!(err, SseError::DecryptFailed));
1820 }
1821
1822 #[test]
1823 fn s4e2_tampered_algo_byte_fails() {
1824 let kr = SseKeyring::new(1, key32(2));
1825 let mut ct = encrypt_v2(b"hi", &kr).to_vec();
1826 ct[4] = 99;
1827 let err = decrypt(&ct, &kr).unwrap_err();
1828 assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
1829 }
1830
1831 #[test]
1832 fn wrong_key_fails_v1_via_keyring() {
1833 let k1 = SseKey::from_bytes(&[1u8; 32]).unwrap();
1835 let ct = encrypt(&k1, b"secret");
1836 let kr_wrong = SseKeyring::new(1, Arc::new(SseKey::from_bytes(&[2u8; 32]).unwrap()));
1837 let err = decrypt(&ct, &kr_wrong).unwrap_err();
1838 assert!(matches!(err, SseError::DecryptFailed));
1839 }
1840
1841 #[test]
1842 fn rejects_short_body() {
1843 let kr = SseKeyring::new(1, key32(1));
1844 let err = decrypt(b"short", &kr).unwrap_err();
1845 assert!(matches!(err, SseError::TooShort { got: 5 }));
1846 }
1847
1848 #[test]
1849 fn looks_encrypted_passthrough_returns_false() {
1850 let f2 = b"S4F2\x01\x00\x00\x00........................................";
1852 assert!(!looks_encrypted(f2));
1853 assert!(!looks_encrypted(b""));
1854 }
1855
1856 #[test]
1857 fn looks_encrypted_detects_both_v1_and_v2() {
1858 let kr = SseKeyring::new(1, key32(8));
1859 let v1 = encrypt(&SseKey::from_bytes(&[8u8; 32]).unwrap(), b"x");
1860 let v2 = encrypt_v2(b"x", &kr);
1861 assert!(looks_encrypted(&v1));
1862 assert!(looks_encrypted(&v2));
1863 }
1864
1865 #[test]
1866 fn key_from_hex_string() {
1867 let bad =
1868 SseKey::from_bytes(b"0102030405060708090a0b0c0d0e0f10111213141516171819202122232425")
1869 .unwrap_err();
1870 assert!(matches!(bad, SseError::BadKeyLength { .. }));
1871 let good = b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
1872 let _ = SseKey::from_bytes(good).expect("64-char hex should parse");
1873 }
1874
1875 #[test]
1876 fn encrypt_v2_uses_random_nonce() {
1877 let kr = SseKeyring::new(1, key32(3));
1878 let pt = b"deterministic input";
1879 let a = encrypt_v2(pt, &kr);
1880 let b = encrypt_v2(pt, &kr);
1881 assert_ne!(a, b, "nonce must be random per-call");
1882 }
1883
1884 #[test]
1885 fn keyring_active_and_get() {
1886 let k1 = key32(1);
1887 let k2 = key32(2);
1888 let mut kr = SseKeyring::new(1, Arc::clone(&k1));
1889 kr.add(2, Arc::clone(&k2));
1890 let (id, active) = kr.active();
1891 assert_eq!(id, 1);
1892 assert_eq!(active.bytes, [1u8; 32]);
1893 assert!(kr.get(2).is_some());
1894 assert!(kr.get(3).is_none());
1895 }
1896
1897 use base64::Engine as _;
1902
1903 fn cust_key(seed: u8) -> CustomerKeyMaterial {
1904 let key = [seed; KEY_LEN];
1905 let key_md5 = compute_key_md5(&key);
1906 CustomerKeyMaterial { key, key_md5 }
1907 }
1908
1909 #[test]
1910 fn s4e3_roundtrip_happy_path() {
1911 let m = cust_key(42);
1912 let pt = b"top-secret SSE-C payload";
1913 let ct = encrypt_with_source(
1914 pt,
1915 SseSource::CustomerKey {
1916 key: &m.key,
1917 key_md5: &m.key_md5,
1918 },
1919 );
1920 assert_eq!(&ct[..4], SSE_MAGIC_V3);
1922 assert_eq!(ct[4], ALGO_AES_256_GCM);
1923 assert_eq!(&ct[5..5 + KEY_MD5_LEN], &m.key_md5);
1924 assert_eq!(ct.len(), SSE_HEADER_BYTES_V3 + pt.len());
1925 assert!(looks_encrypted(&ct));
1926 let plain = decrypt(
1928 &ct,
1929 SseSource::CustomerKey {
1930 key: &m.key,
1931 key_md5: &m.key_md5,
1932 },
1933 )
1934 .unwrap();
1935 assert_eq!(plain.as_ref(), pt);
1936 let plain2 = decrypt(&ct, &m).unwrap();
1938 assert_eq!(plain2.as_ref(), pt);
1939 }
1940
1941 #[test]
1942 fn s4e3_wrong_key_yields_wrong_customer_key_error() {
1943 let m = cust_key(1);
1944 let other = cust_key(2);
1945 let ct = encrypt_with_source(b"payload", (&m).into());
1946 let err = decrypt(
1947 &ct,
1948 SseSource::CustomerKey {
1949 key: &other.key,
1950 key_md5: &other.key_md5,
1951 },
1952 )
1953 .unwrap_err();
1954 assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
1955 }
1956
1957 #[test]
1958 fn s4e3_tampered_stored_md5_is_caught() {
1959 let m = cust_key(7);
1966 let mut ct = encrypt_with_source(b"victim payload", (&m).into()).to_vec();
1967 ct[5] ^= 0x55;
1969 let err = decrypt(
1971 &ct,
1972 SseSource::CustomerKey {
1973 key: &m.key,
1974 key_md5: &m.key_md5,
1975 },
1976 )
1977 .unwrap_err();
1978 assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
1979 }
1980
1981 #[test]
1982 fn s4e3_tampered_md5_with_matching_supplied_md5_fails_aead() {
1983 let m = cust_key(3);
1987 let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
1988 ct[5] ^= 0xFF;
1989 let mut bogus_md5 = m.key_md5;
1990 bogus_md5[0] ^= 0xFF;
1991 let err = decrypt(
1992 &ct,
1993 SseSource::CustomerKey {
1994 key: &m.key,
1995 key_md5: &bogus_md5,
1996 },
1997 )
1998 .unwrap_err();
1999 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
2000 }
2001
2002 #[test]
2003 fn s4e3_tampered_ciphertext_fails_aead() {
2004 let m = cust_key(8);
2005 let mut ct = encrypt_with_source(b"sealed message", (&m).into()).to_vec();
2006 let last = ct.len() - 1;
2007 ct[last] ^= 0x01;
2008 let err = decrypt(&ct, &m).unwrap_err();
2009 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
2010 }
2011
2012 #[test]
2013 fn s4e3_tampered_algo_byte_rejected() {
2014 let m = cust_key(9);
2015 let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
2016 ct[4] = 99;
2017 let err = decrypt(&ct, &m).unwrap_err();
2018 assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
2019 }
2020
2021 #[test]
2022 fn s4e3_uses_random_nonce() {
2023 let m = cust_key(10);
2024 let a = encrypt_with_source(b"deterministic input", (&m).into());
2025 let b = encrypt_with_source(b"deterministic input", (&m).into());
2026 assert_ne!(a, b, "nonce must be random per-call");
2027 }
2028
2029 #[test]
2030 fn parse_customer_key_headers_happy_path() {
2031 let key = [11u8; KEY_LEN];
2032 let md5 = compute_key_md5(&key);
2033 let key_b64 = base64::engine::general_purpose::STANDARD.encode(key);
2034 let md5_b64 = base64::engine::general_purpose::STANDARD.encode(md5);
2035 let m = parse_customer_key_headers("AES256", &key_b64, &md5_b64).unwrap();
2036 assert_eq!(m.key, key);
2037 assert_eq!(m.key_md5, md5);
2038 }
2039
2040 #[test]
2041 fn parse_customer_key_headers_rejects_wrong_algorithm() {
2042 let key = [1u8; KEY_LEN];
2043 let md5 = compute_key_md5(&key);
2044 let kb = base64::engine::general_purpose::STANDARD.encode(key);
2045 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
2046 let err = parse_customer_key_headers("AES128", &kb, &mb).unwrap_err();
2047 assert!(
2048 matches!(err, SseError::CustomerKeyAlgorithmUnsupported { ref algo } if algo == "AES128"),
2049 "got {err:?}"
2050 );
2051 let err2 = parse_customer_key_headers("aes256", &kb, &mb).unwrap_err();
2053 assert!(
2054 matches!(err2, SseError::CustomerKeyAlgorithmUnsupported { .. }),
2055 "got {err2:?}"
2056 );
2057 }
2058
2059 #[test]
2060 fn parse_customer_key_headers_rejects_wrong_key_length() {
2061 let short_key = vec![5u8; 16]; let md5 = compute_key_md5(&short_key);
2063 let kb = base64::engine::general_purpose::STANDARD.encode(&short_key);
2064 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
2065 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
2066 assert!(
2067 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("key length")),
2068 "got {err:?}"
2069 );
2070 }
2071
2072 #[test]
2073 fn parse_customer_key_headers_rejects_wrong_md5_length() {
2074 let key = [3u8; KEY_LEN];
2075 let kb = base64::engine::general_purpose::STANDARD.encode(key);
2076 let bad_md5 = vec![0u8; 15];
2078 let mb = base64::engine::general_purpose::STANDARD.encode(bad_md5);
2079 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
2080 assert!(
2081 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 length")),
2082 "got {err:?}"
2083 );
2084 }
2085
2086 #[test]
2087 fn parse_customer_key_headers_rejects_md5_mismatch() {
2088 let key = [4u8; KEY_LEN];
2089 let other = [5u8; KEY_LEN];
2090 let kb = base64::engine::general_purpose::STANDARD.encode(key);
2091 let wrong_md5 = compute_key_md5(&other);
2092 let mb = base64::engine::general_purpose::STANDARD.encode(wrong_md5);
2093 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
2094 assert!(
2095 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 does not match")),
2096 "got {err:?}"
2097 );
2098 }
2099
2100 #[test]
2101 fn parse_customer_key_headers_rejects_bad_base64() {
2102 let valid_key = [0u8; KEY_LEN];
2103 let md5 = compute_key_md5(&valid_key);
2104 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
2105 let err = parse_customer_key_headers("AES256", "!!!not-base64!!!", &mb).unwrap_err();
2106 assert!(
2107 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
2108 "got {err:?}"
2109 );
2110 let kb = base64::engine::general_purpose::STANDARD.encode(valid_key);
2112 let err2 = parse_customer_key_headers("AES256", &kb, "??not-base64??").unwrap_err();
2113 assert!(
2114 matches!(err2, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
2115 "got {err2:?}"
2116 );
2117 }
2118
2119 #[test]
2120 fn parse_customer_key_headers_trims_whitespace() {
2121 let key = [12u8; KEY_LEN];
2123 let md5 = compute_key_md5(&key);
2124 let kb = format!(
2125 " {}\n",
2126 base64::engine::general_purpose::STANDARD.encode(key)
2127 );
2128 let mb = format!(
2129 "\t{} ",
2130 base64::engine::general_purpose::STANDARD.encode(md5)
2131 );
2132 let m = parse_customer_key_headers("AES256", &kb, &mb).unwrap();
2133 assert_eq!(m.key, key);
2134 }
2135
2136 #[test]
2141 fn back_compat_decrypt_s4e1_with_keyring_source() {
2142 let k = key32(33);
2143 let legacy_ct = encrypt(&k, b"v0.4 vintage object");
2144 let kr = SseKeyring::new(1, Arc::clone(&k));
2145 let plain = decrypt(&legacy_ct, &kr).unwrap();
2148 assert_eq!(plain.as_ref(), b"v0.4 vintage object");
2149 let plain2 = decrypt(&legacy_ct, SseSource::Keyring(&kr)).unwrap();
2150 assert_eq!(plain2.as_ref(), b"v0.4 vintage object");
2151 }
2152
2153 #[test]
2154 fn back_compat_decrypt_s4e2_with_keyring_source() {
2155 let kr = keyring_single(34);
2156 let ct = encrypt_v2(b"v0.5 #29 object", &kr);
2157 let plain = decrypt(&ct, &kr).unwrap();
2158 assert_eq!(plain.as_ref(), b"v0.5 #29 object");
2159 let ct2 = encrypt_with_source(b"v0.5 #29 object", SseSource::Keyring(&kr));
2162 assert_eq!(&ct2[..4], SSE_MAGIC_V2);
2163 let plain2 = decrypt(&ct2, &kr).unwrap();
2164 assert_eq!(plain2.as_ref(), b"v0.5 #29 object");
2165 }
2166
2167 #[test]
2168 fn s4e2_blob_with_customer_key_source_is_rejected() {
2169 let kr = keyring_single(50);
2173 let ct = encrypt_v2(b"server-managed object", &kr);
2174 let m = cust_key(99);
2175 let err = decrypt(
2176 &ct,
2177 SseSource::CustomerKey {
2178 key: &m.key,
2179 key_md5: &m.key_md5,
2180 },
2181 )
2182 .unwrap_err();
2183 assert!(matches!(err, SseError::CustomerKeyUnexpected), "got {err:?}");
2184 }
2185
2186 #[test]
2187 fn s4e3_blob_with_keyring_source_is_rejected() {
2188 let m = cust_key(60);
2191 let ct = encrypt_with_source(b"customer-key object", (&m).into());
2192 let kr = keyring_single(60);
2193 let err = decrypt(&ct, &kr).unwrap_err();
2194 assert!(matches!(err, SseError::CustomerKeyRequired), "got {err:?}");
2195 }
2196
2197 #[test]
2198 fn looks_encrypted_detects_s4e3() {
2199 let m = cust_key(13);
2200 let ct = encrypt_with_source(b"x", (&m).into());
2201 assert!(looks_encrypted(&ct));
2202 }
2203
2204 #[test]
2205 fn s4e3_rejects_short_body() {
2206 let mut short = Vec::new();
2209 short.extend_from_slice(SSE_MAGIC_V3);
2210 short.push(ALGO_AES_256_GCM);
2211 short.extend_from_slice(&[0u8; SSE_HEADER_BYTES - 5]);
2214 assert_eq!(short.len(), SSE_HEADER_BYTES);
2215 let m = cust_key(1);
2216 let err = decrypt(
2217 &short,
2218 SseSource::CustomerKey {
2219 key: &m.key,
2220 key_md5: &m.key_md5,
2221 },
2222 )
2223 .unwrap_err();
2224 assert!(matches!(err, SseError::TooShort { .. }), "got {err:?}");
2225 }
2226
2227 #[test]
2228 fn customer_key_material_debug_redacts_key() {
2229 let m = cust_key(99);
2230 let s = format!("{m:?}");
2231 assert!(s.contains("redacted"));
2232 assert!(!s.contains(&format!("{:?}", m.key.as_slice())));
2233 }
2234
2235 #[test]
2236 fn constant_time_eq_basic() {
2237 assert!(constant_time_eq(b"abc", b"abc"));
2238 assert!(!constant_time_eq(b"abc", b"abd"));
2239 assert!(!constant_time_eq(b"abc", b"abcd"));
2240 assert!(constant_time_eq(b"", b""));
2241 }
2242
2243 #[test]
2244 fn compute_key_md5_known_vector() {
2245 let got = compute_key_md5(b"");
2247 let expected_hex = "d41d8cd98f00b204e9800998ecf8427e";
2248 assert_eq!(hex_lower(&got), expected_hex);
2249 }
2250
2251 use crate::kms::{KmsBackend, LocalKms};
2256 use std::collections::HashMap;
2257 use std::path::PathBuf;
2258
2259 fn local_kms_with(key_ids: &[(&str, [u8; 32])]) -> LocalKms {
2260 let mut keks: HashMap<String, [u8; 32]> = HashMap::new();
2261 for (id, k) in key_ids {
2262 keks.insert((*id).to_string(), *k);
2263 }
2264 LocalKms::from_keks(PathBuf::from("/tmp/none"), keks)
2265 }
2266
2267 #[tokio::test]
2268 async fn s4e4_roundtrip_via_local_kms() {
2269 let kms = local_kms_with(&[("alpha", [42u8; 32])]);
2270 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
2271 let mut dek = [0u8; 32];
2272 dek.copy_from_slice(&dek_vec);
2273 let pt = b"SSE-KMS envelope payload across the S4E4 frame";
2274 let ct = encrypt_with_source(
2275 pt,
2276 SseSource::Kms {
2277 dek: &dek,
2278 wrapped: &wrapped,
2279 },
2280 );
2281 assert_eq!(&ct[..4], SSE_MAGIC_V4);
2283 assert_eq!(ct[4], ALGO_AES_256_GCM);
2284 let key_id_len = ct[5] as usize;
2285 assert_eq!(key_id_len, "alpha".len());
2286 assert_eq!(&ct[6..6 + key_id_len], b"alpha");
2287 assert!(looks_encrypted(&ct));
2289 assert_eq!(peek_magic(&ct), Some("S4E4"));
2290 let plain = decrypt_with_kms(&ct, &kms).await.unwrap();
2292 assert_eq!(plain.as_ref(), pt);
2293 }
2294
2295 #[tokio::test]
2296 async fn s4e4_tampered_key_id_fails_aead() {
2297 let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
2298 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
2299 let mut dek = [0u8; 32];
2300 dek.copy_from_slice(&dek_vec);
2301 let mut ct = encrypt_with_source(
2302 b"do not redirect",
2303 SseSource::Kms {
2304 dek: &dek,
2305 wrapped: &wrapped,
2306 },
2307 )
2308 .to_vec();
2309 let key_id_off = 6;
2314 ct[key_id_off] = b'b';
2315 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2316 assert!(
2317 matches!(
2318 err,
2319 SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
2320 | SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
2321 ),
2322 "got {err:?}"
2323 );
2324 }
2325
2326 #[tokio::test]
2327 async fn s4e4_tampered_key_id_to_real_other_id_still_fails() {
2328 let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
2334 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
2335 let mut dek = [0u8; 32];
2336 dek.copy_from_slice(&dek_vec);
2337 let mut ct = encrypt_with_source(
2338 b"redirect attempt",
2339 SseSource::Kms {
2340 dek: &dek,
2341 wrapped: &wrapped,
2342 },
2343 )
2344 .to_vec();
2345 let key_id_off = 6;
2348 ct[key_id_off..key_id_off + 5].copy_from_slice(b"beta_");
2349 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2356 assert!(
2357 matches!(
2358 err,
2359 SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
2360 ),
2361 "got {err:?}"
2362 );
2363 }
2364
2365 #[tokio::test]
2366 async fn s4e4_tampered_wrapped_dek_fails_unwrap() {
2367 let kms = local_kms_with(&[("k", [3u8; 32])]);
2368 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
2369 let mut dek = [0u8; 32];
2370 dek.copy_from_slice(&dek_vec);
2371 let mut ct = encrypt_with_source(
2372 b"target body",
2373 SseSource::Kms {
2374 dek: &dek,
2375 wrapped: &wrapped,
2376 },
2377 )
2378 .to_vec();
2379 let key_id_len = ct[5] as usize;
2383 let wrapped_len_off = 6 + key_id_len;
2384 let wrapped_off = wrapped_len_off + 4;
2385 let mid = wrapped_off + (wrapped.ciphertext.len() / 2);
2386 ct[mid] ^= 0xFF;
2387 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2388 assert!(
2389 matches!(
2390 err,
2391 SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
2392 ),
2393 "got {err:?}"
2394 );
2395 }
2396
2397 #[tokio::test]
2398 async fn s4e4_tampered_ciphertext_fails_aead() {
2399 let kms = local_kms_with(&[("k", [4u8; 32])]);
2400 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
2401 let mut dek = [0u8; 32];
2402 dek.copy_from_slice(&dek_vec);
2403 let mut ct = encrypt_with_source(
2404 b"sealed body",
2405 SseSource::Kms {
2406 dek: &dek,
2407 wrapped: &wrapped,
2408 },
2409 )
2410 .to_vec();
2411 let last = ct.len() - 1;
2412 ct[last] ^= 0x01;
2413 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2414 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
2415 }
2416
2417 #[tokio::test]
2418 async fn s4e4_uses_random_nonce_and_dek_per_put() {
2419 let kms = local_kms_with(&[("k", [5u8; 32])]);
2420 let (dek1_vec, wrapped1) = kms.generate_dek("k").await.unwrap();
2423 let (dek2_vec, wrapped2) = kms.generate_dek("k").await.unwrap();
2424 let mut dek1 = [0u8; 32];
2425 dek1.copy_from_slice(&dek1_vec);
2426 let mut dek2 = [0u8; 32];
2427 dek2.copy_from_slice(&dek2_vec);
2428 let pt = b"deterministic input";
2429 let a = encrypt_with_source(
2430 pt,
2431 SseSource::Kms {
2432 dek: &dek1,
2433 wrapped: &wrapped1,
2434 },
2435 );
2436 let b = encrypt_with_source(
2437 pt,
2438 SseSource::Kms {
2439 dek: &dek2,
2440 wrapped: &wrapped2,
2441 },
2442 );
2443 assert_ne!(a, b);
2444 let plain_a = decrypt_with_kms(&a, &kms).await.unwrap();
2446 let plain_b = decrypt_with_kms(&b, &kms).await.unwrap();
2447 assert_eq!(plain_a.as_ref(), pt);
2448 assert_eq!(plain_b.as_ref(), pt);
2449 }
2450
2451 #[tokio::test]
2452 async fn s4e4_sync_decrypt_returns_kms_async_required() {
2453 let kms = local_kms_with(&[("k", [6u8; 32])]);
2458 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
2459 let mut dek = [0u8; 32];
2460 dek.copy_from_slice(&dek_vec);
2461 let ct = encrypt_with_source(
2462 b"async only",
2463 SseSource::Kms {
2464 dek: &dek,
2465 wrapped: &wrapped,
2466 },
2467 );
2468 let kr = SseKeyring::new(1, key32(0));
2470 let err = decrypt(&ct, &kr).unwrap_err();
2471 assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
2472 }
2473
2474 #[test]
2475 fn back_compat_s4e1_e2_e3_still_decrypt_via_sync() {
2476 let k = key32(7);
2479 let v1 = encrypt(&k, b"v0.4 vintage");
2480 let kr = SseKeyring::new(1, Arc::clone(&k));
2481 assert_eq!(decrypt(&v1, &kr).unwrap().as_ref(), b"v0.4 vintage");
2482
2483 let v2 = encrypt_v2(b"v0.5 #29 vintage", &kr);
2484 assert_eq!(
2485 decrypt(&v2, &kr).unwrap().as_ref(),
2486 b"v0.5 #29 vintage"
2487 );
2488
2489 let m = cust_key(7);
2490 let v3 = encrypt_with_source(b"v0.5 #27 vintage", (&m).into());
2491 assert_eq!(
2492 decrypt(&v3, &m).unwrap().as_ref(),
2493 b"v0.5 #27 vintage"
2494 );
2495 }
2496
2497 #[test]
2498 fn peek_magic_distinguishes_all_variants() {
2499 let k = key32(9);
2502 let v1 = encrypt(&k, b"x");
2503 assert_eq!(peek_magic(&v1), Some("S4E1"));
2504 let kr = SseKeyring::new(1, Arc::clone(&k));
2505 let v2 = encrypt_v2(b"x", &kr);
2506 assert_eq!(peek_magic(&v2), Some("S4E2"));
2507 let m = cust_key(9);
2508 let v3 = encrypt_with_source(b"x", (&m).into());
2509 assert_eq!(peek_magic(&v3), Some("S4E3"));
2510 let mut v4 = Vec::new();
2515 v4.extend_from_slice(SSE_MAGIC_V4);
2516 v4.extend_from_slice(&[0u8; 40]);
2517 assert_eq!(peek_magic(&v4), Some("S4E4"));
2518 assert!(peek_magic(b"NOPE").is_none());
2520 assert!(peek_magic(b"short").is_none());
2521 assert!(peek_magic(&[0u8; 100]).is_none());
2522 }
2523
2524 #[tokio::test]
2525 async fn s4e4_truncated_frame_errors_cleanly() {
2526 let truncated = b"S4E4\x01\x05hi";
2529 let kms = local_kms_with(&[("k", [1u8; 32])]);
2530 let err = decrypt_with_kms(truncated, &kms).await.unwrap_err();
2531 assert!(
2532 matches!(err, SseError::KmsFrameTooShort { .. }),
2533 "got {err:?}"
2534 );
2535 }
2536
2537 #[tokio::test]
2538 async fn s4e4_oob_key_id_len_errors() {
2539 let mut body = Vec::new();
2543 body.extend_from_slice(SSE_MAGIC_V4);
2544 body.push(ALGO_AES_256_GCM);
2545 body.push(200u8); body.extend_from_slice(&[0u8; 50]);
2550 let kms = local_kms_with(&[("k", [1u8; 32])]);
2551 let err = decrypt_with_kms(&body, &kms).await.unwrap_err();
2552 assert!(
2553 matches!(err, SseError::KmsFrameFieldOob { .. }),
2554 "got {err:?}"
2555 );
2556 }
2557
2558 #[tokio::test]
2559 async fn s4e4_via_keyring_source_into_sync_decrypt_is_kms_async_required() {
2560 let kms = local_kms_with(&[("k", [9u8; 32])]);
2566 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
2567 let mut dek = [0u8; 32];
2568 dek.copy_from_slice(&dek_vec);
2569 let ct = encrypt_with_source(
2570 b"x",
2571 SseSource::Kms {
2572 dek: &dek,
2573 wrapped: &wrapped,
2574 },
2575 );
2576 let m = cust_key(1);
2577 let err = decrypt(&ct, &m).unwrap_err();
2578 assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
2579 }
2580
2581 #[tokio::test]
2582 async fn s4e4_looks_encrypted_passthrough_returns_false_for_synthetic() {
2583 let mut not_s4e4 = Vec::new();
2585 not_s4e4.extend_from_slice(b"S4F4");
2586 not_s4e4.extend_from_slice(&[0u8; 60]);
2587 assert!(!looks_encrypted(¬_s4e4));
2588 assert_eq!(peek_magic(¬_s4e4), None);
2589 }
2590
2591 #[tokio::test]
2592 async fn s4e4_aad_length_prefix_prevents_byte_shifting() {
2593 let kms = local_kms_with(&[("kk", [11u8; 32])]);
2600 let (dek_vec, wrapped) = kms.generate_dek("kk").await.unwrap();
2601 let mut dek = [0u8; 32];
2602 dek.copy_from_slice(&dek_vec);
2603 let mut ct = encrypt_with_source(
2604 b"length-shift defense",
2605 SseSource::Kms {
2606 dek: &dek,
2607 wrapped: &wrapped,
2608 },
2609 )
2610 .to_vec();
2611 let key_id_len = ct[5] as usize;
2612 let wrapped_len_off = 6 + key_id_len;
2613 let original_len = u32::from_be_bytes([
2619 ct[wrapped_len_off],
2620 ct[wrapped_len_off + 1],
2621 ct[wrapped_len_off + 2],
2622 ct[wrapped_len_off + 3],
2623 ]);
2624 let new_len = (original_len - 1).to_be_bytes();
2625 ct[wrapped_len_off..wrapped_len_off + 4].copy_from_slice(&new_len);
2626 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2627 assert!(
2630 matches!(
2631 err,
2632 SseError::KmsBackend(_)
2633 | SseError::DecryptFailed
2634 | SseError::KmsFrameFieldOob { .. }
2635 | SseError::KmsFrameTooShort { .. }
2636 ),
2637 "got {err:?}"
2638 );
2639 }
2640
2641 use futures::StreamExt;
2646
2647 async fn collect_chunks(
2650 s: impl futures::Stream<Item = Result<Bytes, SseError>>,
2651 ) -> Result<Vec<Bytes>, SseError> {
2652 let mut out = Vec::new();
2653 let mut s = std::pin::pin!(s);
2654 while let Some(item) = s.next().await {
2655 out.push(item?);
2656 }
2657 Ok(out)
2658 }
2659
2660 #[test]
2661 fn s4e5_encrypt_layout_10mb_at_1mib() {
2662 let kr = keyring_single(0x42);
2666 let chunk_size = 1024 * 1024;
2667 let pt_len = 10 * 1024 * 1024;
2668 let pt = vec![0xAB_u8; pt_len];
2669 let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).expect("encrypt ok");
2670 assert_eq!(&ct[..4], SSE_MAGIC_V5);
2671 assert_eq!(ct[4], ALGO_AES_256_GCM);
2672 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1, "key_id BE = active id");
2673 assert_eq!(ct[7], 0, "reserved must be 0");
2674 assert_eq!(
2675 u32::from_be_bytes([ct[8], ct[9], ct[10], ct[11]]),
2676 chunk_size as u32,
2677 "chunk_size BE",
2678 );
2679 assert_eq!(
2680 u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
2681 10,
2682 "chunk_count BE — 10 MiB / 1 MiB = 10 (no remainder)",
2683 );
2684 assert_eq!(
2685 ct.len(),
2686 S4E5_HEADER_BYTES + 10 * S4E5_PER_CHUNK_OVERHEAD + pt_len,
2687 "total = header + 10 tags + plaintext",
2688 );
2689 assert!(looks_encrypted(&ct), "looks_encrypted must accept S4E5");
2690 assert_eq!(peek_magic(&ct), Some("S4E5"));
2691 }
2692
2693 #[tokio::test]
2694 async fn s4e5_decrypt_chunked_stream_byte_equal() {
2695 let kr = keyring_single(0x55);
2698 let pt: Vec<u8> = (0..(10 * 1024 * 1024_u32)).map(|i| (i & 0xFF) as u8).collect();
2699 let ct = encrypt_v2_chunked(&pt, &kr, 1024 * 1024).unwrap();
2700 let stream = decrypt_chunked_stream(ct, &kr);
2701 let chunks = collect_chunks(stream).await.expect("stream ok");
2702 assert_eq!(chunks.len(), 10, "10 chunks expected for 10 MiB / 1 MiB");
2703 let mut joined = Vec::with_capacity(pt.len());
2704 for c in chunks {
2705 joined.extend_from_slice(&c);
2706 }
2707 assert_eq!(joined.len(), pt.len(), "byte length matches");
2708 assert_eq!(joined, pt, "byte-equal round-trip");
2709 }
2710
2711 #[tokio::test]
2712 async fn s4e5_single_chunk_for_small_object() {
2713 let kr = keyring_single(0x77);
2715 let pt = b"tiny payload, smaller than chunk_size";
2716 let ct = encrypt_v2_chunked(pt, &kr, 1024 * 1024).unwrap();
2717 assert_eq!(
2718 u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
2719 1,
2720 "small plaintext = single chunk",
2721 );
2722 let stream = decrypt_chunked_stream(ct, &kr);
2723 let chunks = collect_chunks(stream).await.expect("stream ok");
2724 assert_eq!(chunks.len(), 1);
2725 assert_eq!(chunks[0].as_ref(), pt);
2726 }
2727
2728 #[tokio::test]
2729 async fn s4e5_tampered_chunk_n_reports_chunk_index() {
2730 let kr = keyring_single(0x91);
2734 let chunk_size = 1024;
2735 let pt = vec![0xCD_u8; chunk_size * 8]; let mut ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap().to_vec();
2737 let target = S4E5_HEADER_BYTES + 3 * (TAG_LEN + chunk_size) + TAG_LEN;
2740 ct[target] ^= 0x42;
2741 let stream = decrypt_chunked_stream(bytes::Bytes::from(ct), &kr);
2742 let mut s = std::pin::pin!(stream);
2743 for expected_i in 0..3_u32 {
2745 let item = s.next().await.expect("yield");
2746 item.unwrap_or_else(|e| panic!("chunk {expected_i}: {e:?}"));
2747 }
2748 let err = s.next().await.expect("yield error").unwrap_err();
2750 assert!(
2751 matches!(err, SseError::ChunkAuthFailed { chunk_index: 3 }),
2752 "got {err:?}",
2753 );
2754 }
2755
2756 #[tokio::test]
2757 async fn s4e5_back_compat_s4e2_blob_rejected_with_clear_error() {
2758 let kr = keyring_single(0x12);
2762 let s4e2 = encrypt_v2(b"a v2 blob, not chunked", &kr);
2763 let stream = decrypt_chunked_stream(s4e2, &kr);
2764 let result = collect_chunks(stream).await;
2765 let err = result.unwrap_err();
2766 assert!(matches!(err, SseError::BadMagic { .. }), "got {err:?}");
2767 }
2768
2769 #[test]
2770 fn s4e5_salt_uniqueness_birthday_smoke() {
2771 let kr = keyring_single(0x33);
2777 let mut salts = std::collections::HashSet::new();
2778 let n = 1024;
2779 for _ in 0..n {
2780 let ct = encrypt_v2_chunked(b"x", &kr, 64).unwrap();
2781 let mut salt = [0u8; 4];
2782 salt.copy_from_slice(&ct[16..20]);
2783 salts.insert(salt);
2784 }
2785 assert!(
2786 salts.len() > n / 2,
2787 "expected most of the {n} salts to be unique (got {} unique)",
2788 salts.len(),
2789 );
2790 }
2791
2792 #[test]
2793 fn s4e5_chunk_size_zero_is_invalid() {
2794 let kr = keyring_single(0x66);
2795 let err = encrypt_v2_chunked(b"hi", &kr, 0).unwrap_err();
2796 assert!(matches!(err, SseError::ChunkSizeInvalid));
2797 }
2798
2799 #[tokio::test]
2800 async fn s4e5_truncated_body_surfaces_chunk_frame_truncated() {
2801 let kr = keyring_single(0xA1);
2804 let chunk_size = 256;
2805 let pt = vec![0u8; chunk_size * 4];
2806 let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap();
2807 let trunc = S4E5_HEADER_BYTES + 2 * (TAG_LEN + chunk_size) + 8;
2810 let truncated = bytes::Bytes::copy_from_slice(&ct[..trunc]);
2811 let stream = decrypt_chunked_stream(truncated, &kr);
2812 let result = collect_chunks(stream).await;
2813 let err = result.unwrap_err();
2814 assert!(
2815 matches!(err, SseError::ChunkFrameTruncated { .. }),
2816 "got {err:?}",
2817 );
2818 }
2819
2820 #[test]
2821 fn s4e5_decrypt_buffered_round_trip_via_top_level_decrypt() {
2822 let kr = keyring_single(0xDE);
2826 let pt = b"buffered sync decrypt path".repeat(32);
2827 let ct = encrypt_v2_chunked(&pt, &kr, 13).unwrap();
2828 let plain = decrypt(&ct, &kr).expect("buffered S4E5 decrypt ok");
2829 assert_eq!(plain.as_ref(), pt.as_slice());
2830 }
2831
2832 #[tokio::test]
2833 async fn s4e5_unknown_key_id_in_frame_errors() {
2834 let kr_put = SseKeyring::new(7, key32(0xCC));
2836 let kr_get = keyring_single(0xCC); let ct = encrypt_v2_chunked(b"orphan key", &kr_put, 64).unwrap();
2838 let err = decrypt(&ct, &kr_get).unwrap_err();
2840 assert!(matches!(err, SseError::KeyNotInKeyring { id: 7 }), "got {err:?}");
2841 let stream = decrypt_chunked_stream(ct, &kr_get);
2843 let result = collect_chunks(stream).await;
2844 assert!(
2845 matches!(result, Err(SseError::KeyNotInKeyring { id: 7 })),
2846 "got {result:?}",
2847 );
2848 }
2849
2850 #[tokio::test]
2851 async fn s4e5_final_chunk_smaller_than_chunk_size() {
2852 let kr = keyring_single(0xEF);
2854 let chunk_size = 100;
2855 let pt: Vec<u8> = (0..250_u32).map(|i| i as u8).collect();
2856 let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap();
2857 assert_eq!(
2858 u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
2859 3,
2860 "ceil(250/100) = 3 chunks",
2861 );
2862 assert_eq!(ct.len(), 20 + 48 + 250);
2864 let stream = decrypt_chunked_stream(ct, &kr);
2865 let chunks = collect_chunks(stream).await.expect("stream ok");
2866 assert_eq!(chunks.len(), 3);
2867 assert_eq!(chunks[0].len(), 100);
2868 assert_eq!(chunks[1].len(), 100);
2869 assert_eq!(chunks[2].len(), 50, "final chunk is the remainder");
2870 let joined: Vec<u8> = chunks.iter().flat_map(|c| c.iter().copied()).collect();
2871 assert_eq!(joined, pt);
2872 }
2873}