Skip to main content

s2_common/record/
encryption.rs

1//! Encrypted record storage, wire format, and raw cryptography.
2//!
3//! ```text
4//! [format_id: 1 byte] [nonce] [ciphertext] [tag]
5//! ```
6//!
7//! | format_id | Format         | Nonce  | Tag  |
8//! |-----------|----------------|--------|------|
9//! | 0x01      | AEGIS-256 v1   | 32 B   | 16 B |
10//! | 0x02      | AES-256-GCM v1 | 12 B   | 16 B |
11//!
12//! The leading format byte identifies the full encrypted record framing,
13//! including the framing version and encryption algorithm. This leaves room for
14//! future layout changes without a separate version byte.
15//!
16//! AAD is caller-supplied associated data and is not stored in the encoded
17//! record.
18//!
19//! Plaintext records are stored as `StoredRecord::Plaintext(Record)` and use
20//! the same command/envelope framing as the logical record layer.
21//!
22//! Encrypted envelope records are stored as `StoredRecord::Encrypted`. Their
23//! outer record type is `RecordType::EncryptedEnvelope`, and the encoded body is
24//! an [`EncryptedRecord`] containing encrypted bytes for the byte-for-byte
25//! plaintext [`EnvelopeRecord`](super::EnvelopeRecord) encoding.
26//!
27//! The stored `metered_size` remains the logical plaintext metered size rather
28//! than the encoded encrypted record size, so protection does not change
29//! append/read metering, limits, or accounting.
30
31use aegis::aegis256::Aegis256;
32use aes_gcm::{Aes256Gcm, KeyInit, aead::AeadInPlace};
33use bytes::{BufMut, Bytes, BytesMut};
34use rand::random;
35
36use super::{Encodable, Metered, MeteredSize, Record, RecordDecodeError, SeqNum, StoredRecord};
37use crate::{
38    deep_size::DeepSize,
39    encryption::{EncryptionAlgorithm, EncryptionSpec},
40    record::MeteredExt as _,
41};
42
43const FORMAT_ID_LEN: usize = 1;
44
45const FORMAT_ID_AEGIS256_V1: u8 = 0x01;
46const FORMAT_ID_AES256GCM_V1: u8 = 0x02;
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub(crate) enum EncryptedRecordFormat {
50    Aegis256V1,
51    Aes256GcmV1,
52}
53
54impl EncryptedRecordFormat {
55    const fn try_from_format_id(format_id: u8) -> Result<Self, RecordDecodeError> {
56        match format_id {
57            FORMAT_ID_AEGIS256_V1 => Ok(Self::Aegis256V1),
58            FORMAT_ID_AES256GCM_V1 => Ok(Self::Aes256GcmV1),
59            _ => Err(RecordDecodeError::InvalidValue(
60                "EncryptedRecord",
61                "invalid encrypted record format id",
62            )),
63        }
64    }
65
66    const fn format_id(self) -> u8 {
67        match self {
68            Self::Aegis256V1 => FORMAT_ID_AEGIS256_V1,
69            Self::Aes256GcmV1 => FORMAT_ID_AES256GCM_V1,
70        }
71    }
72
73    const fn algorithm(self) -> EncryptionAlgorithm {
74        match self {
75            Self::Aegis256V1 => EncryptionAlgorithm::Aegis256,
76            Self::Aes256GcmV1 => EncryptionAlgorithm::Aes256Gcm,
77        }
78    }
79
80    const fn nonce_len(self) -> usize {
81        match self {
82            Self::Aegis256V1 => 32,
83            Self::Aes256GcmV1 => 12,
84        }
85    }
86
87    const fn tag_len(self) -> usize {
88        match self {
89            Self::Aegis256V1 => 16,
90            Self::Aes256GcmV1 => 16,
91        }
92    }
93
94    fn put_random_nonce(self, buf: &mut impl BufMut) {
95        match self {
96            Self::Aegis256V1 => buf.put_slice(&random::<[u8; 32]>()),
97            Self::Aes256GcmV1 => buf.put_slice(&random::<[u8; 12]>()),
98        }
99    }
100
101    const fn max_assignable_seq_num(self) -> SeqNum {
102        match self {
103            Self::Aegis256V1 => SeqNum::MAX,
104            Self::Aes256GcmV1 => (1u64 << 32) - 1,
105        }
106    }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
110pub enum RecordDecryptionError {
111    #[error("record encryption algorithm mismatch")]
112    AlgorithmMismatch {
113        expected: Option<EncryptionAlgorithm>,
114        actual: Option<EncryptionAlgorithm>,
115    },
116    #[error("record decryption failed")]
117    AuthenticationFailed,
118    #[error("malformed encrypted record")]
119    MalformedEncryptedRecord,
120    #[error("decrypted record metered size mismatch: stored {stored}, actual {actual}")]
121    MeteredSizeMismatch { stored: usize, actual: usize },
122    #[error("malformed decrypted record: {0}")]
123    MalformedDecryptedRecord(#[from] RecordDecodeError),
124}
125
126#[derive(PartialEq, Eq, Clone)]
127pub struct EncryptedRecord {
128    encoded: Bytes,
129    format: EncryptedRecordFormat,
130}
131
132impl EncryptedRecord {
133    fn new(encoded: Bytes, format: EncryptedRecordFormat) -> Self {
134        debug_assert!(!encoded.is_empty());
135        debug_assert_eq!(encoded[0], format.format_id());
136        debug_assert!(encoded.len() >= FORMAT_ID_LEN + format.nonce_len() + format.tag_len());
137        Self { encoded, format }
138    }
139
140    pub fn algorithm(&self) -> EncryptionAlgorithm {
141        self.format.algorithm()
142    }
143
144    pub fn max_assignable_seq_num(&self) -> SeqNum {
145        self.format.max_assignable_seq_num()
146    }
147
148    pub(crate) fn nonce(&self) -> &[u8] {
149        let start = FORMAT_ID_LEN;
150        let end = start + self.format.nonce_len();
151        &self.encoded[start..end]
152    }
153
154    pub(crate) fn ciphertext(&self) -> &[u8] {
155        let start = FORMAT_ID_LEN + self.format.nonce_len();
156        let end = self.encoded.len() - self.format.tag_len();
157        &self.encoded[start..end]
158    }
159
160    pub(crate) fn tag(&self) -> &[u8] {
161        let start = self.encoded.len() - self.format.tag_len();
162        let end = self.encoded.len();
163        &self.encoded[start..end]
164    }
165
166    fn into_mut_encoded(self) -> BytesMut {
167        self.encoded
168            .try_into_mut()
169            .unwrap_or_else(|encoded| BytesMut::from(encoded.as_ref()))
170    }
171}
172
173impl std::fmt::Debug for EncryptedRecord {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        f.debug_struct("EncryptedRecord")
176            .field("format_id", &self.encoded[0])
177            .field("format", &self.format)
178            .field("algorithm", &self.format.algorithm())
179            .field("nonce.len", &self.nonce().len())
180            .field("ciphertext.len", &self.ciphertext().len())
181            .field("tag.len", &self.tag().len())
182            .finish()
183    }
184}
185
186impl DeepSize for EncryptedRecord {
187    fn deep_size(&self) -> usize {
188        self.encoded.len()
189    }
190}
191
192impl Encodable for EncryptedRecord {
193    fn encoded_size(&self) -> usize {
194        self.encoded.len()
195    }
196
197    fn encode_into(&self, buf: &mut impl BufMut) {
198        buf.put_slice(self.encoded.as_ref());
199    }
200}
201
202pub fn encrypt_record(
203    record: Metered<Record>,
204    encryption: &EncryptionSpec,
205    aad: &[u8],
206) -> Metered<StoredRecord> {
207    let metered_size = record.metered_size();
208    let record = match (record.into_inner(), encryption) {
209        (record @ Record::Command(_), _) => StoredRecord::Plaintext(record),
210        (record @ Record::Envelope(_), EncryptionSpec::Plain) => StoredRecord::Plaintext(record),
211        (Record::Envelope(envelope), EncryptionSpec::Aegis256(key)) => {
212            let format = EncryptedRecordFormat::Aegis256V1;
213            let (mut encoded, payload_start) = prep_encryption_buffer(&envelope, format);
214            let (prefix, payload) = encoded.split_at_mut(payload_start);
215            let nonce: &[u8; 32] = prefix[FORMAT_ID_LEN..]
216                .try_into()
217                .expect("AEGIS-256 nonce must be 32 bytes");
218            let tag =
219                Aegis256::<16>::new(key.expose_secret(), nonce).encrypt_in_place(payload, aad);
220            encoded.put_slice(tag.as_ref());
221
222            let encrypted = EncryptedRecord::new(encoded.freeze(), format);
223            StoredRecord::encrypted(encrypted, metered_size)
224        }
225        (Record::Envelope(envelope), EncryptionSpec::Aes256Gcm(key)) => {
226            let format = EncryptedRecordFormat::Aes256GcmV1;
227            let (mut encoded, payload_start) = prep_encryption_buffer(&envelope, format);
228            let (prefix, payload) = encoded.split_at_mut(payload_start);
229            let nonce = aes_gcm::Nonce::from_slice(&prefix[FORMAT_ID_LEN..]);
230            let tag = Aes256Gcm::new(aes_gcm::Key::<Aes256Gcm>::from_slice(key.expose_secret()))
231                .encrypt_in_place_detached(nonce, aad, payload)
232                .expect("AES-256-GCM encryption should not fail on size validation");
233            encoded.put_slice(tag.as_ref());
234
235            let encrypted = EncryptedRecord::new(encoded.freeze(), format);
236            StoredRecord::encrypted(encrypted, metered_size)
237        }
238    };
239    Metered::with_size(metered_size, record)
240}
241
242fn prep_encryption_buffer(
243    envelope: &super::EnvelopeRecord,
244    format: EncryptedRecordFormat,
245) -> (BytesMut, usize) {
246    let payload_start = FORMAT_ID_LEN + format.nonce_len();
247    let mut encoded =
248        BytesMut::with_capacity(payload_start + envelope.encoded_size() + format.tag_len());
249    encoded.put_u8(format.format_id());
250    format.put_random_nonce(&mut encoded);
251    envelope.encode_into(&mut encoded);
252    (encoded, payload_start)
253}
254
255impl TryFrom<Bytes> for EncryptedRecord {
256    type Error = RecordDecodeError;
257
258    fn try_from(encoded: Bytes) -> Result<Self, Self::Error> {
259        if encoded.len() < FORMAT_ID_LEN {
260            return Err(RecordDecodeError::Truncated("EncryptedRecordFormatId"));
261        }
262
263        let format = EncryptedRecordFormat::try_from_format_id(encoded[0])?;
264        let nonce_len = format.nonce_len();
265        let tag_len = format.tag_len();
266        if encoded.len() < FORMAT_ID_LEN + nonce_len + tag_len {
267            return Err(RecordDecodeError::Truncated("EncryptedRecordFrame"));
268        }
269
270        Ok(Self::new(encoded, format))
271    }
272}
273
274pub fn decrypt_stored_record(
275    record: StoredRecord,
276    encryption: &EncryptionSpec,
277    aad: &[u8],
278) -> Result<Metered<Record>, RecordDecryptionError> {
279    match record {
280        StoredRecord::Plaintext(record @ Record::Command(_)) => Ok(record.metered()),
281        StoredRecord::Plaintext(record @ Record::Envelope(_)) => match encryption {
282            EncryptionSpec::Plain => Ok(record.metered()),
283            EncryptionSpec::Aegis256(_) => Err(RecordDecryptionError::AlgorithmMismatch {
284                expected: Some(EncryptionAlgorithm::Aegis256),
285                actual: None,
286            }),
287            EncryptionSpec::Aes256Gcm(_) => Err(RecordDecryptionError::AlgorithmMismatch {
288                expected: Some(EncryptionAlgorithm::Aes256Gcm),
289                actual: None,
290            }),
291        },
292        StoredRecord::Encrypted {
293            metered_size,
294            record: encrypted,
295        } => {
296            let plaintext = decrypt_payload(encrypted, encryption, aad)?;
297            let record = Record::Envelope(plaintext.try_into()?);
298            let actual_metered_size = record.metered_size();
299            if metered_size != actual_metered_size {
300                return Err(RecordDecryptionError::MeteredSizeMismatch {
301                    stored: metered_size,
302                    actual: actual_metered_size,
303                });
304            }
305            Ok(Metered::with_size(metered_size, record))
306        }
307    }
308}
309
310fn decrypt_payload(
311    record: EncryptedRecord,
312    encryption: &EncryptionSpec,
313    aad: &[u8],
314) -> Result<Bytes, RecordDecryptionError> {
315    let format = record.format;
316    let (mut encoded, payload_start, payload_end) = decryption_layout(record, format)?;
317    let plaintext_len = payload_end - payload_start;
318
319    match (format, encryption) {
320        (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Aegis256(key)) => {
321            let (prefix, payload_and_tag) = encoded.split_at_mut(payload_start);
322            let nonce: &[u8; 32] = prefix
323                .get(FORMAT_ID_LEN..)
324                .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?
325                .try_into()
326                .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
327            let (ciphertext, tag) = payload_and_tag.split_at_mut(plaintext_len);
328            let tag: &[u8; 16] = tag
329                .as_ref()
330                .try_into()
331                .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
332            Aegis256::<16>::new(key.expose_secret(), nonce)
333                .decrypt_in_place(ciphertext, tag, aad)
334                .map_err(|_| RecordDecryptionError::AuthenticationFailed)?;
335            Ok(decryption_finish(encoded, payload_start, plaintext_len))
336        }
337        (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Plain) => {
338            Err(RecordDecryptionError::AlgorithmMismatch {
339                expected: None,
340                actual: Some(EncryptionAlgorithm::Aegis256),
341            })
342        }
343        (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Aes256Gcm(_)) => {
344            Err(RecordDecryptionError::AlgorithmMismatch {
345                expected: Some(EncryptionAlgorithm::Aes256Gcm),
346                actual: Some(EncryptionAlgorithm::Aegis256),
347            })
348        }
349        (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Aes256Gcm(key)) => {
350            let cipher = Aes256Gcm::new(aes_gcm::Key::<Aes256Gcm>::from_slice(key.expose_secret()));
351            let (prefix, payload_and_tag) = encoded.split_at_mut(payload_start);
352            let nonce: &[u8; 12] = prefix
353                .get(FORMAT_ID_LEN..)
354                .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?
355                .try_into()
356                .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
357            let nonce = aes_gcm::Nonce::from_slice(nonce);
358            let (ciphertext, tag) = payload_and_tag.split_at_mut(plaintext_len);
359            let tag: &[u8; 16] = tag
360                .as_ref()
361                .try_into()
362                .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
363            let tag = aes_gcm::Tag::from_slice(tag);
364            cipher
365                .decrypt_in_place_detached(nonce, aad, ciphertext, tag)
366                .map_err(|_| RecordDecryptionError::AuthenticationFailed)?;
367            Ok(decryption_finish(encoded, payload_start, plaintext_len))
368        }
369        (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Plain) => {
370            Err(RecordDecryptionError::AlgorithmMismatch {
371                expected: None,
372                actual: Some(EncryptionAlgorithm::Aes256Gcm),
373            })
374        }
375        (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Aegis256(_)) => {
376            Err(RecordDecryptionError::AlgorithmMismatch {
377                expected: Some(EncryptionAlgorithm::Aegis256),
378                actual: Some(EncryptionAlgorithm::Aes256Gcm),
379            })
380        }
381    }
382}
383
384fn decryption_layout(
385    record: EncryptedRecord,
386    format: EncryptedRecordFormat,
387) -> Result<(BytesMut, usize, usize), RecordDecryptionError> {
388    let payload_start = FORMAT_ID_LEN + format.nonce_len();
389    let payload_end = record
390        .encoded
391        .len()
392        .checked_sub(format.tag_len())
393        .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?;
394    if payload_start > payload_end {
395        return Err(RecordDecryptionError::MalformedEncryptedRecord);
396    }
397    Ok((record.into_mut_encoded(), payload_start, payload_end))
398}
399
400fn decryption_finish(mut encoded: BytesMut, payload_start: usize, plaintext_len: usize) -> Bytes {
401    let _ = encoded.split_to(payload_start);
402    encoded.truncate(plaintext_len);
403    encoded.freeze()
404}
405
406#[cfg(test)]
407mod tests {
408    use bytes::Bytes;
409    use rstest::rstest;
410
411    use super::*;
412    use crate::record::{CommandRecord, EnvelopeRecord, Header, MeteredExt};
413
414    const TEST_KEY: [u8; 32] = [0x42; 32];
415    const OTHER_TEST_KEY: [u8; 32] = [0x99; 32];
416
417    fn test_encryption(alg: EncryptionAlgorithm) -> EncryptionSpec {
418        match alg {
419            EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(TEST_KEY),
420            EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(TEST_KEY),
421        }
422    }
423
424    fn other_test_encryption(alg: EncryptionAlgorithm) -> EncryptionSpec {
425        match alg {
426            EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(OTHER_TEST_KEY),
427            EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(OTHER_TEST_KEY),
428        }
429    }
430
431    fn encrypt_test_record(
432        plaintext: EnvelopeRecord,
433        alg: EncryptionAlgorithm,
434        aad: &[u8],
435    ) -> EncryptedRecord {
436        let stored = encrypt_record(
437            Record::Envelope(plaintext).metered(),
438            &test_encryption(alg),
439            aad,
440        )
441        .into_inner();
442        let StoredRecord::Encrypted { record, .. } = stored else {
443            panic!("expected encrypted envelope record");
444        };
445        record
446    }
447
448    fn make_encrypted_record(
449        format: EncryptedRecordFormat,
450        nonce: impl AsRef<[u8]>,
451        ciphertext: impl AsRef<[u8]>,
452        tag: impl AsRef<[u8]>,
453    ) -> EncryptedRecord {
454        let nonce = nonce.as_ref();
455        let ciphertext = ciphertext.as_ref();
456        let tag = tag.as_ref();
457
458        assert_eq!(nonce.len(), format.nonce_len());
459        assert_eq!(tag.len(), format.tag_len());
460
461        let mut encoded =
462            BytesMut::with_capacity(FORMAT_ID_LEN + nonce.len() + ciphertext.len() + tag.len());
463        encoded.put_u8(format.format_id());
464        encoded.put_slice(nonce);
465        encoded.put_slice(ciphertext);
466        encoded.put_slice(tag);
467
468        EncryptedRecord::new(encoded.freeze(), format)
469    }
470
471    fn aad() -> [u8; 32] {
472        [0xA5; 32]
473    }
474
475    fn make_envelope(headers: Vec<Header>, body: Bytes) -> EnvelopeRecord {
476        EnvelopeRecord::try_from_parts(headers, body).unwrap()
477    }
478
479    fn make_plaintext_envelope(headers: Vec<Header>, body: Bytes) -> Record {
480        Record::Envelope(make_envelope(headers, body))
481    }
482
483    fn make_encrypted_stored_record(
484        encryption: &EncryptionSpec,
485        headers: Vec<Header>,
486        body: Bytes,
487        aad: &[u8],
488    ) -> StoredRecord {
489        let stored = encrypt_record(
490            make_plaintext_envelope(headers, body).metered(),
491            encryption,
492            aad,
493        )
494        .into_inner();
495        let StoredRecord::Encrypted { .. } = &stored else {
496            panic!("plain encryption should not produce an encrypted record");
497        };
498        stored
499    }
500
501    #[rstest]
502    #[case::aegis_unique(EncryptionAlgorithm::Aegis256, false)]
503    #[case::aegis_shared(EncryptionAlgorithm::Aegis256, true)]
504    #[case::aes_unique(EncryptionAlgorithm::Aes256Gcm, false)]
505    #[case::aes_shared(EncryptionAlgorithm::Aes256Gcm, true)]
506    fn encrypted_payload_roundtrips(
507        #[case] algorithm: EncryptionAlgorithm,
508        #[case] shared_encoded_record_buffer: bool,
509    ) {
510        let headers = vec![Header {
511            name: Bytes::from_static(b"x-test"),
512            value: Bytes::from_static(b"hello"),
513        }];
514        let body = Bytes::from_static(b"secret payload");
515
516        let aad = aad();
517        let plaintext = make_envelope(headers.clone(), body.clone());
518        let encryption = test_encryption(algorithm);
519        let encrypted_record = encrypt_test_record(plaintext, algorithm, &aad);
520        let encrypted_record = if shared_encoded_record_buffer {
521            let shared = encrypted_record.encoded.clone();
522            EncryptedRecord::try_from(shared).unwrap()
523        } else {
524            encrypted_record
525        };
526        let decrypted = decrypt_payload(encrypted_record, &encryption, &aad).unwrap();
527        let (out_headers, out_body) = EnvelopeRecord::try_from(decrypted).unwrap().into_parts();
528
529        assert_eq!(out_headers, headers);
530        assert_eq!(out_body, body);
531    }
532
533    #[rstest]
534    #[case(EncryptionAlgorithm::Aegis256)]
535    #[case(EncryptionAlgorithm::Aes256Gcm)]
536    fn wrong_key_fails(#[case] algorithm: EncryptionAlgorithm) {
537        let aad = aad();
538        let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
539        let encrypted_record = encrypt_test_record(plaintext, algorithm, &aad);
540        let result = decrypt_payload(encrypted_record, &other_test_encryption(algorithm), &aad);
541        assert!(matches!(
542            result,
543            Err(RecordDecryptionError::AuthenticationFailed)
544        ));
545    }
546
547    #[test]
548    fn empty_body_fails() {
549        let result = EncryptedRecord::try_from(Bytes::new());
550        assert!(matches!(
551            result,
552            Err(RecordDecodeError::Truncated("EncryptedRecordFormatId"))
553        ));
554    }
555
556    #[test]
557    fn format_id_byte_present() {
558        let aad = aad();
559        let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
560        let encrypted_record = encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad);
561        let encoded = encrypted_record.to_bytes();
562        assert_eq!(encrypted_record.format, EncryptedRecordFormat::Aegis256V1);
563        assert_eq!(encrypted_record.algorithm(), EncryptionAlgorithm::Aegis256);
564        assert_eq!(encoded[0], 0x01);
565    }
566
567    #[test]
568    fn format_id_flip_detected() {
569        let aad = aad();
570        let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
571        let mut encoded_record =
572            encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad)
573                .to_bytes()
574                .to_vec();
575        assert_eq!(encoded_record[0], 0x01);
576        encoded_record[0] = 0x02;
577        let encrypted_record = EncryptedRecord::try_from(Bytes::from(encoded_record)).unwrap();
578        let result = decrypt_payload(
579            encrypted_record,
580            &test_encryption(EncryptionAlgorithm::Aegis256),
581            &aad,
582        );
583        assert!(matches!(
584            result,
585            Err(RecordDecryptionError::AlgorithmMismatch {
586                expected: Some(EncryptionAlgorithm::Aegis256),
587                actual: Some(EncryptionAlgorithm::Aes256Gcm),
588            })
589        ));
590    }
591
592    #[test]
593    fn wrong_aad_fails() {
594        let aad = aad();
595        let other_aad = [0x5A; 32];
596        let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
597        let encrypted_record = encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad);
598        let result = decrypt_payload(
599            encrypted_record,
600            &test_encryption(EncryptionAlgorithm::Aegis256),
601            &other_aad,
602        );
603        assert!(matches!(
604            result,
605            Err(RecordDecryptionError::AuthenticationFailed)
606        ));
607    }
608
609    #[test]
610    fn malformed_encrypted_record_layout_returns_error_instead_of_panicking() {
611        let aad = aad();
612        let record = EncryptedRecord {
613            encoded: Bytes::from_static(b"\x01short"),
614            format: EncryptedRecordFormat::Aegis256V1,
615        };
616
617        let result = decrypt_payload(
618            record,
619            &test_encryption(EncryptionAlgorithm::Aegis256),
620            &aad,
621        );
622
623        assert!(matches!(
624            result,
625            Err(RecordDecryptionError::MalformedEncryptedRecord)
626        ));
627    }
628
629    #[test]
630    fn encrypted_record_roundtrips_aes256gcm() {
631        let record = make_encrypted_record(
632            EncryptedRecordFormat::Aes256GcmV1,
633            Bytes::from_static(b"0123456789ab"),
634            Bytes::from_static(b"ciphertext"),
635            Bytes::from_static(b"0123456789abcdef"),
636        );
637
638        let bytes = record.to_bytes();
639        let decoded = EncryptedRecord::try_from(bytes).unwrap();
640
641        assert_eq!(decoded, record);
642        assert_eq!(decoded.format, EncryptedRecordFormat::Aes256GcmV1);
643        assert_eq!(decoded.encoded[0], FORMAT_ID_AES256GCM_V1);
644        assert_eq!(decoded.nonce(), b"0123456789ab");
645        assert_eq!(decoded.ciphertext(), b"ciphertext");
646        assert_eq!(decoded.tag(), b"0123456789abcdef");
647    }
648
649    #[test]
650    fn rejects_invalid_format_id() {
651        let err = EncryptedRecord::try_from(Bytes::from_static(b"\xFFpayload")).unwrap_err();
652        assert_eq!(
653            err,
654            RecordDecodeError::InvalidValue(
655                "EncryptedRecord",
656                "invalid encrypted record format id"
657            )
658        );
659    }
660
661    #[test]
662    fn rejects_truncated_layout() {
663        let err = EncryptedRecord::try_from(Bytes::from_static(b"\x01tiny")).unwrap_err();
664        assert_eq!(err, RecordDecodeError::Truncated("EncryptedRecordFrame"));
665    }
666
667    #[test]
668    fn encrypt_record_encrypts_envelope_records() {
669        let aad = aad();
670        let encryption = test_encryption(EncryptionAlgorithm::Aegis256);
671        let headers = vec![Header {
672            name: Bytes::from_static(b"x-test"),
673            value: Bytes::from_static(b"hello"),
674        }];
675        let body = Bytes::from_static(b"secret payload");
676        let record = make_plaintext_envelope(headers.clone(), body.clone()).metered();
677
678        let stored = encrypt_record(record, &encryption, &aad).into_inner();
679        let StoredRecord::Encrypted {
680            record: envelope, ..
681        } = &stored
682        else {
683            panic!("expected encrypted envelope record");
684        };
685        assert_eq!(envelope.format, EncryptedRecordFormat::Aegis256V1);
686        assert_eq!(envelope.algorithm(), EncryptionAlgorithm::Aegis256);
687
688        let decrypted = decrypt_stored_record(stored, &encryption, &aad).unwrap();
689        let Record::Envelope(record) = decrypted.into_inner() else {
690            panic!("expected envelope record");
691        };
692        assert_eq!(record.headers(), headers.as_slice());
693        assert_eq!(record.body().as_ref(), body.as_ref());
694    }
695
696    #[test]
697    fn decrypt_stored_record_preserves_plaintext_command_records() {
698        let token: crate::record::FencingToken = "fence-test".parse().unwrap();
699        let record = StoredRecord::Plaintext(Record::Command(CommandRecord::Fence(token.clone())));
700
701        let decrypted = decrypt_stored_record(
702            record,
703            &test_encryption(EncryptionAlgorithm::Aegis256),
704            &aad(),
705        )
706        .unwrap();
707
708        let Record::Command(record) = decrypted.into_inner() else {
709            panic!("expected command record");
710        };
711        assert_eq!(record, CommandRecord::Fence(token));
712    }
713
714    #[test]
715    fn decrypt_stored_record_decrypts_encrypted_records() {
716        let aad = aad();
717        let record = make_encrypted_stored_record(
718            &test_encryption(EncryptionAlgorithm::Aegis256),
719            vec![Header {
720                name: Bytes::from_static(b"x-test"),
721                value: Bytes::from_static(b"hello"),
722            }],
723            Bytes::from_static(b"secret payload"),
724            &aad,
725        );
726
727        let decrypted = decrypt_stored_record(
728            record,
729            &test_encryption(EncryptionAlgorithm::Aegis256),
730            &aad,
731        )
732        .unwrap();
733
734        let Record::Envelope(record) = decrypted.into_inner() else {
735            panic!("expected envelope record");
736        };
737        assert_eq!(record.headers().len(), 1);
738        assert_eq!(record.headers()[0].name.as_ref(), b"x-test");
739        assert_eq!(record.headers()[0].value.as_ref(), b"hello");
740        assert_eq!(record.body().as_ref(), b"secret payload");
741    }
742
743    #[test]
744    fn decrypt_stored_record_plain_rejects_encrypted_records() {
745        let aad = aad();
746        let record = make_encrypted_stored_record(
747            &test_encryption(EncryptionAlgorithm::Aegis256),
748            vec![],
749            Bytes::from_static(b"secret payload"),
750            &aad,
751        );
752
753        let result = decrypt_stored_record(record, &EncryptionSpec::Plain, &aad);
754
755        assert!(matches!(
756            result,
757            Err(RecordDecryptionError::AlgorithmMismatch {
758                expected: None,
759                actual: Some(EncryptionAlgorithm::Aegis256),
760            })
761        ));
762    }
763
764    #[test]
765    fn decode_stored_record_rejects_encrypted_metered_size_mismatch() {
766        let aad = aad();
767        let stored = make_encrypted_stored_record(
768            &test_encryption(EncryptionAlgorithm::Aegis256),
769            vec![Header {
770                name: Bytes::from_static(b"x-test"),
771                value: Bytes::from_static(b"hello"),
772            }],
773            Bytes::from_static(b"secret payload"),
774            &aad,
775        );
776        let StoredRecord::Encrypted {
777            metered_size,
778            record,
779        } = stored
780        else {
781            panic!("expected encrypted stored record");
782        };
783
784        let result = decrypt_stored_record(
785            StoredRecord::encrypted(record, metered_size + 1),
786            &test_encryption(EncryptionAlgorithm::Aegis256),
787            &aad,
788        );
789
790        assert!(matches!(
791            result,
792            Err(RecordDecryptionError::MeteredSizeMismatch {
793                stored,
794                actual
795            }) if stored == metered_size + 1 && actual == metered_size
796        ));
797    }
798}