1use std::collections::HashMap;
122use std::path::Path;
123use std::sync::Arc;
124
125use aes_gcm::aead::{Aead, KeyInit, Payload};
126use aes_gcm::{Aes256Gcm, Key, Nonce};
127use bytes::Bytes;
128use md5::{Digest as Md5Digest, Md5};
129use rand::RngCore;
130use thiserror::Error;
131
132use crate::kms::{KmsBackend, KmsError, WrappedDek};
133
134pub const SSE_MAGIC_V1: &[u8; 4] = b"S4E1";
135pub const SSE_MAGIC_V2: &[u8; 4] = b"S4E2";
136pub const SSE_MAGIC_V3: &[u8; 4] = b"S4E3";
137pub const SSE_MAGIC_V4: &[u8; 4] = b"S4E4";
138pub const SSE_MAGIC_V5: &[u8; 4] = b"S4E5";
149pub const SSE_MAGIC_V6: &[u8; 4] = b"S4E6";
155pub const SSE_MAGIC: &[u8; 4] = SSE_MAGIC_V1;
157
158pub const SSE_HEADER_BYTES: usize = 4 + 1 + 3 + 12 + 16; pub const SSE_HEADER_BYTES_V3: usize = 4 + 1 + KEY_MD5_LEN + 12 + 16; pub const ALGO_AES_256_GCM: u8 = 1;
169const NONCE_LEN: usize = 12;
170const TAG_LEN: usize = 16;
171const KEY_LEN: usize = 32;
172const KEY_MD5_LEN: usize = 16;
173pub const SSE_C_ALGORITHM: &str = "AES256";
177
178#[derive(Debug, Error)]
179pub enum SseError {
180 #[error("SSE key file {path:?}: {source}")]
181 KeyFileIo {
182 path: std::path::PathBuf,
183 source: std::io::Error,
184 },
185 #[error(
186 "SSE key file must be exactly 32 raw bytes (or 64-char hex / 44-char base64); got {got} bytes after parse"
187 )]
188 BadKeyLength { got: usize },
189 #[error("SSE-encrypted body too short ({got} bytes; need at least {SSE_HEADER_BYTES})")]
190 TooShort { got: usize },
191 #[error("SSE bad magic: expected S4E1/S4E2/S4E3/S4E4/S4E5/S4E6, got {got:?}")]
192 BadMagic { got: [u8; 4] },
193 #[error("SSE unsupported algo tag: {tag} (this build only knows AES-256-GCM = 1)")]
194 UnsupportedAlgo { tag: u8 },
195 #[error(
196 "SSE key_id {id} (S4E2 frame) not present in keyring; rotation history likely incomplete"
197 )]
198 KeyNotInKeyring { id: u16 },
199 #[error("SSE decryption / authentication failed (key mismatch or ciphertext tampered with)")]
200 DecryptFailed,
201 #[error("SSE-C key MD5 fingerprint mismatch — client supplied a different key than PUT")]
209 WrongCustomerKey,
210 #[error("SSE-C customer-key headers invalid: {reason}")]
215 InvalidCustomerKey { reason: &'static str },
216 #[error("SSE-C algorithm {algo:?} unsupported (only {SSE_C_ALGORITHM:?} is allowed)")]
220 CustomerKeyAlgorithmUnsupported { algo: String },
221 #[error("S4E3 frame requires SseSource::CustomerKey; got Keyring")]
226 CustomerKeyRequired,
227 #[error("S4E1/S4E2 frame stored without SSE-C; SseSource::CustomerKey is unexpected")]
232 CustomerKeyUnexpected,
233 #[error(
240 "S4E4 (SSE-KMS) body requires async decrypt — call decrypt_with_kms() instead of decrypt()"
241 )]
242 KmsAsyncRequired,
243 #[error("S4E4 frame too short ({got} bytes; need at least {min})")]
247 KmsFrameTooShort { got: usize, min: usize },
248 #[error("S4E4 frame field length out of bounds: {what}")]
253 KmsFrameFieldOob { what: &'static str },
254 #[error("S4E4 key_id is not valid UTF-8")]
259 KmsKeyIdNotUtf8,
260 #[error(
267 "S4E4 SseSource::Kms wrapped DEK key_id {supplied:?} doesn't match frame key_id {stored:?}"
268 )]
269 KmsWrappedDekMismatch {
270 supplied: String,
271 stored: String,
272 },
273 #[error("S4E4 frame requires SseSource::Kms")]
280 KmsRequired,
281 #[error("KMS unwrap: {0}")]
284 KmsBackend(#[from] KmsError),
285 #[error("S4E5 chunk {chunk_index} auth tag verify failed (key mismatch or chunk tampered with)")]
294 ChunkAuthFailed { chunk_index: u32 },
295 #[error("S4E5 chunk_size must be > 0 (got 0)")]
300 ChunkSizeInvalid,
301 #[error("S4E5 frame truncated: {what}")]
307 ChunkFrameTruncated { what: &'static str },
308 #[error(
318 "S4E6 chunk_count {got} exceeds 24-bit max ({max}) — pick a larger --sse-chunk-size"
319 )]
320 ChunkCountTooLarge { got: u32, max: u32 },
321 #[error("S4E5/S4E6 chunked frame declares an over-large size: {details}")]
342 ChunkFrameTooLarge { details: &'static str },
343}
344
345pub const DEFAULT_MAX_BODY_BYTES: usize = 5 * 1024 * 1024 * 1024;
352
353pub struct SseKey {
358 pub bytes: [u8; 32],
359}
360
361impl SseKey {
362 pub fn from_path(path: &Path) -> Result<Self, SseError> {
366 let raw = std::fs::read(path).map_err(|source| SseError::KeyFileIo {
367 path: path.to_path_buf(),
368 source,
369 })?;
370 Self::from_bytes(&raw)
371 }
372
373 pub fn from_bytes(bytes: &[u8]) -> Result<Self, SseError> {
374 if bytes.len() == KEY_LEN {
376 let mut k = [0u8; KEY_LEN];
377 k.copy_from_slice(bytes);
378 return Ok(Self { bytes: k });
379 }
380 let s = std::str::from_utf8(bytes).unwrap_or("").trim();
382 if s.len() == KEY_LEN * 2 && s.chars().all(|c| c.is_ascii_hexdigit()) {
383 let mut k = [0u8; KEY_LEN];
384 for (i, k_byte) in k.iter_mut().enumerate() {
385 *k_byte = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16)
386 .map_err(|_| SseError::BadKeyLength { got: bytes.len() })?;
387 }
388 return Ok(Self { bytes: k });
389 }
390 if let Ok(decoded) =
391 base64::Engine::decode(&base64::engine::general_purpose::STANDARD, s.as_bytes())
392 && decoded.len() == KEY_LEN
393 {
394 let mut k = [0u8; KEY_LEN];
395 k.copy_from_slice(&decoded);
396 return Ok(Self { bytes: k });
397 }
398 Err(SseError::BadKeyLength { got: bytes.len() })
399 }
400
401 fn as_aes_key(&self) -> &Key<Aes256Gcm> {
402 Key::<Aes256Gcm>::from_slice(&self.bytes)
403 }
404}
405
406impl std::fmt::Debug for SseKey {
407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 f.debug_struct("SseKey")
409 .field("len", &KEY_LEN)
410 .field("key", &"<redacted>")
411 .finish()
412 }
413}
414
415#[derive(Clone)]
420pub struct SseKeyring {
421 active: u16,
422 keys: HashMap<u16, Arc<SseKey>>,
423}
424
425impl SseKeyring {
426 pub fn new(active: u16, key: Arc<SseKey>) -> Self {
430 let mut keys = HashMap::new();
431 keys.insert(active, key);
432 Self { active, keys }
433 }
434
435 pub fn add(&mut self, id: u16, key: Arc<SseKey>) {
439 self.keys.insert(id, key);
440 }
441
442 pub fn active(&self) -> (u16, &SseKey) {
445 let id = self.active;
446 let key = self
447 .keys
448 .get(&id)
449 .expect("active key id must be present in keyring (constructor invariant)");
450 (id, key.as_ref())
451 }
452
453 pub fn get(&self, id: u16) -> Option<&SseKey> {
456 self.keys.get(&id).map(Arc::as_ref)
457 }
458}
459
460impl std::fmt::Debug for SseKeyring {
461 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462 f.debug_struct("SseKeyring")
463 .field("active", &self.active)
464 .field("key_count", &self.keys.len())
465 .field("key_ids", &self.keys.keys().collect::<Vec<_>>())
466 .finish()
467 }
468}
469
470pub type SharedSseKeyring = Arc<SseKeyring>;
471
472pub fn encrypt(key: &SseKey, plaintext: &[u8]) -> Bytes {
479 let cipher = Aes256Gcm::new(key.as_aes_key());
480 let mut nonce_bytes = [0u8; NONCE_LEN];
481 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
482 let nonce = Nonce::from_slice(&nonce_bytes);
483 let mut aad = [0u8; 8];
485 aad[..4].copy_from_slice(SSE_MAGIC_V1);
486 aad[4] = ALGO_AES_256_GCM;
487 let ct_with_tag = cipher
488 .encrypt(
489 nonce,
490 Payload {
491 msg: plaintext,
492 aad: &aad,
493 },
494 )
495 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
496 debug_assert!(ct_with_tag.len() >= TAG_LEN);
497 let split = ct_with_tag.len() - TAG_LEN;
498 let (ct, tag) = ct_with_tag.split_at(split);
499
500 let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
501 out.extend_from_slice(SSE_MAGIC_V1);
502 out.push(ALGO_AES_256_GCM);
503 out.extend_from_slice(&[0u8; 3]); out.extend_from_slice(&nonce_bytes);
505 out.extend_from_slice(tag);
506 out.extend_from_slice(ct);
507 Bytes::from(out)
508}
509
510pub fn encrypt_v2(plaintext: &[u8], keyring: &SseKeyring) -> Bytes {
515 let (key_id, key) = keyring.active();
516 let cipher = Aes256Gcm::new(key.as_aes_key());
517 let mut nonce_bytes = [0u8; NONCE_LEN];
518 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
519 let nonce = Nonce::from_slice(&nonce_bytes);
520 let aad = aad_v2(key_id);
521 let ct_with_tag = cipher
522 .encrypt(
523 nonce,
524 Payload {
525 msg: plaintext,
526 aad: &aad,
527 },
528 )
529 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
530 debug_assert!(ct_with_tag.len() >= TAG_LEN);
531 let split = ct_with_tag.len() - TAG_LEN;
532 let (ct, tag) = ct_with_tag.split_at(split);
533
534 let mut out = Vec::with_capacity(SSE_HEADER_BYTES + ct.len());
535 out.extend_from_slice(SSE_MAGIC_V2);
536 out.push(ALGO_AES_256_GCM);
537 out.extend_from_slice(&key_id.to_be_bytes()); out.push(0u8); out.extend_from_slice(&nonce_bytes);
540 out.extend_from_slice(tag);
541 out.extend_from_slice(ct);
542 Bytes::from(out)
543}
544
545fn aad_v1() -> [u8; 8] {
546 let mut aad = [0u8; 8];
547 aad[..4].copy_from_slice(SSE_MAGIC_V1);
548 aad[4] = ALGO_AES_256_GCM;
549 aad
550}
551
552fn aad_v2(key_id: u16) -> [u8; 8] {
553 let mut aad = [0u8; 8];
554 aad[..4].copy_from_slice(SSE_MAGIC_V2);
555 aad[4] = ALGO_AES_256_GCM;
556 aad[5..7].copy_from_slice(&key_id.to_be_bytes());
557 aad[7] = 0u8;
558 aad
559}
560
561fn aad_v3(key_md5: &[u8; KEY_MD5_LEN]) -> [u8; 4 + 1 + KEY_MD5_LEN] {
567 let mut aad = [0u8; 4 + 1 + KEY_MD5_LEN];
568 aad[..4].copy_from_slice(SSE_MAGIC_V3);
569 aad[4] = ALGO_AES_256_GCM;
570 aad[5..5 + KEY_MD5_LEN].copy_from_slice(key_md5);
571 aad
572}
573
574#[derive(Clone)]
580pub struct CustomerKeyMaterial {
581 pub key: [u8; KEY_LEN],
582 pub key_md5: [u8; KEY_MD5_LEN],
583}
584
585impl std::fmt::Debug for CustomerKeyMaterial {
586 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
587 f.debug_struct("CustomerKeyMaterial")
590 .field("key", &"<redacted>")
591 .field("key_md5_hex", &hex_lower(&self.key_md5))
592 .finish()
593 }
594}
595
596fn hex_lower(bytes: &[u8]) -> String {
597 let mut s = String::with_capacity(bytes.len() * 2);
598 for b in bytes {
599 s.push_str(&format!("{b:02x}"));
600 }
601 s
602}
603
604#[derive(Debug, Clone, Copy)]
612pub enum SseSource<'a> {
613 Keyring(&'a SseKeyring),
616 CustomerKey {
620 key: &'a [u8; KEY_LEN],
621 key_md5: &'a [u8; KEY_MD5_LEN],
622 },
623 Kms {
629 dek: &'a [u8; KEY_LEN],
631 wrapped: &'a WrappedDek,
634 },
635}
636
637impl<'a> From<&'a SseKeyring> for SseSource<'a> {
644 fn from(kr: &'a SseKeyring) -> Self {
645 SseSource::Keyring(kr)
646 }
647}
648
649impl<'a> From<&'a Arc<SseKeyring>> for SseSource<'a> {
653 fn from(kr: &'a Arc<SseKeyring>) -> Self {
654 SseSource::Keyring(kr.as_ref())
655 }
656}
657
658impl<'a> From<&'a CustomerKeyMaterial> for SseSource<'a> {
659 fn from(m: &'a CustomerKeyMaterial) -> Self {
660 SseSource::CustomerKey {
661 key: &m.key,
662 key_md5: &m.key_md5,
663 }
664 }
665}
666
667pub fn parse_customer_key_headers(
679 algorithm: &str,
680 key_base64: &str,
681 key_md5_base64: &str,
682) -> Result<CustomerKeyMaterial, SseError> {
683 use base64::Engine as _;
684 if algorithm != SSE_C_ALGORITHM {
685 return Err(SseError::CustomerKeyAlgorithmUnsupported {
686 algo: algorithm.to_string(),
687 });
688 }
689 let key_bytes = base64::engine::general_purpose::STANDARD
690 .decode(key_base64.trim().as_bytes())
691 .map_err(|_| SseError::InvalidCustomerKey {
692 reason: "base64 decode of key",
693 })?;
694 if key_bytes.len() != KEY_LEN {
695 return Err(SseError::InvalidCustomerKey {
696 reason: "key length (must be 32 bytes after base64 decode)",
697 });
698 }
699 let supplied_md5 = base64::engine::general_purpose::STANDARD
700 .decode(key_md5_base64.trim().as_bytes())
701 .map_err(|_| SseError::InvalidCustomerKey {
702 reason: "base64 decode of key MD5",
703 })?;
704 if supplied_md5.len() != KEY_MD5_LEN {
705 return Err(SseError::InvalidCustomerKey {
706 reason: "key MD5 length (must be 16 bytes after base64 decode)",
707 });
708 }
709 let actual_md5 = compute_key_md5(&key_bytes);
710 if !constant_time_eq(&actual_md5, &supplied_md5) {
713 return Err(SseError::InvalidCustomerKey {
714 reason: "supplied MD5 does not match MD5 of supplied key",
715 });
716 }
717 let mut key = [0u8; KEY_LEN];
718 key.copy_from_slice(&key_bytes);
719 let mut key_md5 = [0u8; KEY_MD5_LEN];
720 key_md5.copy_from_slice(&actual_md5);
721 Ok(CustomerKeyMaterial { key, key_md5 })
722}
723
724pub fn compute_key_md5(key: &[u8]) -> [u8; KEY_MD5_LEN] {
729 let mut h = Md5::new();
730 h.update(key);
731 let out = h.finalize();
732 let mut md5 = [0u8; KEY_MD5_LEN];
733 md5.copy_from_slice(&out);
734 md5
735}
736
737fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
740 if a.len() != b.len() {
741 return false;
742 }
743 let mut acc: u8 = 0;
744 for (x, y) in a.iter().zip(b.iter()) {
745 acc |= x ^ y;
746 }
747 acc == 0
748}
749
750pub fn encrypt_with_source(plaintext: &[u8], source: SseSource<'_>) -> Bytes {
760 match source {
761 SseSource::Keyring(kr) => encrypt_v2(plaintext, kr),
762 SseSource::CustomerKey { key, key_md5 } => encrypt_v3(plaintext, key, key_md5),
763 SseSource::Kms { dek, wrapped } => encrypt_v4(plaintext, dek, wrapped),
764 }
765}
766
767fn encrypt_v3(
768 plaintext: &[u8],
769 key: &[u8; KEY_LEN],
770 key_md5: &[u8; KEY_MD5_LEN],
771) -> Bytes {
772 let aes_key = Key::<Aes256Gcm>::from_slice(key);
773 let cipher = Aes256Gcm::new(aes_key);
774 let mut nonce_bytes = [0u8; NONCE_LEN];
775 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
776 let nonce = Nonce::from_slice(&nonce_bytes);
777 let aad = aad_v3(key_md5);
778 let ct_with_tag = cipher
779 .encrypt(
780 nonce,
781 Payload {
782 msg: plaintext,
783 aad: &aad,
784 },
785 )
786 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
787 debug_assert!(ct_with_tag.len() >= TAG_LEN);
788 let split = ct_with_tag.len() - TAG_LEN;
789 let (ct, tag) = ct_with_tag.split_at(split);
790
791 let mut out = Vec::with_capacity(SSE_HEADER_BYTES_V3 + ct.len());
792 out.extend_from_slice(SSE_MAGIC_V3);
793 out.push(ALGO_AES_256_GCM);
794 out.extend_from_slice(key_md5);
795 out.extend_from_slice(&nonce_bytes);
796 out.extend_from_slice(tag);
797 out.extend_from_slice(ct);
798 Bytes::from(out)
799}
800
801pub fn decrypt<'a, S: Into<SseSource<'a>>>(body: &[u8], source: S) -> Result<Bytes, SseError> {
820 let source = source.into();
821 if body.len() < SSE_HEADER_BYTES {
827 return Err(SseError::TooShort { got: body.len() });
828 }
829 let mut magic = [0u8; 4];
830 magic.copy_from_slice(&body[..4]);
831 match &magic {
832 m if m == SSE_MAGIC_V1 || m == SSE_MAGIC_V2 => {
833 let keyring = match source {
834 SseSource::Keyring(kr) => kr,
835 SseSource::CustomerKey { .. } => return Err(SseError::CustomerKeyUnexpected),
836 SseSource::Kms { .. } => return Err(SseError::CustomerKeyUnexpected),
842 };
843 if m == SSE_MAGIC_V1 {
844 decrypt_v1_with_keyring(body, keyring)
845 } else {
846 decrypt_v2_with_keyring(body, keyring)
847 }
848 }
849 m if m == SSE_MAGIC_V3 => {
850 if body.len() < SSE_HEADER_BYTES_V3 {
852 return Err(SseError::TooShort { got: body.len() });
853 }
854 let (key, key_md5) = match source {
855 SseSource::CustomerKey { key, key_md5 } => (key, key_md5),
856 SseSource::Keyring(_) => return Err(SseError::CustomerKeyRequired),
857 SseSource::Kms { .. } => return Err(SseError::CustomerKeyRequired),
858 };
859 decrypt_v3(body, key, key_md5)
860 }
861 m if m == SSE_MAGIC_V4 => {
862 Err(SseError::KmsAsyncRequired)
867 }
868 m if m == SSE_MAGIC_V5 || m == SSE_MAGIC_V6 => {
869 let keyring = match source {
877 SseSource::Keyring(kr) => kr,
878 SseSource::CustomerKey { .. } => {
879 return Err(SseError::CustomerKeyUnexpected);
880 }
881 SseSource::Kms { .. } => return Err(SseError::CustomerKeyUnexpected),
882 };
883 decrypt_chunked_buffered_default(body, keyring)
890 }
891 _ => Err(SseError::BadMagic { got: magic }),
892 }
893}
894
895fn decrypt_v3(
896 body: &[u8],
897 key: &[u8; KEY_LEN],
898 supplied_md5: &[u8; KEY_MD5_LEN],
899) -> Result<Bytes, SseError> {
900 let algo = body[4];
901 if algo != ALGO_AES_256_GCM {
902 return Err(SseError::UnsupportedAlgo { tag: algo });
903 }
904 let mut stored_md5 = [0u8; KEY_MD5_LEN];
905 stored_md5.copy_from_slice(&body[5..5 + KEY_MD5_LEN]);
906 if !constant_time_eq(supplied_md5, &stored_md5) {
912 return Err(SseError::WrongCustomerKey);
913 }
914 let nonce_off = 5 + KEY_MD5_LEN;
915 let tag_off = nonce_off + NONCE_LEN;
916 let mut nonce_bytes = [0u8; NONCE_LEN];
917 nonce_bytes.copy_from_slice(&body[nonce_off..nonce_off + NONCE_LEN]);
918 let mut tag_bytes = [0u8; TAG_LEN];
919 tag_bytes.copy_from_slice(&body[tag_off..tag_off + TAG_LEN]);
920 let ct = &body[SSE_HEADER_BYTES_V3..];
921
922 let aad = aad_v3(&stored_md5);
923 let nonce = Nonce::from_slice(&nonce_bytes);
924 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
925 ct_with_tag.extend_from_slice(ct);
926 ct_with_tag.extend_from_slice(&tag_bytes);
927
928 let aes_key = Key::<Aes256Gcm>::from_slice(key);
929 let cipher = Aes256Gcm::new(aes_key);
930 let plain = cipher
931 .decrypt(
932 nonce,
933 Payload {
934 msg: &ct_with_tag,
935 aad: &aad,
936 },
937 )
938 .map_err(|_| SseError::DecryptFailed)?;
939 Ok(Bytes::from(plain))
940}
941
942fn aad_v4(key_id: &[u8], wrapped_dek: &[u8]) -> Vec<u8> {
953 let mut aad = Vec::with_capacity(4 + 1 + 1 + key_id.len() + 4 + wrapped_dek.len());
954 aad.extend_from_slice(SSE_MAGIC_V4);
955 aad.push(ALGO_AES_256_GCM);
956 aad.push(key_id.len() as u8);
957 aad.extend_from_slice(key_id);
958 aad.extend_from_slice(&(wrapped_dek.len() as u32).to_be_bytes());
959 aad.extend_from_slice(wrapped_dek);
960 aad
961}
962
963fn encrypt_v4(plaintext: &[u8], dek: &[u8; KEY_LEN], wrapped: &WrappedDek) -> Bytes {
964 assert!(
972 !wrapped.key_id.is_empty() && wrapped.key_id.len() <= u8::MAX as usize,
973 "S4E4 key_id must be 1..=255 bytes (got {})",
974 wrapped.key_id.len()
975 );
976 assert!(
977 wrapped.ciphertext.len() <= u32::MAX as usize,
978 "S4E4 wrapped_dek longer than u32::MAX",
979 );
980
981 let aes_key = Key::<Aes256Gcm>::from_slice(dek);
982 let cipher = Aes256Gcm::new(aes_key);
983 let mut nonce_bytes = [0u8; NONCE_LEN];
984 rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
985 let nonce = Nonce::from_slice(&nonce_bytes);
986 let aad = aad_v4(wrapped.key_id.as_bytes(), &wrapped.ciphertext);
987 let ct_with_tag = cipher
988 .encrypt(
989 nonce,
990 Payload {
991 msg: plaintext,
992 aad: &aad,
993 },
994 )
995 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
996 debug_assert!(ct_with_tag.len() >= TAG_LEN);
997 let split = ct_with_tag.len() - TAG_LEN;
998 let (ct, tag) = ct_with_tag.split_at(split);
999
1000 let key_id_bytes = wrapped.key_id.as_bytes();
1001 let mut out = Vec::with_capacity(
1002 4 + 1 + 1 + key_id_bytes.len() + 4 + wrapped.ciphertext.len() + NONCE_LEN + TAG_LEN + ct.len(),
1003 );
1004 out.extend_from_slice(SSE_MAGIC_V4);
1005 out.push(ALGO_AES_256_GCM);
1006 out.push(key_id_bytes.len() as u8);
1007 out.extend_from_slice(key_id_bytes);
1008 out.extend_from_slice(&(wrapped.ciphertext.len() as u32).to_be_bytes());
1009 out.extend_from_slice(&wrapped.ciphertext);
1010 out.extend_from_slice(&nonce_bytes);
1011 out.extend_from_slice(tag);
1012 out.extend_from_slice(ct);
1013 Bytes::from(out)
1014}
1015
1016#[derive(Debug)]
1022pub struct S4E4Header<'a> {
1023 pub key_id: &'a str,
1024 pub wrapped_dek: &'a [u8],
1025 pub nonce: &'a [u8],
1026 pub tag: &'a [u8],
1027 pub ciphertext: &'a [u8],
1028}
1029
1030pub fn parse_s4e4_header(body: &[u8]) -> Result<S4E4Header<'_>, SseError> {
1034 const S4E4_MIN: usize = 4 + 1 + 1 + 4 + NONCE_LEN + TAG_LEN; if body.len() < S4E4_MIN {
1041 return Err(SseError::KmsFrameTooShort {
1042 got: body.len(),
1043 min: S4E4_MIN,
1044 });
1045 }
1046 let magic = &body[..4];
1047 if magic != SSE_MAGIC_V4 {
1048 let mut got = [0u8; 4];
1049 got.copy_from_slice(magic);
1050 return Err(SseError::BadMagic { got });
1051 }
1052 let algo = body[4];
1053 if algo != ALGO_AES_256_GCM {
1054 return Err(SseError::UnsupportedAlgo { tag: algo });
1055 }
1056 let key_id_len = body[5] as usize;
1057 let key_id_off: usize = 6;
1058 let key_id_end = key_id_off
1059 .checked_add(key_id_len)
1060 .ok_or(SseError::KmsFrameFieldOob { what: "key_id_len" })?;
1061 if key_id_end + 4 > body.len() {
1062 return Err(SseError::KmsFrameFieldOob { what: "key_id" });
1063 }
1064 let key_id = std::str::from_utf8(&body[key_id_off..key_id_end])
1065 .map_err(|_| SseError::KmsKeyIdNotUtf8)?;
1066 let wrapped_len_off = key_id_end;
1067 let wrapped_dek_len = u32::from_be_bytes([
1068 body[wrapped_len_off],
1069 body[wrapped_len_off + 1],
1070 body[wrapped_len_off + 2],
1071 body[wrapped_len_off + 3],
1072 ]) as usize;
1073 let wrapped_off = wrapped_len_off + 4;
1074 let wrapped_end = wrapped_off
1075 .checked_add(wrapped_dek_len)
1076 .ok_or(SseError::KmsFrameFieldOob { what: "wrapped_dek_len" })?;
1077 if wrapped_end + NONCE_LEN + TAG_LEN > body.len() {
1078 return Err(SseError::KmsFrameFieldOob { what: "wrapped_dek" });
1079 }
1080 let wrapped_dek = &body[wrapped_off..wrapped_end];
1081 let nonce_off = wrapped_end;
1082 let tag_off = nonce_off + NONCE_LEN;
1083 let ct_off = tag_off + TAG_LEN;
1084 let nonce = &body[nonce_off..nonce_off + NONCE_LEN];
1085 let tag = &body[tag_off..tag_off + TAG_LEN];
1086 let ciphertext = &body[ct_off..];
1087 Ok(S4E4Header {
1088 key_id,
1089 wrapped_dek,
1090 nonce,
1091 tag,
1092 ciphertext,
1093 })
1094}
1095
1096pub async fn decrypt_with_kms(
1112 body: &[u8],
1113 kms: &dyn KmsBackend,
1114) -> Result<Bytes, SseError> {
1115 let hdr = parse_s4e4_header(body)?;
1116 let wrapped = WrappedDek {
1117 key_id: hdr.key_id.to_string(),
1118 ciphertext: hdr.wrapped_dek.to_vec(),
1119 };
1120 let dek_vec = kms.decrypt_dek(&wrapped).await?;
1121 if dek_vec.len() != KEY_LEN {
1122 return Err(SseError::KmsBackend(KmsError::BackendUnavailable {
1127 message: format!(
1128 "KMS returned {} byte DEK; expected {KEY_LEN}",
1129 dek_vec.len()
1130 ),
1131 }));
1132 }
1133 let mut dek = [0u8; KEY_LEN];
1134 dek.copy_from_slice(&dek_vec);
1135
1136 let aad = aad_v4(hdr.key_id.as_bytes(), hdr.wrapped_dek);
1137 let aes_key = Key::<Aes256Gcm>::from_slice(&dek);
1138 let cipher = Aes256Gcm::new(aes_key);
1139 let nonce = Nonce::from_slice(hdr.nonce);
1140 let mut ct_with_tag = Vec::with_capacity(hdr.ciphertext.len() + TAG_LEN);
1141 ct_with_tag.extend_from_slice(hdr.ciphertext);
1142 ct_with_tag.extend_from_slice(hdr.tag);
1143 let plain = cipher
1144 .decrypt(
1145 nonce,
1146 Payload {
1147 msg: &ct_with_tag,
1148 aad: &aad,
1149 },
1150 )
1151 .map_err(|_| SseError::DecryptFailed)?;
1152 Ok(Bytes::from(plain))
1153}
1154
1155fn decrypt_v1_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
1156 let algo = body[4];
1157 if algo != ALGO_AES_256_GCM {
1158 return Err(SseError::UnsupportedAlgo { tag: algo });
1159 }
1160 let mut nonce_bytes = [0u8; NONCE_LEN];
1163 nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
1164 let mut tag_bytes = [0u8; TAG_LEN];
1165 tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
1166 let ct = &body[SSE_HEADER_BYTES..];
1167
1168 let aad = aad_v1();
1169 let nonce = Nonce::from_slice(&nonce_bytes);
1170 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
1171 ct_with_tag.extend_from_slice(ct);
1172 ct_with_tag.extend_from_slice(&tag_bytes);
1173
1174 let (active_id, _active_key) = keyring.active();
1178 let mut ids: Vec<u16> = keyring.keys.keys().copied().collect();
1179 ids.sort_by_key(|id| if *id == active_id { 0 } else { 1 });
1180 for id in ids {
1181 let key = keyring.get(id).expect("id came from keyring iteration");
1182 let cipher = Aes256Gcm::new(key.as_aes_key());
1183 if let Ok(plain) = cipher.decrypt(
1184 nonce,
1185 Payload {
1186 msg: &ct_with_tag,
1187 aad: &aad,
1188 },
1189 ) {
1190 return Ok(Bytes::from(plain));
1191 }
1192 }
1193 Err(SseError::DecryptFailed)
1194}
1195
1196fn decrypt_v2_with_keyring(body: &[u8], keyring: &SseKeyring) -> Result<Bytes, SseError> {
1197 let algo = body[4];
1198 if algo != ALGO_AES_256_GCM {
1199 return Err(SseError::UnsupportedAlgo { tag: algo });
1200 }
1201 let key_id = u16::from_be_bytes([body[5], body[6]]);
1202 let key = keyring
1204 .get(key_id)
1205 .ok_or(SseError::KeyNotInKeyring { id: key_id })?;
1206 let mut nonce_bytes = [0u8; NONCE_LEN];
1207 nonce_bytes.copy_from_slice(&body[8..8 + NONCE_LEN]);
1208 let mut tag_bytes = [0u8; TAG_LEN];
1209 tag_bytes.copy_from_slice(&body[8 + NONCE_LEN..SSE_HEADER_BYTES]);
1210 let ct = &body[SSE_HEADER_BYTES..];
1211
1212 let aad = aad_v2(key_id);
1213 let nonce = Nonce::from_slice(&nonce_bytes);
1214 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
1215 ct_with_tag.extend_from_slice(ct);
1216 ct_with_tag.extend_from_slice(&tag_bytes);
1217 let cipher = Aes256Gcm::new(key.as_aes_key());
1218 let plain = cipher
1219 .decrypt(
1220 nonce,
1221 Payload {
1222 msg: &ct_with_tag,
1223 aad: &aad,
1224 },
1225 )
1226 .map_err(|_| SseError::DecryptFailed)?;
1227 Ok(Bytes::from(plain))
1228}
1229
1230pub fn looks_encrypted(body: &[u8]) -> bool {
1241 if body.len() < SSE_HEADER_BYTES {
1242 return false;
1243 }
1244 let m = &body[..4];
1245 m == SSE_MAGIC_V1
1246 || m == SSE_MAGIC_V2
1247 || m == SSE_MAGIC_V3
1248 || m == SSE_MAGIC_V4
1249 || m == SSE_MAGIC_V5
1250 || m == SSE_MAGIC_V6
1251}
1252
1253pub fn peek_magic(body: &[u8]) -> Option<&'static str> {
1264 if body.len() < SSE_HEADER_BYTES {
1265 return None;
1266 }
1267 match &body[..4] {
1268 m if m == SSE_MAGIC_V1 => Some("S4E1"),
1269 m if m == SSE_MAGIC_V2 => Some("S4E2"),
1270 m if m == SSE_MAGIC_V3 => Some("S4E3"),
1271 m if m == SSE_MAGIC_V4 => Some("S4E4"),
1272 m if m == SSE_MAGIC_V5 => Some("S4E5"),
1277 m if m == SSE_MAGIC_V6 => Some("S4E6"),
1279 _ => None,
1280 }
1281}
1282
1283pub type SharedSseKey = Arc<SseKey>;
1284
1285pub const S4E5_HEADER_BYTES: usize = 4 + 1 + 2 + 1 + 4 + 4 + 4; pub const S4E5_PER_CHUNK_OVERHEAD: usize = TAG_LEN; pub const S4E6_HEADER_BYTES: usize = 4 + 1 + 2 + 1 + 4 + 4 + 8; pub const S4E6_PER_CHUNK_OVERHEAD: usize = TAG_LEN; pub const S4E6_MAX_CHUNK_COUNT: u32 = (1u32 << 24) - 1; const S4E5_NONCE_TAG: [u8; 4] = [b'E', b'5', 0, 0];
1407
1408const S4E6_NONCE_PREFIX: u8 = b'E';
1413
1414#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1419enum ChunkedVariant {
1420 V5,
1421 V6,
1422}
1423
1424impl ChunkedVariant {
1425 fn header_bytes(self) -> usize {
1426 match self {
1427 ChunkedVariant::V5 => S4E5_HEADER_BYTES,
1428 ChunkedVariant::V6 => S4E6_HEADER_BYTES,
1429 }
1430 }
1431}
1432
1433fn aad_v5(
1438 chunk_index: u32,
1439 total_chunks: u32,
1440 key_id: u16,
1441 salt: &[u8; 4],
1442) -> [u8; 4 + 1 + 4 + 4 + 2 + 4] {
1443 let mut aad = [0u8; 4 + 1 + 4 + 4 + 2 + 4]; aad[..4].copy_from_slice(SSE_MAGIC_V5);
1445 aad[4] = ALGO_AES_256_GCM;
1446 aad[5..9].copy_from_slice(&chunk_index.to_be_bytes());
1447 aad[9..13].copy_from_slice(&total_chunks.to_be_bytes());
1448 aad[13..15].copy_from_slice(&key_id.to_be_bytes());
1449 aad[15..19].copy_from_slice(salt);
1450 aad
1451}
1452
1453fn aad_v6(
1459 chunk_index: u32,
1460 total_chunks: u32,
1461 key_id: u16,
1462 salt: &[u8; 8],
1463) -> [u8; 4 + 1 + 4 + 4 + 2 + 8] {
1464 let mut aad = [0u8; 4 + 1 + 4 + 4 + 2 + 8]; aad[..4].copy_from_slice(SSE_MAGIC_V6);
1466 aad[4] = ALGO_AES_256_GCM;
1467 aad[5..9].copy_from_slice(&chunk_index.to_be_bytes());
1468 aad[9..13].copy_from_slice(&total_chunks.to_be_bytes());
1469 aad[13..15].copy_from_slice(&key_id.to_be_bytes());
1470 aad[15..23].copy_from_slice(salt);
1471 aad
1472}
1473
1474fn nonce_v5(salt: &[u8; 4], chunk_index: u32) -> [u8; NONCE_LEN] {
1480 let mut n = [0u8; NONCE_LEN];
1481 n[..4].copy_from_slice(&S4E5_NONCE_TAG);
1482 n[4..8].copy_from_slice(salt);
1483 n[8..12].copy_from_slice(&chunk_index.to_be_bytes());
1484 n
1485}
1486
1487fn nonce_v6(salt: &[u8; 8], chunk_index: u32) -> [u8; NONCE_LEN] {
1495 debug_assert!(
1496 chunk_index <= S4E6_MAX_CHUNK_COUNT,
1497 "S4E6 chunk_index {chunk_index} exceeds 24-bit cap (caller MUST validate)",
1498 );
1499 let mut n = [0u8; NONCE_LEN];
1500 n[0] = S4E6_NONCE_PREFIX;
1501 n[1..9].copy_from_slice(salt);
1502 let be = chunk_index.to_be_bytes(); n[9..12].copy_from_slice(&be[1..4]);
1506 n
1507}
1508
1509pub fn encrypt_v2_chunked(
1532 plaintext: &[u8],
1533 keyring: &SseKeyring,
1534 chunk_size: usize,
1535) -> Result<Bytes, SseError> {
1536 if chunk_size == 0 {
1537 return Err(SseError::ChunkSizeInvalid);
1538 }
1539 let (key_id, key) = keyring.active();
1540 let cipher = Aes256Gcm::new(key.as_aes_key());
1541 let mut salt = [0u8; 8];
1542 rand::rngs::OsRng.fill_bytes(&mut salt);
1543
1544 let chunk_count_usize = if plaintext.is_empty() {
1547 1
1548 } else {
1549 plaintext.len().div_ceil(chunk_size)
1550 };
1551 let chunk_count: u32 = u32::try_from(chunk_count_usize).unwrap_or(u32::MAX);
1555 if chunk_count > S4E6_MAX_CHUNK_COUNT {
1556 return Err(SseError::ChunkCountTooLarge {
1557 got: chunk_count,
1558 max: S4E6_MAX_CHUNK_COUNT,
1559 });
1560 }
1561
1562 let mut out = Vec::with_capacity(
1563 S4E6_HEADER_BYTES + plaintext.len() + (chunk_count as usize * S4E6_PER_CHUNK_OVERHEAD),
1564 );
1565 out.extend_from_slice(SSE_MAGIC_V6);
1566 out.push(ALGO_AES_256_GCM);
1567 out.extend_from_slice(&key_id.to_be_bytes());
1568 out.push(0u8); out.extend_from_slice(&(chunk_size as u32).to_be_bytes());
1570 out.extend_from_slice(&chunk_count.to_be_bytes());
1571 out.extend_from_slice(&salt);
1572
1573 for i in 0..chunk_count {
1574 let off = (i as usize).saturating_mul(chunk_size);
1575 let end = off.saturating_add(chunk_size).min(plaintext.len());
1576 let chunk_pt: &[u8] = if off >= plaintext.len() {
1577 &[]
1580 } else {
1581 &plaintext[off..end]
1582 };
1583 let nonce_bytes = nonce_v6(&salt, i);
1584 let nonce = Nonce::from_slice(&nonce_bytes);
1585 let aad = aad_v6(i, chunk_count, key_id, &salt);
1586 let ct_with_tag = cipher
1587 .encrypt(
1588 nonce,
1589 Payload {
1590 msg: chunk_pt,
1591 aad: &aad,
1592 },
1593 )
1594 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
1595 debug_assert!(ct_with_tag.len() >= TAG_LEN);
1596 let split = ct_with_tag.len() - TAG_LEN;
1597 let (ct, tag) = ct_with_tag.split_at(split);
1598 out.extend_from_slice(tag);
1599 out.extend_from_slice(ct);
1600 crate::metrics::record_sse_streaming_chunk("encrypt");
1601 }
1602 Ok(Bytes::from(out))
1603}
1604
1605#[derive(Debug, Clone, Copy)]
1609enum ChunkedSalt {
1610 V5([u8; 4]),
1611 V6([u8; 8]),
1612}
1613
1614#[derive(Debug, Clone, Copy)]
1619struct ChunkedHeader {
1620 #[allow(dead_code)]
1626 variant: ChunkedVariant,
1627 key_id: u16,
1628 chunk_size: u32,
1629 chunk_count: u32,
1630 salt: ChunkedSalt,
1631 chunks_offset: usize,
1635}
1636
1637#[derive(Debug, Clone, Copy)]
1644pub struct S4E6Header<'a> {
1645 pub key_id: u16,
1646 pub chunk_size: u32,
1647 pub chunk_count: u32,
1648 pub salt: &'a [u8; 8],
1649}
1650
1651pub fn parse_s4e6_header(blob: &[u8]) -> Result<S4E6Header<'_>, SseError> {
1655 if blob.len() < S4E6_HEADER_BYTES {
1656 return Err(SseError::ChunkFrameTruncated { what: "header" });
1657 }
1658 if &blob[..4] != SSE_MAGIC_V6 {
1659 let mut got = [0u8; 4];
1660 got.copy_from_slice(&blob[..4]);
1661 return Err(SseError::BadMagic { got });
1662 }
1663 let algo = blob[4];
1664 if algo != ALGO_AES_256_GCM {
1665 return Err(SseError::UnsupportedAlgo { tag: algo });
1666 }
1667 let key_id = u16::from_be_bytes([blob[5], blob[6]]);
1668 let chunk_size = u32::from_be_bytes([blob[8], blob[9], blob[10], blob[11]]);
1670 let chunk_count = u32::from_be_bytes([blob[12], blob[13], blob[14], blob[15]]);
1671 if chunk_size == 0 {
1672 return Err(SseError::ChunkSizeInvalid);
1673 }
1674 if chunk_count == 0 {
1675 return Err(SseError::ChunkFrameTruncated {
1676 what: "chunk_count == 0",
1677 });
1678 }
1679 if chunk_count > S4E6_MAX_CHUNK_COUNT {
1680 return Err(SseError::ChunkCountTooLarge {
1681 got: chunk_count,
1682 max: S4E6_MAX_CHUNK_COUNT,
1683 });
1684 }
1685 let salt: &[u8; 8] = (&blob[16..24]).try_into().expect("8B salt slice");
1686 Ok(S4E6Header {
1687 key_id,
1688 chunk_size,
1689 chunk_count,
1690 salt,
1691 })
1692}
1693
1694fn parse_chunked_header(
1695 body: &[u8],
1696 max_body_bytes: usize,
1697) -> Result<ChunkedHeader, SseError> {
1698 if body.len() < 4 {
1699 return Err(SseError::ChunkFrameTruncated { what: "magic" });
1700 }
1701 let magic = &body[..4];
1702 let variant = if magic == SSE_MAGIC_V5 {
1703 ChunkedVariant::V5
1704 } else if magic == SSE_MAGIC_V6 {
1705 ChunkedVariant::V6
1706 } else {
1707 let mut got = [0u8; 4];
1708 got.copy_from_slice(magic);
1709 return Err(SseError::BadMagic { got });
1710 };
1711 let header_bytes = variant.header_bytes();
1712 if body.len() < header_bytes {
1713 return Err(SseError::ChunkFrameTruncated { what: "header" });
1714 }
1715 let algo = body[4];
1716 if algo != ALGO_AES_256_GCM {
1717 return Err(SseError::UnsupportedAlgo { tag: algo });
1718 }
1719 let key_id = u16::from_be_bytes([body[5], body[6]]);
1720 let chunk_size = u32::from_be_bytes([body[8], body[9], body[10], body[11]]);
1722 let chunk_count = u32::from_be_bytes([body[12], body[13], body[14], body[15]]);
1723 if chunk_size == 0 {
1724 return Err(SseError::ChunkSizeInvalid);
1725 }
1726 if chunk_count == 0 {
1727 return Err(SseError::ChunkFrameTruncated {
1728 what: "chunk_count == 0",
1729 });
1730 }
1731 let salt = match variant {
1732 ChunkedVariant::V5 => {
1733 let mut s = [0u8; 4];
1734 s.copy_from_slice(&body[16..20]);
1735 ChunkedSalt::V5(s)
1736 }
1737 ChunkedVariant::V6 => {
1738 if chunk_count > S4E6_MAX_CHUNK_COUNT {
1743 return Err(SseError::ChunkCountTooLarge {
1744 got: chunk_count,
1745 max: S4E6_MAX_CHUNK_COUNT,
1746 });
1747 }
1748 let mut s = [0u8; 8];
1749 s.copy_from_slice(&body[16..24]);
1750 ChunkedSalt::V6(s)
1751 }
1752 };
1753
1754 let chunk_size_u64 = chunk_size as u64;
1782 let chunk_count_u64 = chunk_count as u64;
1783 let expected_plain_size = chunk_size_u64
1784 .checked_mul(chunk_count_u64)
1785 .ok_or(SseError::ChunkFrameTooLarge {
1786 details: "chunk_size * chunk_count overflows u64",
1787 })?;
1788 let per_chunk_overhead = S4E5_PER_CHUNK_OVERHEAD as u64; let total_tag_overhead = per_chunk_overhead.checked_mul(chunk_count_u64).ok_or(
1790 SseError::ChunkFrameTooLarge {
1791 details: "tag_len * chunk_count overflows u64",
1792 },
1793 )?;
1794 let max_total = expected_plain_size
1795 .checked_add(total_tag_overhead)
1796 .and_then(|t| t.checked_add(header_bytes as u64))
1797 .ok_or(SseError::ChunkFrameTooLarge {
1798 details: "header + plaintext + tag overhead overflows u64",
1799 })?;
1800 if (body.len() as u64) > max_total {
1809 return Err(SseError::ChunkFrameTruncated {
1810 what: "trailing bytes past declared chunk geometry",
1811 });
1812 }
1813 if expected_plain_size > max_body_bytes as u64 {
1818 return Err(SseError::ChunkFrameTooLarge {
1819 details: "declared plaintext exceeds gateway max_body_bytes",
1820 });
1821 }
1822
1823 Ok(ChunkedHeader {
1824 variant,
1825 key_id,
1826 chunk_size,
1827 chunk_count,
1828 salt,
1829 chunks_offset: header_bytes,
1830 })
1831}
1832
1833fn decrypt_chunked_chunk(
1837 cipher: &Aes256Gcm,
1838 chunk_index: u32,
1839 chunk_count: u32,
1840 key_id: u16,
1841 salt: &ChunkedSalt,
1842 tag: &[u8; TAG_LEN],
1843 ct: &[u8],
1844) -> Result<Bytes, SseError> {
1845 let nonce_bytes = match salt {
1846 ChunkedSalt::V5(s) => nonce_v5(s, chunk_index),
1847 ChunkedSalt::V6(s) => nonce_v6(s, chunk_index),
1848 };
1849 let nonce = Nonce::from_slice(&nonce_bytes);
1850 let mut ct_with_tag = Vec::with_capacity(ct.len() + TAG_LEN);
1851 ct_with_tag.extend_from_slice(ct);
1852 ct_with_tag.extend_from_slice(tag);
1853 let result = match salt {
1854 ChunkedSalt::V5(s) => {
1855 let aad = aad_v5(chunk_index, chunk_count, key_id, s);
1856 cipher.decrypt(
1857 nonce,
1858 Payload {
1859 msg: &ct_with_tag,
1860 aad: &aad,
1861 },
1862 )
1863 }
1864 ChunkedSalt::V6(s) => {
1865 let aad = aad_v6(chunk_index, chunk_count, key_id, s);
1866 cipher.decrypt(
1867 nonce,
1868 Payload {
1869 msg: &ct_with_tag,
1870 aad: &aad,
1871 },
1872 )
1873 }
1874 };
1875 result
1876 .map(Bytes::from)
1877 .map_err(|_| SseError::ChunkAuthFailed { chunk_index })
1878}
1879
1880fn walk_chunked<F: FnMut(Bytes) -> Result<(), SseError>>(
1886 body: &[u8],
1887 keyring: &SseKeyring,
1888 max_body_bytes: usize,
1889 mut emit: F,
1890) -> Result<(), SseError> {
1891 let hdr = parse_chunked_header(body, max_body_bytes)?;
1892 let key = keyring
1893 .get(hdr.key_id)
1894 .ok_or(SseError::KeyNotInKeyring { id: hdr.key_id })?;
1895 let cipher = Aes256Gcm::new(key.as_aes_key());
1896
1897 let mut cursor = hdr.chunks_offset;
1898 let chunk_size = hdr.chunk_size as usize;
1899 for i in 0..hdr.chunk_count {
1900 if cursor + TAG_LEN > body.len() {
1901 return Err(SseError::ChunkFrameTruncated { what: "chunk tag" });
1902 }
1903 let tag_off = cursor;
1904 let ct_off = tag_off + TAG_LEN;
1905 let is_last = i + 1 == hdr.chunk_count;
1906 let ct_len = if is_last {
1907 if ct_off > body.len() {
1908 return Err(SseError::ChunkFrameTruncated {
1909 what: "final chunk ciphertext",
1910 });
1911 }
1912 let remaining = body.len() - ct_off;
1913 if remaining > chunk_size {
1914 return Err(SseError::ChunkFrameTruncated {
1915 what: "trailing bytes after final chunk",
1916 });
1917 }
1918 remaining
1919 } else {
1920 chunk_size
1921 };
1922 let ct_end = ct_off + ct_len;
1923 if ct_end > body.len() {
1924 return Err(SseError::ChunkFrameTruncated {
1925 what: "chunk ciphertext",
1926 });
1927 }
1928 let mut tag = [0u8; TAG_LEN];
1929 tag.copy_from_slice(&body[tag_off..ct_off]);
1930 let ct = &body[ct_off..ct_end];
1931 let plain = decrypt_chunked_chunk(
1932 &cipher,
1933 i,
1934 hdr.chunk_count,
1935 hdr.key_id,
1936 &hdr.salt,
1937 &tag,
1938 ct,
1939 )?;
1940 crate::metrics::record_sse_streaming_chunk("decrypt");
1941 emit(plain)?;
1942 cursor = ct_end;
1943 }
1944 if cursor != body.len() {
1945 return Err(SseError::ChunkFrameTruncated {
1946 what: "trailing bytes after declared chunk_count",
1947 });
1948 }
1949 Ok(())
1950}
1951
1952pub fn decrypt_chunked_buffered(
1965 body: &[u8],
1966 keyring: &SseKeyring,
1967 max_body_bytes: usize,
1968) -> Result<Bytes, SseError> {
1969 let hdr = parse_chunked_header(body, max_body_bytes)?;
1970 let mut out = Vec::with_capacity(hdr.chunk_size as usize * hdr.chunk_count as usize);
1976 walk_chunked(body, keyring, max_body_bytes, |chunk| {
1977 out.extend_from_slice(&chunk);
1978 Ok(())
1979 })?;
1980 Ok(Bytes::from(out))
1981}
1982
1983pub fn decrypt_chunked_buffered_default(
1990 body: &[u8],
1991 keyring: &SseKeyring,
1992) -> Result<Bytes, SseError> {
1993 decrypt_chunked_buffered(body, keyring, DEFAULT_MAX_BODY_BYTES)
1994}
1995
1996pub fn decrypt_chunked_stream(
2021 body: bytes::Bytes,
2022 keyring: &SseKeyring,
2023) -> impl futures::Stream<Item = Result<Bytes, SseError>> + 'static {
2024 use futures::stream::{self, StreamExt};
2025
2026 let prelude = (|| {
2033 let hdr = parse_chunked_header(&body, usize::MAX)?;
2043 let key = keyring
2044 .get(hdr.key_id)
2045 .ok_or(SseError::KeyNotInKeyring { id: hdr.key_id })?;
2046 let cipher = Aes256Gcm::new(key.as_aes_key());
2047 Ok::<_, SseError>((hdr, cipher))
2048 })();
2049
2050 match prelude {
2051 Err(e) => stream::iter(std::iter::once(Err(e))).left_stream(),
2052 Ok((hdr, cipher)) => {
2053 let chunks_offset = hdr.chunks_offset;
2054 let state = ChunkedDecryptState {
2055 body,
2056 cipher,
2057 hdr,
2058 cursor: chunks_offset,
2059 next_index: 0,
2060 };
2061 stream::try_unfold(state, decrypt_next_chunk).right_stream()
2062 }
2063 }
2064}
2065
2066struct ChunkedDecryptState {
2070 body: bytes::Bytes,
2071 cipher: Aes256Gcm,
2072 hdr: ChunkedHeader,
2073 cursor: usize,
2074 next_index: u32,
2075}
2076
2077async fn decrypt_next_chunk(
2078 mut state: ChunkedDecryptState,
2079) -> Result<Option<(Bytes, ChunkedDecryptState)>, SseError> {
2080 if state.next_index >= state.hdr.chunk_count {
2081 if state.cursor != state.body.len() {
2084 return Err(SseError::ChunkFrameTruncated {
2085 what: "trailing bytes after declared chunk_count",
2086 });
2087 }
2088 return Ok(None);
2089 }
2090 let i = state.next_index;
2091 let chunk_size = state.hdr.chunk_size as usize;
2092 if state.cursor + TAG_LEN > state.body.len() {
2093 return Err(SseError::ChunkFrameTruncated { what: "chunk tag" });
2094 }
2095 let tag_off = state.cursor;
2096 let ct_off = tag_off + TAG_LEN;
2097 let is_last = i + 1 == state.hdr.chunk_count;
2098 let ct_len = if is_last {
2099 if ct_off > state.body.len() {
2100 return Err(SseError::ChunkFrameTruncated {
2101 what: "final chunk ciphertext",
2102 });
2103 }
2104 let remaining = state.body.len() - ct_off;
2105 if remaining > chunk_size {
2106 return Err(SseError::ChunkFrameTruncated {
2107 what: "trailing bytes after final chunk",
2108 });
2109 }
2110 remaining
2111 } else {
2112 chunk_size
2113 };
2114 let ct_end = ct_off + ct_len;
2115 if ct_end > state.body.len() {
2116 return Err(SseError::ChunkFrameTruncated {
2117 what: "chunk ciphertext",
2118 });
2119 }
2120 let mut tag = [0u8; TAG_LEN];
2121 tag.copy_from_slice(&state.body[tag_off..ct_off]);
2122 let ct = &state.body[ct_off..ct_end];
2123 let plain = decrypt_chunked_chunk(
2124 &state.cipher,
2125 i,
2126 state.hdr.chunk_count,
2127 state.hdr.key_id,
2128 &state.hdr.salt,
2129 &tag,
2130 ct,
2131 )?;
2132 crate::metrics::record_sse_streaming_chunk("decrypt");
2133 state.cursor = ct_end;
2134 state.next_index += 1;
2135 Ok(Some((plain, state)))
2136}
2137
2138#[cfg(test)]
2144fn encrypt_v2_chunked_s4e5_for_test(
2145 plaintext: &[u8],
2146 keyring: &SseKeyring,
2147 chunk_size: usize,
2148) -> Result<Bytes, SseError> {
2149 if chunk_size == 0 {
2150 return Err(SseError::ChunkSizeInvalid);
2151 }
2152 let (key_id, key) = keyring.active();
2153 let cipher = Aes256Gcm::new(key.as_aes_key());
2154 let mut salt = [0u8; 4];
2155 rand::rngs::OsRng.fill_bytes(&mut salt);
2156
2157 let chunk_count: u32 = if plaintext.is_empty() {
2158 1
2159 } else {
2160 plaintext
2161 .len()
2162 .div_ceil(chunk_size)
2163 .try_into()
2164 .expect("chunk_count overflows u32")
2165 };
2166
2167 let mut out = Vec::with_capacity(
2168 S4E5_HEADER_BYTES + plaintext.len() + (chunk_count as usize * S4E5_PER_CHUNK_OVERHEAD),
2169 );
2170 out.extend_from_slice(SSE_MAGIC_V5);
2171 out.push(ALGO_AES_256_GCM);
2172 out.extend_from_slice(&key_id.to_be_bytes());
2173 out.push(0u8);
2174 out.extend_from_slice(&(chunk_size as u32).to_be_bytes());
2175 out.extend_from_slice(&chunk_count.to_be_bytes());
2176 out.extend_from_slice(&salt);
2177
2178 for i in 0..chunk_count {
2179 let off = (i as usize).saturating_mul(chunk_size);
2180 let end = off.saturating_add(chunk_size).min(plaintext.len());
2181 let chunk_pt: &[u8] = if off >= plaintext.len() {
2182 &[]
2183 } else {
2184 &plaintext[off..end]
2185 };
2186 let nonce_bytes = nonce_v5(&salt, i);
2187 let nonce = Nonce::from_slice(&nonce_bytes);
2188 let aad = aad_v5(i, chunk_count, key_id, &salt);
2189 let ct_with_tag = cipher
2190 .encrypt(
2191 nonce,
2192 Payload {
2193 msg: chunk_pt,
2194 aad: &aad,
2195 },
2196 )
2197 .expect("aes-gcm encrypt cannot fail with a 32-byte key");
2198 let split = ct_with_tag.len() - TAG_LEN;
2199 let (ct, tag) = ct_with_tag.split_at(split);
2200 out.extend_from_slice(tag);
2201 out.extend_from_slice(ct);
2202 }
2203 Ok(Bytes::from(out))
2204}
2205
2206#[cfg(test)]
2207mod tests {
2208 use super::*;
2209
2210 fn key32(seed: u8) -> Arc<SseKey> {
2211 Arc::new(SseKey::from_bytes(&[seed; 32]).unwrap())
2212 }
2213
2214 fn keyring_single(seed: u8) -> SseKeyring {
2215 SseKeyring::new(1, key32(seed))
2216 }
2217
2218 #[test]
2219 fn roundtrip_basic_v1() {
2220 let k = SseKey::from_bytes(&[7u8; 32]).unwrap();
2222 let pt = b"the quick brown fox jumps over the lazy dog";
2223 let ct = encrypt(&k, pt);
2224 assert!(looks_encrypted(&ct));
2225 assert_eq!(&ct[..4], SSE_MAGIC_V1);
2226 assert_eq!(ct[4], ALGO_AES_256_GCM);
2227 assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
2228 let kr = SseKeyring::new(1, Arc::new(k));
2230 let pt2 = decrypt(&ct, &kr).unwrap();
2231 assert_eq!(pt2.as_ref(), pt);
2232 }
2233
2234 #[test]
2235 fn s4e2_roundtrip_active_key() {
2236 let kr = keyring_single(7);
2237 let pt = b"S4E2 active-key roundtrip";
2238 let ct = encrypt_v2(pt, &kr);
2239 assert_eq!(&ct[..4], SSE_MAGIC_V2);
2240 assert_eq!(ct[4], ALGO_AES_256_GCM);
2241 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1, "key_id BE");
2242 assert_eq!(ct[7], 0, "reserved byte");
2243 assert_eq!(ct.len(), SSE_HEADER_BYTES + pt.len());
2244 assert!(looks_encrypted(&ct));
2245 let pt2 = decrypt(&ct, &kr).unwrap();
2246 assert_eq!(pt2.as_ref(), pt);
2247 }
2248
2249 #[test]
2250 fn decrypt_s4e1_via_active_only_keyring() {
2251 let k_arc = key32(11);
2254 let legacy_ct = encrypt(&k_arc, b"v0.4 vintage object");
2255 assert_eq!(&legacy_ct[..4], SSE_MAGIC_V1);
2256 let kr = SseKeyring::new(1, Arc::clone(&k_arc));
2257 let plain = decrypt(&legacy_ct, &kr).unwrap();
2258 assert_eq!(plain.as_ref(), b"v0.4 vintage object");
2259 }
2260
2261 #[test]
2262 fn decrypt_s4e2_under_old_key_after_rotation() {
2263 let k1 = key32(1);
2267 let k2 = key32(2);
2268 let mut kr_old = SseKeyring::new(1, Arc::clone(&k1));
2269 let ct = encrypt_v2(b"old-rotation object", &kr_old);
2270 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
2271
2272 kr_old.add(2, Arc::clone(&k2));
2274 let mut kr_new = SseKeyring::new(2, Arc::clone(&k2));
2275 kr_new.add(1, Arc::clone(&k1));
2276
2277 let plain = decrypt(&ct, &kr_new).unwrap();
2278 assert_eq!(plain.as_ref(), b"old-rotation object");
2279
2280 let new_ct = encrypt_v2(b"new-rotation object", &kr_new);
2282 assert_eq!(u16::from_be_bytes([new_ct[5], new_ct[6]]), 2);
2283 let plain_new = decrypt(&new_ct, &kr_new).unwrap();
2284 assert_eq!(plain_new.as_ref(), b"new-rotation object");
2285 }
2286
2287 #[test]
2288 fn s4e2_unknown_key_id_errors() {
2289 let kr = keyring_single(3); let kr_other = SseKeyring::new(99, key32(3));
2291 let ct = encrypt_v2(b"x", &kr_other); let err = decrypt(&ct, &kr).unwrap_err();
2293 assert!(
2294 matches!(err, SseError::KeyNotInKeyring { id: 99 }),
2295 "got {err:?}"
2296 );
2297 }
2298
2299 #[test]
2300 fn s4e2_tampered_key_id_fails_auth() {
2301 let kr = SseKeyring::new(1, key32(4));
2302 let mut kr_with_2 = kr.clone();
2303 kr_with_2.add(2, key32(5)); let mut ct = encrypt_v2(b"do not flip my key id", &kr).to_vec();
2305 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1);
2309 ct[5] = 0;
2310 ct[6] = 2;
2311 let err = decrypt(&ct, &kr_with_2).unwrap_err();
2312 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
2313 }
2314
2315 #[test]
2316 fn s4e2_tampered_ciphertext_fails() {
2317 let kr = SseKeyring::new(7, key32(9));
2318 let mut ct = encrypt_v2(b"secret message v2", &kr).to_vec();
2319 let last = ct.len() - 1;
2320 ct[last] ^= 0x01;
2321 let err = decrypt(&ct, &kr).unwrap_err();
2322 assert!(matches!(err, SseError::DecryptFailed));
2323 }
2324
2325 #[test]
2326 fn s4e2_tampered_algo_byte_fails() {
2327 let kr = SseKeyring::new(1, key32(2));
2328 let mut ct = encrypt_v2(b"hi", &kr).to_vec();
2329 ct[4] = 99;
2330 let err = decrypt(&ct, &kr).unwrap_err();
2331 assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
2332 }
2333
2334 #[test]
2335 fn wrong_key_fails_v1_via_keyring() {
2336 let k1 = SseKey::from_bytes(&[1u8; 32]).unwrap();
2338 let ct = encrypt(&k1, b"secret");
2339 let kr_wrong = SseKeyring::new(1, Arc::new(SseKey::from_bytes(&[2u8; 32]).unwrap()));
2340 let err = decrypt(&ct, &kr_wrong).unwrap_err();
2341 assert!(matches!(err, SseError::DecryptFailed));
2342 }
2343
2344 #[test]
2345 fn rejects_short_body() {
2346 let kr = SseKeyring::new(1, key32(1));
2347 let err = decrypt(b"short", &kr).unwrap_err();
2348 assert!(matches!(err, SseError::TooShort { got: 5 }));
2349 }
2350
2351 #[test]
2352 fn looks_encrypted_passthrough_returns_false() {
2353 let f2 = b"S4F2\x01\x00\x00\x00........................................";
2355 assert!(!looks_encrypted(f2));
2356 assert!(!looks_encrypted(b""));
2357 }
2358
2359 #[test]
2360 fn looks_encrypted_detects_both_v1_and_v2() {
2361 let kr = SseKeyring::new(1, key32(8));
2362 let v1 = encrypt(&SseKey::from_bytes(&[8u8; 32]).unwrap(), b"x");
2363 let v2 = encrypt_v2(b"x", &kr);
2364 assert!(looks_encrypted(&v1));
2365 assert!(looks_encrypted(&v2));
2366 }
2367
2368 #[test]
2369 fn key_from_hex_string() {
2370 let bad =
2371 SseKey::from_bytes(b"0102030405060708090a0b0c0d0e0f10111213141516171819202122232425")
2372 .unwrap_err();
2373 assert!(matches!(bad, SseError::BadKeyLength { .. }));
2374 let good = b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
2375 let _ = SseKey::from_bytes(good).expect("64-char hex should parse");
2376 }
2377
2378 #[test]
2379 fn encrypt_v2_uses_random_nonce() {
2380 let kr = SseKeyring::new(1, key32(3));
2381 let pt = b"deterministic input";
2382 let a = encrypt_v2(pt, &kr);
2383 let b = encrypt_v2(pt, &kr);
2384 assert_ne!(a, b, "nonce must be random per-call");
2385 }
2386
2387 #[test]
2388 fn keyring_active_and_get() {
2389 let k1 = key32(1);
2390 let k2 = key32(2);
2391 let mut kr = SseKeyring::new(1, Arc::clone(&k1));
2392 kr.add(2, Arc::clone(&k2));
2393 let (id, active) = kr.active();
2394 assert_eq!(id, 1);
2395 assert_eq!(active.bytes, [1u8; 32]);
2396 assert!(kr.get(2).is_some());
2397 assert!(kr.get(3).is_none());
2398 }
2399
2400 use base64::Engine as _;
2405
2406 fn cust_key(seed: u8) -> CustomerKeyMaterial {
2407 let key = [seed; KEY_LEN];
2408 let key_md5 = compute_key_md5(&key);
2409 CustomerKeyMaterial { key, key_md5 }
2410 }
2411
2412 #[test]
2413 fn s4e3_roundtrip_happy_path() {
2414 let m = cust_key(42);
2415 let pt = b"top-secret SSE-C payload";
2416 let ct = encrypt_with_source(
2417 pt,
2418 SseSource::CustomerKey {
2419 key: &m.key,
2420 key_md5: &m.key_md5,
2421 },
2422 );
2423 assert_eq!(&ct[..4], SSE_MAGIC_V3);
2425 assert_eq!(ct[4], ALGO_AES_256_GCM);
2426 assert_eq!(&ct[5..5 + KEY_MD5_LEN], &m.key_md5);
2427 assert_eq!(ct.len(), SSE_HEADER_BYTES_V3 + pt.len());
2428 assert!(looks_encrypted(&ct));
2429 let plain = decrypt(
2431 &ct,
2432 SseSource::CustomerKey {
2433 key: &m.key,
2434 key_md5: &m.key_md5,
2435 },
2436 )
2437 .unwrap();
2438 assert_eq!(plain.as_ref(), pt);
2439 let plain2 = decrypt(&ct, &m).unwrap();
2441 assert_eq!(plain2.as_ref(), pt);
2442 }
2443
2444 #[test]
2445 fn s4e3_wrong_key_yields_wrong_customer_key_error() {
2446 let m = cust_key(1);
2447 let other = cust_key(2);
2448 let ct = encrypt_with_source(b"payload", (&m).into());
2449 let err = decrypt(
2450 &ct,
2451 SseSource::CustomerKey {
2452 key: &other.key,
2453 key_md5: &other.key_md5,
2454 },
2455 )
2456 .unwrap_err();
2457 assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
2458 }
2459
2460 #[test]
2461 fn s4e3_tampered_stored_md5_is_caught() {
2462 let m = cust_key(7);
2469 let mut ct = encrypt_with_source(b"victim payload", (&m).into()).to_vec();
2470 ct[5] ^= 0x55;
2472 let err = decrypt(
2474 &ct,
2475 SseSource::CustomerKey {
2476 key: &m.key,
2477 key_md5: &m.key_md5,
2478 },
2479 )
2480 .unwrap_err();
2481 assert!(matches!(err, SseError::WrongCustomerKey), "got {err:?}");
2482 }
2483
2484 #[test]
2485 fn s4e3_tampered_md5_with_matching_supplied_md5_fails_aead() {
2486 let m = cust_key(3);
2490 let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
2491 ct[5] ^= 0xFF;
2492 let mut bogus_md5 = m.key_md5;
2493 bogus_md5[0] ^= 0xFF;
2494 let err = decrypt(
2495 &ct,
2496 SseSource::CustomerKey {
2497 key: &m.key,
2498 key_md5: &bogus_md5,
2499 },
2500 )
2501 .unwrap_err();
2502 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
2503 }
2504
2505 #[test]
2506 fn s4e3_tampered_ciphertext_fails_aead() {
2507 let m = cust_key(8);
2508 let mut ct = encrypt_with_source(b"sealed message", (&m).into()).to_vec();
2509 let last = ct.len() - 1;
2510 ct[last] ^= 0x01;
2511 let err = decrypt(&ct, &m).unwrap_err();
2512 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
2513 }
2514
2515 #[test]
2516 fn s4e3_tampered_algo_byte_rejected() {
2517 let m = cust_key(9);
2518 let mut ct = encrypt_with_source(b"x", (&m).into()).to_vec();
2519 ct[4] = 99;
2520 let err = decrypt(&ct, &m).unwrap_err();
2521 assert!(matches!(err, SseError::UnsupportedAlgo { tag: 99 }));
2522 }
2523
2524 #[test]
2525 fn s4e3_uses_random_nonce() {
2526 let m = cust_key(10);
2527 let a = encrypt_with_source(b"deterministic input", (&m).into());
2528 let b = encrypt_with_source(b"deterministic input", (&m).into());
2529 assert_ne!(a, b, "nonce must be random per-call");
2530 }
2531
2532 #[test]
2533 fn parse_customer_key_headers_happy_path() {
2534 let key = [11u8; KEY_LEN];
2535 let md5 = compute_key_md5(&key);
2536 let key_b64 = base64::engine::general_purpose::STANDARD.encode(key);
2537 let md5_b64 = base64::engine::general_purpose::STANDARD.encode(md5);
2538 let m = parse_customer_key_headers("AES256", &key_b64, &md5_b64).unwrap();
2539 assert_eq!(m.key, key);
2540 assert_eq!(m.key_md5, md5);
2541 }
2542
2543 #[test]
2544 fn parse_customer_key_headers_rejects_wrong_algorithm() {
2545 let key = [1u8; KEY_LEN];
2546 let md5 = compute_key_md5(&key);
2547 let kb = base64::engine::general_purpose::STANDARD.encode(key);
2548 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
2549 let err = parse_customer_key_headers("AES128", &kb, &mb).unwrap_err();
2550 assert!(
2551 matches!(err, SseError::CustomerKeyAlgorithmUnsupported { ref algo } if algo == "AES128"),
2552 "got {err:?}"
2553 );
2554 let err2 = parse_customer_key_headers("aes256", &kb, &mb).unwrap_err();
2556 assert!(
2557 matches!(err2, SseError::CustomerKeyAlgorithmUnsupported { .. }),
2558 "got {err2:?}"
2559 );
2560 }
2561
2562 #[test]
2563 fn parse_customer_key_headers_rejects_wrong_key_length() {
2564 let short_key = vec![5u8; 16]; let md5 = compute_key_md5(&short_key);
2566 let kb = base64::engine::general_purpose::STANDARD.encode(&short_key);
2567 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
2568 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
2569 assert!(
2570 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("key length")),
2571 "got {err:?}"
2572 );
2573 }
2574
2575 #[test]
2576 fn parse_customer_key_headers_rejects_wrong_md5_length() {
2577 let key = [3u8; KEY_LEN];
2578 let kb = base64::engine::general_purpose::STANDARD.encode(key);
2579 let bad_md5 = vec![0u8; 15];
2581 let mb = base64::engine::general_purpose::STANDARD.encode(bad_md5);
2582 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
2583 assert!(
2584 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 length")),
2585 "got {err:?}"
2586 );
2587 }
2588
2589 #[test]
2590 fn parse_customer_key_headers_rejects_md5_mismatch() {
2591 let key = [4u8; KEY_LEN];
2592 let other = [5u8; KEY_LEN];
2593 let kb = base64::engine::general_purpose::STANDARD.encode(key);
2594 let wrong_md5 = compute_key_md5(&other);
2595 let mb = base64::engine::general_purpose::STANDARD.encode(wrong_md5);
2596 let err = parse_customer_key_headers("AES256", &kb, &mb).unwrap_err();
2597 assert!(
2598 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("MD5 does not match")),
2599 "got {err:?}"
2600 );
2601 }
2602
2603 #[test]
2604 fn parse_customer_key_headers_rejects_bad_base64() {
2605 let valid_key = [0u8; KEY_LEN];
2606 let md5 = compute_key_md5(&valid_key);
2607 let mb = base64::engine::general_purpose::STANDARD.encode(md5);
2608 let err = parse_customer_key_headers("AES256", "!!!not-base64!!!", &mb).unwrap_err();
2609 assert!(
2610 matches!(err, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
2611 "got {err:?}"
2612 );
2613 let kb = base64::engine::general_purpose::STANDARD.encode(valid_key);
2615 let err2 = parse_customer_key_headers("AES256", &kb, "??not-base64??").unwrap_err();
2616 assert!(
2617 matches!(err2, SseError::InvalidCustomerKey { reason } if reason.contains("base64")),
2618 "got {err2:?}"
2619 );
2620 }
2621
2622 #[test]
2623 fn parse_customer_key_headers_trims_whitespace() {
2624 let key = [12u8; KEY_LEN];
2626 let md5 = compute_key_md5(&key);
2627 let kb = format!(
2628 " {}\n",
2629 base64::engine::general_purpose::STANDARD.encode(key)
2630 );
2631 let mb = format!(
2632 "\t{} ",
2633 base64::engine::general_purpose::STANDARD.encode(md5)
2634 );
2635 let m = parse_customer_key_headers("AES256", &kb, &mb).unwrap();
2636 assert_eq!(m.key, key);
2637 }
2638
2639 #[test]
2644 fn back_compat_decrypt_s4e1_with_keyring_source() {
2645 let k = key32(33);
2646 let legacy_ct = encrypt(&k, b"v0.4 vintage object");
2647 let kr = SseKeyring::new(1, Arc::clone(&k));
2648 let plain = decrypt(&legacy_ct, &kr).unwrap();
2651 assert_eq!(plain.as_ref(), b"v0.4 vintage object");
2652 let plain2 = decrypt(&legacy_ct, SseSource::Keyring(&kr)).unwrap();
2653 assert_eq!(plain2.as_ref(), b"v0.4 vintage object");
2654 }
2655
2656 #[test]
2657 fn back_compat_decrypt_s4e2_with_keyring_source() {
2658 let kr = keyring_single(34);
2659 let ct = encrypt_v2(b"v0.5 #29 object", &kr);
2660 let plain = decrypt(&ct, &kr).unwrap();
2661 assert_eq!(plain.as_ref(), b"v0.5 #29 object");
2662 let ct2 = encrypt_with_source(b"v0.5 #29 object", SseSource::Keyring(&kr));
2665 assert_eq!(&ct2[..4], SSE_MAGIC_V2);
2666 let plain2 = decrypt(&ct2, &kr).unwrap();
2667 assert_eq!(plain2.as_ref(), b"v0.5 #29 object");
2668 }
2669
2670 #[test]
2671 fn s4e2_blob_with_customer_key_source_is_rejected() {
2672 let kr = keyring_single(50);
2676 let ct = encrypt_v2(b"server-managed object", &kr);
2677 let m = cust_key(99);
2678 let err = decrypt(
2679 &ct,
2680 SseSource::CustomerKey {
2681 key: &m.key,
2682 key_md5: &m.key_md5,
2683 },
2684 )
2685 .unwrap_err();
2686 assert!(matches!(err, SseError::CustomerKeyUnexpected), "got {err:?}");
2687 }
2688
2689 #[test]
2690 fn s4e3_blob_with_keyring_source_is_rejected() {
2691 let m = cust_key(60);
2694 let ct = encrypt_with_source(b"customer-key object", (&m).into());
2695 let kr = keyring_single(60);
2696 let err = decrypt(&ct, &kr).unwrap_err();
2697 assert!(matches!(err, SseError::CustomerKeyRequired), "got {err:?}");
2698 }
2699
2700 #[test]
2701 fn looks_encrypted_detects_s4e3() {
2702 let m = cust_key(13);
2703 let ct = encrypt_with_source(b"x", (&m).into());
2704 assert!(looks_encrypted(&ct));
2705 }
2706
2707 #[test]
2708 fn s4e3_rejects_short_body() {
2709 let mut short = Vec::new();
2712 short.extend_from_slice(SSE_MAGIC_V3);
2713 short.push(ALGO_AES_256_GCM);
2714 short.extend_from_slice(&[0u8; SSE_HEADER_BYTES - 5]);
2717 assert_eq!(short.len(), SSE_HEADER_BYTES);
2718 let m = cust_key(1);
2719 let err = decrypt(
2720 &short,
2721 SseSource::CustomerKey {
2722 key: &m.key,
2723 key_md5: &m.key_md5,
2724 },
2725 )
2726 .unwrap_err();
2727 assert!(matches!(err, SseError::TooShort { .. }), "got {err:?}");
2728 }
2729
2730 #[test]
2731 fn customer_key_material_debug_redacts_key() {
2732 let m = cust_key(99);
2733 let s = format!("{m:?}");
2734 assert!(s.contains("redacted"));
2735 assert!(!s.contains(&format!("{:?}", m.key.as_slice())));
2736 }
2737
2738 #[test]
2739 fn constant_time_eq_basic() {
2740 assert!(constant_time_eq(b"abc", b"abc"));
2741 assert!(!constant_time_eq(b"abc", b"abd"));
2742 assert!(!constant_time_eq(b"abc", b"abcd"));
2743 assert!(constant_time_eq(b"", b""));
2744 }
2745
2746 #[test]
2747 fn compute_key_md5_known_vector() {
2748 let got = compute_key_md5(b"");
2750 let expected_hex = "d41d8cd98f00b204e9800998ecf8427e";
2751 assert_eq!(hex_lower(&got), expected_hex);
2752 }
2753
2754 use crate::kms::{KmsBackend, LocalKms};
2759 use std::collections::HashMap;
2760 use std::path::PathBuf;
2761
2762 fn local_kms_with(key_ids: &[(&str, [u8; 32])]) -> LocalKms {
2763 let mut keks: HashMap<String, [u8; 32]> = HashMap::new();
2764 for (id, k) in key_ids {
2765 keks.insert((*id).to_string(), *k);
2766 }
2767 LocalKms::from_keks(PathBuf::from("/tmp/none"), keks)
2768 }
2769
2770 #[tokio::test]
2771 async fn s4e4_roundtrip_via_local_kms() {
2772 let kms = local_kms_with(&[("alpha", [42u8; 32])]);
2773 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
2774 let mut dek = [0u8; 32];
2775 dek.copy_from_slice(&dek_vec);
2776 let pt = b"SSE-KMS envelope payload across the S4E4 frame";
2777 let ct = encrypt_with_source(
2778 pt,
2779 SseSource::Kms {
2780 dek: &dek,
2781 wrapped: &wrapped,
2782 },
2783 );
2784 assert_eq!(&ct[..4], SSE_MAGIC_V4);
2786 assert_eq!(ct[4], ALGO_AES_256_GCM);
2787 let key_id_len = ct[5] as usize;
2788 assert_eq!(key_id_len, "alpha".len());
2789 assert_eq!(&ct[6..6 + key_id_len], b"alpha");
2790 assert!(looks_encrypted(&ct));
2792 assert_eq!(peek_magic(&ct), Some("S4E4"));
2793 let plain = decrypt_with_kms(&ct, &kms).await.unwrap();
2795 assert_eq!(plain.as_ref(), pt);
2796 }
2797
2798 #[tokio::test]
2799 async fn s4e4_tampered_key_id_fails_aead() {
2800 let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
2801 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
2802 let mut dek = [0u8; 32];
2803 dek.copy_from_slice(&dek_vec);
2804 let mut ct = encrypt_with_source(
2805 b"do not redirect",
2806 SseSource::Kms {
2807 dek: &dek,
2808 wrapped: &wrapped,
2809 },
2810 )
2811 .to_vec();
2812 let key_id_off = 6;
2817 ct[key_id_off] = b'b';
2818 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2819 assert!(
2820 matches!(
2821 err,
2822 SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
2823 | SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
2824 ),
2825 "got {err:?}"
2826 );
2827 }
2828
2829 #[tokio::test]
2830 async fn s4e4_tampered_key_id_to_real_other_id_still_fails() {
2831 let kms = local_kms_with(&[("alpha", [1u8; 32]), ("beta", [2u8; 32])]);
2837 let (dek_vec, wrapped) = kms.generate_dek("alpha").await.unwrap();
2838 let mut dek = [0u8; 32];
2839 dek.copy_from_slice(&dek_vec);
2840 let mut ct = encrypt_with_source(
2841 b"redirect attempt",
2842 SseSource::Kms {
2843 dek: &dek,
2844 wrapped: &wrapped,
2845 },
2846 )
2847 .to_vec();
2848 let key_id_off = 6;
2851 ct[key_id_off..key_id_off + 5].copy_from_slice(b"beta_");
2852 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2859 assert!(
2860 matches!(
2861 err,
2862 SseError::KmsBackend(crate::kms::KmsError::KeyNotFound { .. })
2863 ),
2864 "got {err:?}"
2865 );
2866 }
2867
2868 #[tokio::test]
2869 async fn s4e4_tampered_wrapped_dek_fails_unwrap() {
2870 let kms = local_kms_with(&[("k", [3u8; 32])]);
2871 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
2872 let mut dek = [0u8; 32];
2873 dek.copy_from_slice(&dek_vec);
2874 let mut ct = encrypt_with_source(
2875 b"target body",
2876 SseSource::Kms {
2877 dek: &dek,
2878 wrapped: &wrapped,
2879 },
2880 )
2881 .to_vec();
2882 let key_id_len = ct[5] as usize;
2886 let wrapped_len_off = 6 + key_id_len;
2887 let wrapped_off = wrapped_len_off + 4;
2888 let mid = wrapped_off + (wrapped.ciphertext.len() / 2);
2889 ct[mid] ^= 0xFF;
2890 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2891 assert!(
2892 matches!(
2893 err,
2894 SseError::KmsBackend(crate::kms::KmsError::UnwrapFailed { .. })
2895 ),
2896 "got {err:?}"
2897 );
2898 }
2899
2900 #[tokio::test]
2901 async fn s4e4_tampered_ciphertext_fails_aead() {
2902 let kms = local_kms_with(&[("k", [4u8; 32])]);
2903 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
2904 let mut dek = [0u8; 32];
2905 dek.copy_from_slice(&dek_vec);
2906 let mut ct = encrypt_with_source(
2907 b"sealed body",
2908 SseSource::Kms {
2909 dek: &dek,
2910 wrapped: &wrapped,
2911 },
2912 )
2913 .to_vec();
2914 let last = ct.len() - 1;
2915 ct[last] ^= 0x01;
2916 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
2917 assert!(matches!(err, SseError::DecryptFailed), "got {err:?}");
2918 }
2919
2920 #[tokio::test]
2921 async fn s4e4_uses_random_nonce_and_dek_per_put() {
2922 let kms = local_kms_with(&[("k", [5u8; 32])]);
2923 let (dek1_vec, wrapped1) = kms.generate_dek("k").await.unwrap();
2926 let (dek2_vec, wrapped2) = kms.generate_dek("k").await.unwrap();
2927 let mut dek1 = [0u8; 32];
2928 dek1.copy_from_slice(&dek1_vec);
2929 let mut dek2 = [0u8; 32];
2930 dek2.copy_from_slice(&dek2_vec);
2931 let pt = b"deterministic input";
2932 let a = encrypt_with_source(
2933 pt,
2934 SseSource::Kms {
2935 dek: &dek1,
2936 wrapped: &wrapped1,
2937 },
2938 );
2939 let b = encrypt_with_source(
2940 pt,
2941 SseSource::Kms {
2942 dek: &dek2,
2943 wrapped: &wrapped2,
2944 },
2945 );
2946 assert_ne!(a, b);
2947 let plain_a = decrypt_with_kms(&a, &kms).await.unwrap();
2949 let plain_b = decrypt_with_kms(&b, &kms).await.unwrap();
2950 assert_eq!(plain_a.as_ref(), pt);
2951 assert_eq!(plain_b.as_ref(), pt);
2952 }
2953
2954 #[tokio::test]
2955 async fn s4e4_sync_decrypt_returns_kms_async_required() {
2956 let kms = local_kms_with(&[("k", [6u8; 32])]);
2961 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
2962 let mut dek = [0u8; 32];
2963 dek.copy_from_slice(&dek_vec);
2964 let ct = encrypt_with_source(
2965 b"async only",
2966 SseSource::Kms {
2967 dek: &dek,
2968 wrapped: &wrapped,
2969 },
2970 );
2971 let kr = SseKeyring::new(1, key32(0));
2973 let err = decrypt(&ct, &kr).unwrap_err();
2974 assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
2975 }
2976
2977 #[test]
2978 fn back_compat_s4e1_e2_e3_still_decrypt_via_sync() {
2979 let k = key32(7);
2982 let v1 = encrypt(&k, b"v0.4 vintage");
2983 let kr = SseKeyring::new(1, Arc::clone(&k));
2984 assert_eq!(decrypt(&v1, &kr).unwrap().as_ref(), b"v0.4 vintage");
2985
2986 let v2 = encrypt_v2(b"v0.5 #29 vintage", &kr);
2987 assert_eq!(
2988 decrypt(&v2, &kr).unwrap().as_ref(),
2989 b"v0.5 #29 vintage"
2990 );
2991
2992 let m = cust_key(7);
2993 let v3 = encrypt_with_source(b"v0.5 #27 vintage", (&m).into());
2994 assert_eq!(
2995 decrypt(&v3, &m).unwrap().as_ref(),
2996 b"v0.5 #27 vintage"
2997 );
2998 }
2999
3000 #[test]
3001 fn peek_magic_distinguishes_all_variants() {
3002 let k = key32(9);
3005 let v1 = encrypt(&k, b"x");
3006 assert_eq!(peek_magic(&v1), Some("S4E1"));
3007 let kr = SseKeyring::new(1, Arc::clone(&k));
3008 let v2 = encrypt_v2(b"x", &kr);
3009 assert_eq!(peek_magic(&v2), Some("S4E2"));
3010 let m = cust_key(9);
3011 let v3 = encrypt_with_source(b"x", (&m).into());
3012 assert_eq!(peek_magic(&v3), Some("S4E3"));
3013 let mut v4 = Vec::new();
3018 v4.extend_from_slice(SSE_MAGIC_V4);
3019 v4.extend_from_slice(&[0u8; 40]);
3020 assert_eq!(peek_magic(&v4), Some("S4E4"));
3021 assert!(peek_magic(b"NOPE").is_none());
3023 assert!(peek_magic(b"short").is_none());
3024 assert!(peek_magic(&[0u8; 100]).is_none());
3025 }
3026
3027 #[tokio::test]
3028 async fn s4e4_truncated_frame_errors_cleanly() {
3029 let truncated = b"S4E4\x01\x05hi";
3032 let kms = local_kms_with(&[("k", [1u8; 32])]);
3033 let err = decrypt_with_kms(truncated, &kms).await.unwrap_err();
3034 assert!(
3035 matches!(err, SseError::KmsFrameTooShort { .. }),
3036 "got {err:?}"
3037 );
3038 }
3039
3040 #[tokio::test]
3041 async fn s4e4_oob_key_id_len_errors() {
3042 let mut body = Vec::new();
3046 body.extend_from_slice(SSE_MAGIC_V4);
3047 body.push(ALGO_AES_256_GCM);
3048 body.push(200u8); body.extend_from_slice(&[0u8; 50]);
3053 let kms = local_kms_with(&[("k", [1u8; 32])]);
3054 let err = decrypt_with_kms(&body, &kms).await.unwrap_err();
3055 assert!(
3056 matches!(err, SseError::KmsFrameFieldOob { .. }),
3057 "got {err:?}"
3058 );
3059 }
3060
3061 #[tokio::test]
3062 async fn s4e4_via_keyring_source_into_sync_decrypt_is_kms_async_required() {
3063 let kms = local_kms_with(&[("k", [9u8; 32])]);
3069 let (dek_vec, wrapped) = kms.generate_dek("k").await.unwrap();
3070 let mut dek = [0u8; 32];
3071 dek.copy_from_slice(&dek_vec);
3072 let ct = encrypt_with_source(
3073 b"x",
3074 SseSource::Kms {
3075 dek: &dek,
3076 wrapped: &wrapped,
3077 },
3078 );
3079 let m = cust_key(1);
3080 let err = decrypt(&ct, &m).unwrap_err();
3081 assert!(matches!(err, SseError::KmsAsyncRequired), "got {err:?}");
3082 }
3083
3084 #[tokio::test]
3085 async fn s4e4_looks_encrypted_passthrough_returns_false_for_synthetic() {
3086 let mut not_s4e4 = Vec::new();
3088 not_s4e4.extend_from_slice(b"S4F4");
3089 not_s4e4.extend_from_slice(&[0u8; 60]);
3090 assert!(!looks_encrypted(¬_s4e4));
3091 assert_eq!(peek_magic(¬_s4e4), None);
3092 }
3093
3094 #[tokio::test]
3095 async fn s4e4_aad_length_prefix_prevents_byte_shifting() {
3096 let kms = local_kms_with(&[("kk", [11u8; 32])]);
3103 let (dek_vec, wrapped) = kms.generate_dek("kk").await.unwrap();
3104 let mut dek = [0u8; 32];
3105 dek.copy_from_slice(&dek_vec);
3106 let mut ct = encrypt_with_source(
3107 b"length-shift defense",
3108 SseSource::Kms {
3109 dek: &dek,
3110 wrapped: &wrapped,
3111 },
3112 )
3113 .to_vec();
3114 let key_id_len = ct[5] as usize;
3115 let wrapped_len_off = 6 + key_id_len;
3116 let original_len = u32::from_be_bytes([
3122 ct[wrapped_len_off],
3123 ct[wrapped_len_off + 1],
3124 ct[wrapped_len_off + 2],
3125 ct[wrapped_len_off + 3],
3126 ]);
3127 let new_len = (original_len - 1).to_be_bytes();
3128 ct[wrapped_len_off..wrapped_len_off + 4].copy_from_slice(&new_len);
3129 let err = decrypt_with_kms(&ct, &kms).await.unwrap_err();
3130 assert!(
3133 matches!(
3134 err,
3135 SseError::KmsBackend(_)
3136 | SseError::DecryptFailed
3137 | SseError::KmsFrameFieldOob { .. }
3138 | SseError::KmsFrameTooShort { .. }
3139 ),
3140 "got {err:?}"
3141 );
3142 }
3143
3144 use futures::StreamExt;
3149
3150 async fn collect_chunks(
3153 s: impl futures::Stream<Item = Result<Bytes, SseError>>,
3154 ) -> Result<Vec<Bytes>, SseError> {
3155 let mut out = Vec::new();
3156 let mut s = std::pin::pin!(s);
3157 while let Some(item) = s.next().await {
3158 out.push(item?);
3159 }
3160 Ok(out)
3161 }
3162
3163 #[test]
3164 fn s4e6_encrypt_layout_10mb_at_1mib() {
3165 let kr = keyring_single(0x42);
3170 let chunk_size = 1024 * 1024;
3171 let pt_len = 10 * 1024 * 1024;
3172 let pt = vec![0xAB_u8; pt_len];
3173 let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).expect("encrypt ok");
3174 assert_eq!(&ct[..4], SSE_MAGIC_V6, "new PUTs emit S4E6 (v0.8.1 #57)");
3175 assert_eq!(ct[4], ALGO_AES_256_GCM);
3176 assert_eq!(u16::from_be_bytes([ct[5], ct[6]]), 1, "key_id BE = active id");
3177 assert_eq!(ct[7], 0, "reserved must be 0");
3178 assert_eq!(
3179 u32::from_be_bytes([ct[8], ct[9], ct[10], ct[11]]),
3180 chunk_size as u32,
3181 "chunk_size BE",
3182 );
3183 assert_eq!(
3184 u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
3185 10,
3186 "chunk_count BE — 10 MiB / 1 MiB = 10 (no remainder)",
3187 );
3188 assert_eq!(&ct[16..24].len(), &8, "S4E6 salt slot is 8 bytes");
3192 assert_ne!(&ct[16..24], &[0u8; 8], "S4E6 salt must be random, not zeros");
3193 assert_eq!(
3194 ct.len(),
3195 S4E6_HEADER_BYTES + 10 * S4E6_PER_CHUNK_OVERHEAD + pt_len,
3196 "total = header (24) + 10 tags + plaintext",
3197 );
3198 assert!(looks_encrypted(&ct), "looks_encrypted must accept S4E6");
3199 assert_eq!(peek_magic(&ct), Some("S4E6"));
3200 }
3201
3202 #[tokio::test]
3203 async fn s4e6_decrypt_chunked_stream_byte_equal() {
3204 let kr = keyring_single(0x55);
3207 let pt: Vec<u8> = (0..(10 * 1024 * 1024_u32)).map(|i| (i & 0xFF) as u8).collect();
3208 let ct = encrypt_v2_chunked(&pt, &kr, 1024 * 1024).unwrap();
3209 assert_eq!(&ct[..4], SSE_MAGIC_V6, "new emit is S4E6");
3211 let stream = decrypt_chunked_stream(ct, &kr);
3212 let chunks = collect_chunks(stream).await.expect("stream ok");
3213 assert_eq!(chunks.len(), 10, "10 chunks expected for 10 MiB / 1 MiB");
3214 let mut joined = Vec::with_capacity(pt.len());
3215 for c in chunks {
3216 joined.extend_from_slice(&c);
3217 }
3218 assert_eq!(joined.len(), pt.len(), "byte length matches");
3219 assert_eq!(joined, pt, "byte-equal round-trip");
3220 }
3221
3222 #[tokio::test]
3223 async fn s4e6_single_chunk_for_small_object() {
3224 let kr = keyring_single(0x77);
3228 let pt = b"tiny payload, smaller than chunk_size";
3229 let ct = encrypt_v2_chunked(pt, &kr, 1024 * 1024).unwrap();
3230 assert_eq!(
3231 u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
3232 1,
3233 "small plaintext = single chunk",
3234 );
3235 let stream = decrypt_chunked_stream(ct, &kr);
3236 let chunks = collect_chunks(stream).await.expect("stream ok");
3237 assert_eq!(chunks.len(), 1);
3238 assert_eq!(chunks[0].as_ref(), pt);
3239 }
3240
3241 #[tokio::test]
3242 async fn s4e6_tampered_chunk_n_reports_chunk_index() {
3243 let kr = keyring_single(0x91);
3248 let chunk_size = 1024;
3249 let pt = vec![0xCD_u8; chunk_size * 8]; let mut ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap().to_vec();
3251 let target = S4E6_HEADER_BYTES + 3 * (TAG_LEN + chunk_size) + TAG_LEN;
3254 ct[target] ^= 0x42;
3255 let stream = decrypt_chunked_stream(bytes::Bytes::from(ct), &kr);
3256 let mut s = std::pin::pin!(stream);
3257 for expected_i in 0..3_u32 {
3259 let item = s.next().await.expect("yield");
3260 item.unwrap_or_else(|e| panic!("chunk {expected_i}: {e:?}"));
3261 }
3262 let err = s.next().await.expect("yield error").unwrap_err();
3264 assert!(
3265 matches!(err, SseError::ChunkAuthFailed { chunk_index: 3 }),
3266 "got {err:?}",
3267 );
3268 }
3269
3270 #[tokio::test]
3271 async fn s4e5_back_compat_s4e2_blob_rejected_with_clear_error() {
3272 let kr = keyring_single(0x12);
3276 let s4e2 = encrypt_v2(b"a v2 blob, not chunked", &kr);
3277 let stream = decrypt_chunked_stream(s4e2, &kr);
3278 let result = collect_chunks(stream).await;
3279 let err = result.unwrap_err();
3280 assert!(matches!(err, SseError::BadMagic { .. }), "got {err:?}");
3281 }
3282
3283 #[test]
3284 fn s4e6_salt_randomness_smoke() {
3285 let kr = keyring_single(0x33);
3292 let mut salts = std::collections::HashSet::new();
3293 let n = 1024;
3294 for _ in 0..n {
3295 let ct = encrypt_v2_chunked(b"x", &kr, 64).unwrap();
3296 let mut salt = [0u8; 8];
3297 salt.copy_from_slice(&ct[16..24]);
3298 salts.insert(salt);
3299 }
3300 assert!(
3301 salts.len() > n / 2,
3302 "expected most of the {n} salts to be unique (got {} unique)",
3303 salts.len(),
3304 );
3305 }
3306
3307 #[test]
3308 fn s4e6_chunk_size_zero_invalid() {
3309 let kr = keyring_single(0x66);
3310 let err = encrypt_v2_chunked(b"hi", &kr, 0).unwrap_err();
3311 assert!(matches!(err, SseError::ChunkSizeInvalid));
3312 }
3313
3314 #[tokio::test]
3315 async fn s4e6_truncated_body_reports_frame_truncated() {
3316 let kr = keyring_single(0xA1);
3319 let chunk_size = 256;
3320 let pt = vec![0u8; chunk_size * 4];
3321 let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap();
3322 let trunc = S4E6_HEADER_BYTES + 2 * (TAG_LEN + chunk_size) + 8;
3325 let truncated = bytes::Bytes::copy_from_slice(&ct[..trunc]);
3326 let stream = decrypt_chunked_stream(truncated, &kr);
3327 let result = collect_chunks(stream).await;
3328 let err = result.unwrap_err();
3329 assert!(
3330 matches!(err, SseError::ChunkFrameTruncated { .. }),
3331 "got {err:?}",
3332 );
3333 }
3334
3335 #[test]
3336 fn s4e6_decrypt_buffered_round_trip_via_top_level_decrypt() {
3337 let kr = keyring_single(0xDE);
3341 let pt = b"buffered sync decrypt path".repeat(32);
3342 let ct = encrypt_v2_chunked(&pt, &kr, 13).unwrap();
3343 let plain = decrypt(&ct, &kr).expect("buffered S4E6 decrypt ok");
3344 assert_eq!(plain.as_ref(), pt.as_slice());
3345 }
3346
3347 #[tokio::test]
3348 async fn s4e6_unknown_key_id_in_frame_errors() {
3349 let kr_put = SseKeyring::new(7, key32(0xCC));
3351 let kr_get = keyring_single(0xCC); let ct = encrypt_v2_chunked(b"orphan key", &kr_put, 64).unwrap();
3353 let err = decrypt(&ct, &kr_get).unwrap_err();
3355 assert!(matches!(err, SseError::KeyNotInKeyring { id: 7 }), "got {err:?}");
3356 let stream = decrypt_chunked_stream(ct, &kr_get);
3358 let result = collect_chunks(stream).await;
3359 assert!(
3360 matches!(result, Err(SseError::KeyNotInKeyring { id: 7 })),
3361 "got {result:?}",
3362 );
3363 }
3364
3365 #[tokio::test]
3366 async fn s4e6_final_chunk_smaller_than_chunk_size() {
3367 let kr = keyring_single(0xEF);
3370 let chunk_size = 100;
3371 let pt: Vec<u8> = (0..250_u32).map(|i| i as u8).collect();
3372 let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap();
3373 assert_eq!(
3374 u32::from_be_bytes([ct[12], ct[13], ct[14], ct[15]]),
3375 3,
3376 "ceil(250/100) = 3 chunks",
3377 );
3378 assert_eq!(ct.len(), S4E6_HEADER_BYTES + 48 + 250);
3380 let stream = decrypt_chunked_stream(ct, &kr);
3381 let chunks = collect_chunks(stream).await.expect("stream ok");
3382 assert_eq!(chunks.len(), 3);
3383 assert_eq!(chunks[0].len(), 100);
3384 assert_eq!(chunks[1].len(), 100);
3385 assert_eq!(chunks[2].len(), 50, "final chunk is the remainder");
3386 let joined: Vec<u8> = chunks.iter().flat_map(|c| c.iter().copied()).collect();
3387 assert_eq!(joined, pt);
3388 }
3389
3390 #[test]
3399 fn s4e6_back_compat_read_s4e5_blob() {
3400 let kr = keyring_single(0x57);
3406 let pt = b"v0.8.0 vintage chunked SSE-S4 object".repeat(64);
3407 let s4e5 = encrypt_v2_chunked_s4e5_for_test(&pt, &kr, 91).unwrap();
3408 assert_eq!(&s4e5[..4], SSE_MAGIC_V5, "fixture must be S4E5");
3410 assert_eq!(peek_magic(&s4e5), Some("S4E5"));
3411 let plain_sync = decrypt(&s4e5, &kr).expect("sync S4E5 decrypt ok");
3413 assert_eq!(plain_sync.as_ref(), pt.as_slice());
3414 let collected = futures::executor::block_on(async {
3416 let stream = decrypt_chunked_stream(s4e5.clone(), &kr);
3417 collect_chunks(stream).await
3418 })
3419 .expect("stream S4E5 decrypt ok");
3420 let mut joined = Vec::with_capacity(pt.len());
3421 for c in collected {
3422 joined.extend_from_slice(&c);
3423 }
3424 assert_eq!(joined, pt, "S4E5 streaming round-trip byte-equal");
3425 }
3426
3427 #[test]
3428 fn s4e6_layout_24_bytes_header() {
3429 assert_eq!(S4E6_HEADER_BYTES, 24);
3433 assert_eq!(S4E6_PER_CHUNK_OVERHEAD, TAG_LEN);
3434 assert_eq!(S4E6_HEADER_BYTES, S4E5_HEADER_BYTES + 4);
3435 }
3436
3437 #[test]
3438 fn s4e6_parse_header_round_trip() {
3439 let kr = keyring_single(0xAB);
3443 let chunk_size = 256;
3444 let pt = vec![1u8; 7 * chunk_size];
3445 let ct = encrypt_v2_chunked(&pt, &kr, chunk_size).unwrap();
3446 let hdr = parse_s4e6_header(&ct).expect("parse ok");
3447 assert_eq!(hdr.key_id, 1);
3448 assert_eq!(hdr.chunk_size, chunk_size as u32);
3449 assert_eq!(hdr.chunk_count, 7);
3450 assert_eq!(hdr.salt.len(), 8);
3451 let bogus = b"S4E2\x01\x00\x00\x00........................";
3453 let err = parse_s4e6_header(bogus).unwrap_err();
3454 assert!(matches!(err, SseError::BadMagic { .. }), "got {err:?}");
3455 let err2 = parse_s4e6_header(&ct[..10]).unwrap_err();
3457 assert!(matches!(err2, SseError::ChunkFrameTruncated { .. }), "got {err2:?}");
3458 }
3459
3460 #[test]
3461 fn s4e6_salt_uniqueness_smoke_16m() {
3462 let kr = keyring_single(0xA6);
3479 let mut salts = std::collections::HashSet::with_capacity(16384);
3480 let n = 16384_usize;
3481 let mut collisions_top4 = 0usize;
3482 let mut top4_seen = std::collections::HashSet::with_capacity(16384);
3483 for _ in 0..n {
3484 let ct = encrypt_v2_chunked(b"x", &kr, 64).unwrap();
3485 let mut salt = [0u8; 8];
3486 salt.copy_from_slice(&ct[16..24]);
3487 salts.insert(salt);
3488 let mut top4 = [0u8; 4];
3498 top4.copy_from_slice(&salt[..4]);
3499 if !top4_seen.insert(top4) {
3500 collisions_top4 += 1;
3501 }
3502 }
3503 assert_eq!(
3504 salts.len(),
3505 n,
3506 "all 8-byte salts must be unique across {n} PUTs (got {} unique)",
3507 salts.len(),
3508 );
3509 eprintln!(
3516 "s4e6_salt_uniqueness_smoke_16m: 16k PUTs, full 8B salts \
3517 all unique ({}/{}), simulated 4B-truncated salt yielded \
3518 {} collisions (this is what S4E5 would have shipped)",
3519 salts.len(),
3520 n,
3521 collisions_top4,
3522 );
3523 }
3527
3528 #[test]
3529 fn s4e6_max_chunks_24bit() {
3530 assert_eq!(S4E6_MAX_CHUNK_COUNT, (1u32 << 24) - 1);
3539 assert_eq!(S4E6_MAX_CHUNK_COUNT, 16_777_215);
3540
3541 let kr = keyring_single(0xC4);
3545 let pt = vec![0u8; (S4E6_MAX_CHUNK_COUNT as usize) + 1]; let err = encrypt_v2_chunked(&pt, &kr, 1).unwrap_err();
3547 assert!(
3548 matches!(
3549 err,
3550 SseError::ChunkCountTooLarge {
3551 got: 16_777_216,
3552 max: 16_777_215
3553 }
3554 ),
3555 "got {err:?}",
3556 );
3557
3558 let pt_ok = vec![0u8; 1023];
3567 let ct = encrypt_v2_chunked(&pt_ok, &kr, 1).expect("under-cap PUT must succeed");
3568 let hdr = parse_s4e6_header(&ct).unwrap();
3569 assert_eq!(hdr.chunk_count, 1023);
3570
3571 let mut tampered = ct.to_vec();
3575 let bad = (S4E6_MAX_CHUNK_COUNT + 1).to_be_bytes();
3577 tampered[12..16].copy_from_slice(&bad);
3578 let err2 = parse_s4e6_header(&tampered).unwrap_err();
3579 assert!(
3580 matches!(
3581 err2,
3582 SseError::ChunkCountTooLarge { got: 16_777_216, max: 16_777_215 }
3583 ),
3584 "got {err2:?}",
3585 );
3586 }
3587
3588 #[test]
3589 fn s4e6_nonce_v6_layout() {
3590 let salt = [0xAA_u8; 8];
3594 let n0 = nonce_v6(&salt, 0);
3595 assert_eq!(n0[0], b'E');
3596 assert_eq!(&n0[1..9], &salt);
3597 assert_eq!(&n0[9..12], &[0, 0, 0]);
3598 let n1 = nonce_v6(&salt, 1);
3599 assert_eq!(&n1[9..12], &[0, 0, 1]);
3600 let n_mid = nonce_v6(&salt, 0x123456);
3601 assert_eq!(&n_mid[9..12], &[0x12, 0x34, 0x56]);
3602 let n_max = nonce_v6(&salt, S4E6_MAX_CHUNK_COUNT);
3603 assert_eq!(&n_max[9..12], &[0xFF, 0xFF, 0xFF]);
3604 }
3605
3606 #[tokio::test]
3607 async fn s4e6_tampered_salt_byte_fails_aead() {
3608 let kr = keyring_single(0xB6);
3613 let pt = b"salt-in-aad coverage".repeat(64);
3614 let mut ct = encrypt_v2_chunked(&pt, &kr, 128).unwrap().to_vec();
3615 ct[20] ^= 0x01;
3617 let err = decrypt(&ct, &kr).unwrap_err();
3618 assert!(
3619 matches!(err, SseError::ChunkAuthFailed { chunk_index: 0 }),
3620 "got {err:?}",
3621 );
3622 }
3623
3624 fn synth_s4e6_header(chunk_size: u32, chunk_count: u32) -> Vec<u8> {
3639 let mut blob = Vec::with_capacity(S4E6_HEADER_BYTES);
3640 blob.extend_from_slice(SSE_MAGIC_V6);
3641 blob.push(ALGO_AES_256_GCM);
3642 blob.extend_from_slice(&1_u16.to_be_bytes()); blob.push(0); blob.extend_from_slice(&chunk_size.to_be_bytes());
3645 blob.extend_from_slice(&chunk_count.to_be_bytes());
3646 blob.extend_from_slice(&[0u8; 8]); debug_assert_eq!(blob.len(), S4E6_HEADER_BYTES);
3648 blob
3649 }
3650
3651 #[test]
3652 fn s4e6_header_claims_huge_size_rejected_pre_alloc() {
3653 let kr = keyring_single(0x01);
3659 let chunk_size: u32 = 1 << 30; let chunk_count: u32 = 100;
3661 let mut blob = synth_s4e6_header(chunk_size, chunk_count);
3662 blob.extend_from_slice(&[0u8; 100]);
3665 let err = decrypt_chunked_buffered_default(&blob, &kr).unwrap_err();
3666 assert!(
3667 matches!(err, SseError::ChunkFrameTooLarge { .. }),
3668 "expected ChunkFrameTooLarge (declared 100 GiB > 5 GiB cap), got {err:?}",
3669 );
3670 let err2 = decrypt_chunked_buffered(&blob, &kr, 1024 * 1024).unwrap_err();
3673 assert!(
3674 matches!(err2, SseError::ChunkFrameTooLarge { .. }),
3675 "expected ChunkFrameTooLarge under tighter cap, got {err2:?}",
3676 );
3677 }
3678
3679 #[test]
3680 fn s4e6_header_chunk_size_x_chunk_count_overflows_u64() {
3681 let kr = keyring_single(0x02);
3694 let mut blob = Vec::with_capacity(S4E5_HEADER_BYTES);
3697 blob.extend_from_slice(SSE_MAGIC_V5);
3698 blob.push(ALGO_AES_256_GCM);
3699 blob.extend_from_slice(&1_u16.to_be_bytes());
3700 blob.push(0);
3701 blob.extend_from_slice(&u32::MAX.to_be_bytes()); blob.extend_from_slice(&u32::MAX.to_be_bytes()); blob.extend_from_slice(&[0u8; 4]); debug_assert_eq!(blob.len(), S4E5_HEADER_BYTES);
3705 let err = decrypt_chunked_buffered_default(&blob, &kr).unwrap_err();
3706 assert!(
3707 matches!(err, SseError::ChunkFrameTooLarge { .. }),
3708 "expected ChunkFrameTooLarge for u64 overflow, got {err:?}",
3709 );
3710
3711 let direct = parse_chunked_header(&blob, usize::MAX).unwrap_err();
3714 assert!(
3715 matches!(direct, SseError::ChunkFrameTooLarge { .. }),
3716 "streaming path: expected ChunkFrameTooLarge, got {direct:?}",
3717 );
3718 }
3719
3720 #[test]
3721 fn s4e6_header_within_max_body_bytes_passes() {
3722 let kr = keyring_single(0x03);
3729 let chunk_size: u32 = 1024 * 1024; let chunk_count: u32 = 100;
3731 let mut blob = synth_s4e6_header(chunk_size, chunk_count);
3732 let chunk_array_size =
3737 (chunk_count as usize) * (S4E6_PER_CHUNK_OVERHEAD + chunk_size as usize);
3738 blob.resize(blob.len() + chunk_array_size, 0);
3739 let err =
3740 decrypt_chunked_buffered(&blob, &kr, DEFAULT_MAX_BODY_BYTES).unwrap_err();
3741 assert!(
3749 matches!(err, SseError::ChunkAuthFailed { chunk_index: 0 }),
3750 "expected ChunkAuthFailed (guard let it through), got {err:?}",
3751 );
3752 }
3753
3754 #[test]
3755 fn s4e6_header_exceeds_max_body_bytes_rejected() {
3756 let kr = keyring_single(0x04);
3763 let chunk_size: u32 = 1024 * 1024; let chunk_count: u32 = 6000;
3765 let blob = synth_s4e6_header(chunk_size, chunk_count);
3766 let err = decrypt_chunked_buffered(&blob, &kr, DEFAULT_MAX_BODY_BYTES).unwrap_err();
3770 assert!(
3771 matches!(err, SseError::ChunkFrameTooLarge { .. }),
3772 "expected ChunkFrameTooLarge (6 GiB declared > 5 GiB cap), got {err:?}",
3773 );
3774
3775 let chunk_size_b: u32 = 1024 * 1024; let chunk_count_b: u32 = 100;
3780 let mut blob_b = synth_s4e6_header(chunk_size_b, chunk_count_b);
3781 let pad_b =
3782 (chunk_count_b as usize) * (S4E6_PER_CHUNK_OVERHEAD + chunk_size_b as usize);
3783 blob_b.resize(blob_b.len() + pad_b, 0);
3784 let err_b = decrypt_chunked_buffered(&blob_b, &kr, 1024 * 1024).unwrap_err();
3786 assert!(
3787 matches!(err_b, SseError::ChunkFrameTooLarge { .. }),
3788 "expected ChunkFrameTooLarge (cap < declared), got {err_b:?}",
3789 );
3790 }
3791
3792 #[test]
3793 fn s4e6_random_header_never_panics() {
3794 use rand::{Rng, SeedableRng, rngs::StdRng};
3805 let mut rng = StdRng::seed_from_u64(0xC0FF_EE64_6464_64DE);
3806 let mut max_body_bytes_choices = [
3807 0_usize,
3808 1024,
3809 1024 * 1024,
3810 DEFAULT_MAX_BODY_BYTES,
3811 usize::MAX,
3812 ]
3813 .iter()
3814 .copied()
3815 .cycle();
3816 for _ in 0..100_000 {
3817 let body_len = rng.gen_range(0..=256_usize);
3822 let mut body = vec![0u8; body_len];
3823 rng.fill(body.as_mut_slice());
3824 if body_len >= 4 && rng.gen_bool(0.25) {
3828 if rng.gen_bool(0.5) {
3829 body[..4].copy_from_slice(SSE_MAGIC_V5);
3830 } else {
3831 body[..4].copy_from_slice(SSE_MAGIC_V6);
3832 }
3833 }
3834 let max_cap = max_body_bytes_choices.next().unwrap();
3835 let _ = parse_chunked_header(&body, max_cap);
3840 }
3841 }
3842
3843 #[test]
3848 fn s4e5_extreme_overflow_chunk_count_u32_max() {
3849 let kr = keyring_single(0x05);
3850 let mut blob = Vec::with_capacity(S4E5_HEADER_BYTES);
3856 blob.extend_from_slice(SSE_MAGIC_V5);
3857 blob.push(ALGO_AES_256_GCM);
3858 blob.extend_from_slice(&1_u16.to_be_bytes());
3859 blob.push(0);
3860 blob.extend_from_slice(&u32::MAX.to_be_bytes());
3861 blob.extend_from_slice(&u32::MAX.to_be_bytes());
3862 blob.extend_from_slice(&[0u8; 4]);
3863 let err = decrypt_chunked_buffered_default(&blob, &kr).unwrap_err();
3864 assert!(
3865 matches!(err, SseError::ChunkFrameTooLarge { .. }),
3866 "expected ChunkFrameTooLarge for extreme overflow, got {err:?}",
3867 );
3868 }
3869}