1use aegis::aegis256::Aegis256;
32use aes_gcm::{Aes256Gcm, KeyInit, aead::AeadInPlace};
33use bytes::{BufMut, Bytes, BytesMut};
34use rand::random;
35
36use super::{Encodable, Metered, MeteredSize, Record, RecordDecodeError, SeqNum, StoredRecord};
37use crate::{
38 deep_size::DeepSize,
39 encryption::{EncryptionAlgorithm, EncryptionSpec},
40 record::MeteredExt as _,
41};
42
43const FORMAT_ID_LEN: usize = 1;
44
45const FORMAT_ID_AEGIS256_V1: u8 = 0x01;
46const FORMAT_ID_AES256GCM_V1: u8 = 0x02;
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub(crate) enum EncryptedRecordFormat {
50 Aegis256V1,
51 Aes256GcmV1,
52}
53
54impl EncryptedRecordFormat {
55 const fn try_from_format_id(format_id: u8) -> Result<Self, RecordDecodeError> {
56 match format_id {
57 FORMAT_ID_AEGIS256_V1 => Ok(Self::Aegis256V1),
58 FORMAT_ID_AES256GCM_V1 => Ok(Self::Aes256GcmV1),
59 _ => Err(RecordDecodeError::InvalidValue(
60 "EncryptedRecord",
61 "invalid encrypted record format id",
62 )),
63 }
64 }
65
66 const fn format_id(self) -> u8 {
67 match self {
68 Self::Aegis256V1 => FORMAT_ID_AEGIS256_V1,
69 Self::Aes256GcmV1 => FORMAT_ID_AES256GCM_V1,
70 }
71 }
72
73 const fn algorithm(self) -> EncryptionAlgorithm {
74 match self {
75 Self::Aegis256V1 => EncryptionAlgorithm::Aegis256,
76 Self::Aes256GcmV1 => EncryptionAlgorithm::Aes256Gcm,
77 }
78 }
79
80 const fn nonce_len(self) -> usize {
81 match self {
82 Self::Aegis256V1 => 32,
83 Self::Aes256GcmV1 => 12,
84 }
85 }
86
87 const fn tag_len(self) -> usize {
88 match self {
89 Self::Aegis256V1 => 16,
90 Self::Aes256GcmV1 => 16,
91 }
92 }
93
94 fn put_random_nonce(self, buf: &mut impl BufMut) {
95 match self {
96 Self::Aegis256V1 => buf.put_slice(&random::<[u8; 32]>()),
97 Self::Aes256GcmV1 => buf.put_slice(&random::<[u8; 12]>()),
98 }
99 }
100
101 const fn max_assignable_seq_num(self) -> SeqNum {
102 match self {
103 Self::Aegis256V1 => SeqNum::MAX,
104 Self::Aes256GcmV1 => (1u64 << 32) - 1,
105 }
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
110pub enum RecordDecryptionError {
111 #[error("record encryption algorithm mismatch")]
112 AlgorithmMismatch {
113 expected: Option<EncryptionAlgorithm>,
114 actual: Option<EncryptionAlgorithm>,
115 },
116 #[error("record decryption failed")]
117 AuthenticationFailed,
118 #[error("malformed encrypted record")]
119 MalformedEncryptedRecord,
120 #[error("decrypted record metered size mismatch: stored {stored}, actual {actual}")]
121 MeteredSizeMismatch { stored: usize, actual: usize },
122 #[error("malformed decrypted record: {0}")]
123 MalformedDecryptedRecord(#[from] RecordDecodeError),
124}
125
126#[derive(PartialEq, Eq, Clone)]
127pub struct EncryptedRecord {
128 encoded: Bytes,
129 format: EncryptedRecordFormat,
130}
131
132impl EncryptedRecord {
133 fn new(encoded: Bytes, format: EncryptedRecordFormat) -> Self {
134 debug_assert!(!encoded.is_empty());
135 debug_assert_eq!(encoded[0], format.format_id());
136 debug_assert!(encoded.len() >= FORMAT_ID_LEN + format.nonce_len() + format.tag_len());
137 Self { encoded, format }
138 }
139
140 pub fn algorithm(&self) -> EncryptionAlgorithm {
141 self.format.algorithm()
142 }
143
144 pub fn max_assignable_seq_num(&self) -> SeqNum {
145 self.format.max_assignable_seq_num()
146 }
147
148 pub(crate) fn nonce(&self) -> &[u8] {
149 let start = FORMAT_ID_LEN;
150 let end = start + self.format.nonce_len();
151 &self.encoded[start..end]
152 }
153
154 pub(crate) fn ciphertext(&self) -> &[u8] {
155 let start = FORMAT_ID_LEN + self.format.nonce_len();
156 let end = self.encoded.len() - self.format.tag_len();
157 &self.encoded[start..end]
158 }
159
160 pub(crate) fn tag(&self) -> &[u8] {
161 let start = self.encoded.len() - self.format.tag_len();
162 let end = self.encoded.len();
163 &self.encoded[start..end]
164 }
165
166 fn into_mut_encoded(self) -> BytesMut {
167 self.encoded
168 .try_into_mut()
169 .unwrap_or_else(|encoded| BytesMut::from(encoded.as_ref()))
170 }
171}
172
173impl std::fmt::Debug for EncryptedRecord {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("EncryptedRecord")
176 .field("format_id", &self.encoded[0])
177 .field("format", &self.format)
178 .field("algorithm", &self.format.algorithm())
179 .field("nonce.len", &self.nonce().len())
180 .field("ciphertext.len", &self.ciphertext().len())
181 .field("tag.len", &self.tag().len())
182 .finish()
183 }
184}
185
186impl DeepSize for EncryptedRecord {
187 fn deep_size(&self) -> usize {
188 self.encoded.len()
189 }
190}
191
192impl Encodable for EncryptedRecord {
193 fn encoded_size(&self) -> usize {
194 self.encoded.len()
195 }
196
197 fn encode_into(&self, buf: &mut impl BufMut) {
198 buf.put_slice(self.encoded.as_ref());
199 }
200}
201
202pub fn encrypt_record(
203 record: Metered<Record>,
204 encryption: &EncryptionSpec,
205 aad: &[u8],
206) -> Metered<StoredRecord> {
207 let metered_size = record.metered_size();
208 let record = match (record.into_inner(), encryption) {
209 (record @ Record::Command(_), _) => StoredRecord::Plaintext(record),
210 (record @ Record::Envelope(_), EncryptionSpec::Plain) => StoredRecord::Plaintext(record),
211 (Record::Envelope(envelope), EncryptionSpec::Aegis256(key)) => {
212 let format = EncryptedRecordFormat::Aegis256V1;
213 let (mut encoded, payload_start) = prep_encryption_buffer(&envelope, format);
214 let (prefix, payload) = encoded.split_at_mut(payload_start);
215 let nonce: &[u8; 32] = prefix[FORMAT_ID_LEN..]
216 .try_into()
217 .expect("AEGIS-256 nonce must be 32 bytes");
218 let tag =
219 Aegis256::<16>::new(key.expose_secret(), nonce).encrypt_in_place(payload, aad);
220 encoded.put_slice(tag.as_ref());
221
222 let encrypted = EncryptedRecord::new(encoded.freeze(), format);
223 StoredRecord::encrypted(encrypted, metered_size)
224 }
225 (Record::Envelope(envelope), EncryptionSpec::Aes256Gcm(key)) => {
226 let format = EncryptedRecordFormat::Aes256GcmV1;
227 let (mut encoded, payload_start) = prep_encryption_buffer(&envelope, format);
228 let (prefix, payload) = encoded.split_at_mut(payload_start);
229 let nonce = aes_gcm::Nonce::from_slice(&prefix[FORMAT_ID_LEN..]);
230 let tag = Aes256Gcm::new(aes_gcm::Key::<Aes256Gcm>::from_slice(key.expose_secret()))
231 .encrypt_in_place_detached(nonce, aad, payload)
232 .expect("AES-256-GCM encryption should not fail on size validation");
233 encoded.put_slice(tag.as_ref());
234
235 let encrypted = EncryptedRecord::new(encoded.freeze(), format);
236 StoredRecord::encrypted(encrypted, metered_size)
237 }
238 };
239 Metered::with_size(metered_size, record)
240}
241
242fn prep_encryption_buffer(
243 envelope: &super::EnvelopeRecord,
244 format: EncryptedRecordFormat,
245) -> (BytesMut, usize) {
246 let payload_start = FORMAT_ID_LEN + format.nonce_len();
247 let mut encoded =
248 BytesMut::with_capacity(payload_start + envelope.encoded_size() + format.tag_len());
249 encoded.put_u8(format.format_id());
250 format.put_random_nonce(&mut encoded);
251 envelope.encode_into(&mut encoded);
252 (encoded, payload_start)
253}
254
255impl TryFrom<Bytes> for EncryptedRecord {
256 type Error = RecordDecodeError;
257
258 fn try_from(encoded: Bytes) -> Result<Self, Self::Error> {
259 if encoded.len() < FORMAT_ID_LEN {
260 return Err(RecordDecodeError::Truncated("EncryptedRecordFormatId"));
261 }
262
263 let format = EncryptedRecordFormat::try_from_format_id(encoded[0])?;
264 let nonce_len = format.nonce_len();
265 let tag_len = format.tag_len();
266 if encoded.len() < FORMAT_ID_LEN + nonce_len + tag_len {
267 return Err(RecordDecodeError::Truncated("EncryptedRecordFrame"));
268 }
269
270 Ok(Self::new(encoded, format))
271 }
272}
273
274pub fn decrypt_stored_record(
275 record: StoredRecord,
276 encryption: &EncryptionSpec,
277 aad: &[u8],
278) -> Result<Metered<Record>, RecordDecryptionError> {
279 match record {
280 StoredRecord::Plaintext(record @ Record::Command(_)) => Ok(record.metered()),
281 StoredRecord::Plaintext(record @ Record::Envelope(_)) => match encryption {
282 EncryptionSpec::Plain => Ok(record.metered()),
283 EncryptionSpec::Aegis256(_) => Err(RecordDecryptionError::AlgorithmMismatch {
284 expected: Some(EncryptionAlgorithm::Aegis256),
285 actual: None,
286 }),
287 EncryptionSpec::Aes256Gcm(_) => Err(RecordDecryptionError::AlgorithmMismatch {
288 expected: Some(EncryptionAlgorithm::Aes256Gcm),
289 actual: None,
290 }),
291 },
292 StoredRecord::Encrypted {
293 metered_size,
294 record: encrypted,
295 } => {
296 let plaintext = decrypt_payload(encrypted, encryption, aad)?;
297 let record = Record::Envelope(plaintext.try_into()?);
298 let actual_metered_size = record.metered_size();
299 if metered_size != actual_metered_size {
300 return Err(RecordDecryptionError::MeteredSizeMismatch {
301 stored: metered_size,
302 actual: actual_metered_size,
303 });
304 }
305 Ok(Metered::with_size(metered_size, record))
306 }
307 }
308}
309
310fn decrypt_payload(
311 record: EncryptedRecord,
312 encryption: &EncryptionSpec,
313 aad: &[u8],
314) -> Result<Bytes, RecordDecryptionError> {
315 let format = record.format;
316 let (mut encoded, payload_start, payload_end) = decryption_layout(record, format)?;
317 let plaintext_len = payload_end - payload_start;
318
319 match (format, encryption) {
320 (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Aegis256(key)) => {
321 let (prefix, payload_and_tag) = encoded.split_at_mut(payload_start);
322 let nonce: &[u8; 32] = prefix
323 .get(FORMAT_ID_LEN..)
324 .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?
325 .try_into()
326 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
327 let (ciphertext, tag) = payload_and_tag.split_at_mut(plaintext_len);
328 let tag: &[u8; 16] = tag
329 .as_ref()
330 .try_into()
331 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
332 Aegis256::<16>::new(key.expose_secret(), nonce)
333 .decrypt_in_place(ciphertext, tag, aad)
334 .map_err(|_| RecordDecryptionError::AuthenticationFailed)?;
335 Ok(decryption_finish(encoded, payload_start, plaintext_len))
336 }
337 (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Plain) => {
338 Err(RecordDecryptionError::AlgorithmMismatch {
339 expected: None,
340 actual: Some(EncryptionAlgorithm::Aegis256),
341 })
342 }
343 (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Aes256Gcm(_)) => {
344 Err(RecordDecryptionError::AlgorithmMismatch {
345 expected: Some(EncryptionAlgorithm::Aes256Gcm),
346 actual: Some(EncryptionAlgorithm::Aegis256),
347 })
348 }
349 (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Aes256Gcm(key)) => {
350 let cipher = Aes256Gcm::new(aes_gcm::Key::<Aes256Gcm>::from_slice(key.expose_secret()));
351 let (prefix, payload_and_tag) = encoded.split_at_mut(payload_start);
352 let nonce: &[u8; 12] = prefix
353 .get(FORMAT_ID_LEN..)
354 .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?
355 .try_into()
356 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
357 let nonce = aes_gcm::Nonce::from_slice(nonce);
358 let (ciphertext, tag) = payload_and_tag.split_at_mut(plaintext_len);
359 let tag: &[u8; 16] = tag
360 .as_ref()
361 .try_into()
362 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
363 let tag = aes_gcm::Tag::from_slice(tag);
364 cipher
365 .decrypt_in_place_detached(nonce, aad, ciphertext, tag)
366 .map_err(|_| RecordDecryptionError::AuthenticationFailed)?;
367 Ok(decryption_finish(encoded, payload_start, plaintext_len))
368 }
369 (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Plain) => {
370 Err(RecordDecryptionError::AlgorithmMismatch {
371 expected: None,
372 actual: Some(EncryptionAlgorithm::Aes256Gcm),
373 })
374 }
375 (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Aegis256(_)) => {
376 Err(RecordDecryptionError::AlgorithmMismatch {
377 expected: Some(EncryptionAlgorithm::Aegis256),
378 actual: Some(EncryptionAlgorithm::Aes256Gcm),
379 })
380 }
381 }
382}
383
384fn decryption_layout(
385 record: EncryptedRecord,
386 format: EncryptedRecordFormat,
387) -> Result<(BytesMut, usize, usize), RecordDecryptionError> {
388 let payload_start = FORMAT_ID_LEN + format.nonce_len();
389 let payload_end = record
390 .encoded
391 .len()
392 .checked_sub(format.tag_len())
393 .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?;
394 if payload_start > payload_end {
395 return Err(RecordDecryptionError::MalformedEncryptedRecord);
396 }
397 Ok((record.into_mut_encoded(), payload_start, payload_end))
398}
399
400fn decryption_finish(mut encoded: BytesMut, payload_start: usize, plaintext_len: usize) -> Bytes {
401 let _ = encoded.split_to(payload_start);
402 encoded.truncate(plaintext_len);
403 encoded.freeze()
404}
405
406#[cfg(test)]
407mod tests {
408 use bytes::Bytes;
409 use rstest::rstest;
410
411 use super::*;
412 use crate::record::{CommandRecord, EnvelopeRecord, Header, MeteredExt};
413
414 const TEST_KEY: [u8; 32] = [0x42; 32];
415 const OTHER_TEST_KEY: [u8; 32] = [0x99; 32];
416
417 fn test_encryption(alg: EncryptionAlgorithm) -> EncryptionSpec {
418 match alg {
419 EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(TEST_KEY),
420 EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(TEST_KEY),
421 }
422 }
423
424 fn other_test_encryption(alg: EncryptionAlgorithm) -> EncryptionSpec {
425 match alg {
426 EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(OTHER_TEST_KEY),
427 EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(OTHER_TEST_KEY),
428 }
429 }
430
431 fn encrypt_test_record(
432 plaintext: EnvelopeRecord,
433 alg: EncryptionAlgorithm,
434 aad: &[u8],
435 ) -> EncryptedRecord {
436 let stored = encrypt_record(
437 Record::Envelope(plaintext).metered(),
438 &test_encryption(alg),
439 aad,
440 )
441 .into_inner();
442 let StoredRecord::Encrypted { record, .. } = stored else {
443 panic!("expected encrypted envelope record");
444 };
445 record
446 }
447
448 fn make_encrypted_record(
449 format: EncryptedRecordFormat,
450 nonce: impl AsRef<[u8]>,
451 ciphertext: impl AsRef<[u8]>,
452 tag: impl AsRef<[u8]>,
453 ) -> EncryptedRecord {
454 let nonce = nonce.as_ref();
455 let ciphertext = ciphertext.as_ref();
456 let tag = tag.as_ref();
457
458 assert_eq!(nonce.len(), format.nonce_len());
459 assert_eq!(tag.len(), format.tag_len());
460
461 let mut encoded =
462 BytesMut::with_capacity(FORMAT_ID_LEN + nonce.len() + ciphertext.len() + tag.len());
463 encoded.put_u8(format.format_id());
464 encoded.put_slice(nonce);
465 encoded.put_slice(ciphertext);
466 encoded.put_slice(tag);
467
468 EncryptedRecord::new(encoded.freeze(), format)
469 }
470
471 fn aad() -> [u8; 32] {
472 [0xA5; 32]
473 }
474
475 fn make_envelope(headers: Vec<Header>, body: Bytes) -> EnvelopeRecord {
476 EnvelopeRecord::try_from_parts(headers, body).unwrap()
477 }
478
479 fn make_plaintext_envelope(headers: Vec<Header>, body: Bytes) -> Record {
480 Record::Envelope(make_envelope(headers, body))
481 }
482
483 fn make_encrypted_stored_record(
484 encryption: &EncryptionSpec,
485 headers: Vec<Header>,
486 body: Bytes,
487 aad: &[u8],
488 ) -> StoredRecord {
489 let stored = encrypt_record(
490 make_plaintext_envelope(headers, body).metered(),
491 encryption,
492 aad,
493 )
494 .into_inner();
495 let StoredRecord::Encrypted { .. } = &stored else {
496 panic!("plain encryption should not produce an encrypted record");
497 };
498 stored
499 }
500
501 #[rstest]
502 #[case::aegis_unique(EncryptionAlgorithm::Aegis256, false)]
503 #[case::aegis_shared(EncryptionAlgorithm::Aegis256, true)]
504 #[case::aes_unique(EncryptionAlgorithm::Aes256Gcm, false)]
505 #[case::aes_shared(EncryptionAlgorithm::Aes256Gcm, true)]
506 fn encrypted_payload_roundtrips(
507 #[case] algorithm: EncryptionAlgorithm,
508 #[case] shared_encoded_record_buffer: bool,
509 ) {
510 let headers = vec![Header {
511 name: Bytes::from_static(b"x-test"),
512 value: Bytes::from_static(b"hello"),
513 }];
514 let body = Bytes::from_static(b"secret payload");
515
516 let aad = aad();
517 let plaintext = make_envelope(headers.clone(), body.clone());
518 let encryption = test_encryption(algorithm);
519 let encrypted_record = encrypt_test_record(plaintext, algorithm, &aad);
520 let encrypted_record = if shared_encoded_record_buffer {
521 let shared = encrypted_record.encoded.clone();
522 EncryptedRecord::try_from(shared).unwrap()
523 } else {
524 encrypted_record
525 };
526 let decrypted = decrypt_payload(encrypted_record, &encryption, &aad).unwrap();
527 let (out_headers, out_body) = EnvelopeRecord::try_from(decrypted).unwrap().into_parts();
528
529 assert_eq!(out_headers, headers);
530 assert_eq!(out_body, body);
531 }
532
533 #[rstest]
534 #[case(EncryptionAlgorithm::Aegis256)]
535 #[case(EncryptionAlgorithm::Aes256Gcm)]
536 fn wrong_key_fails(#[case] algorithm: EncryptionAlgorithm) {
537 let aad = aad();
538 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
539 let encrypted_record = encrypt_test_record(plaintext, algorithm, &aad);
540 let result = decrypt_payload(encrypted_record, &other_test_encryption(algorithm), &aad);
541 assert!(matches!(
542 result,
543 Err(RecordDecryptionError::AuthenticationFailed)
544 ));
545 }
546
547 #[test]
548 fn empty_body_fails() {
549 let result = EncryptedRecord::try_from(Bytes::new());
550 assert!(matches!(
551 result,
552 Err(RecordDecodeError::Truncated("EncryptedRecordFormatId"))
553 ));
554 }
555
556 #[test]
557 fn format_id_byte_present() {
558 let aad = aad();
559 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
560 let encrypted_record = encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad);
561 let encoded = encrypted_record.to_bytes();
562 assert_eq!(encrypted_record.format, EncryptedRecordFormat::Aegis256V1);
563 assert_eq!(encrypted_record.algorithm(), EncryptionAlgorithm::Aegis256);
564 assert_eq!(encoded[0], 0x01);
565 }
566
567 #[test]
568 fn format_id_flip_detected() {
569 let aad = aad();
570 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
571 let mut encoded_record =
572 encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad)
573 .to_bytes()
574 .to_vec();
575 assert_eq!(encoded_record[0], 0x01);
576 encoded_record[0] = 0x02;
577 let encrypted_record = EncryptedRecord::try_from(Bytes::from(encoded_record)).unwrap();
578 let result = decrypt_payload(
579 encrypted_record,
580 &test_encryption(EncryptionAlgorithm::Aegis256),
581 &aad,
582 );
583 assert!(matches!(
584 result,
585 Err(RecordDecryptionError::AlgorithmMismatch {
586 expected: Some(EncryptionAlgorithm::Aegis256),
587 actual: Some(EncryptionAlgorithm::Aes256Gcm),
588 })
589 ));
590 }
591
592 #[test]
593 fn wrong_aad_fails() {
594 let aad = aad();
595 let other_aad = [0x5A; 32];
596 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
597 let encrypted_record = encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad);
598 let result = decrypt_payload(
599 encrypted_record,
600 &test_encryption(EncryptionAlgorithm::Aegis256),
601 &other_aad,
602 );
603 assert!(matches!(
604 result,
605 Err(RecordDecryptionError::AuthenticationFailed)
606 ));
607 }
608
609 #[test]
610 fn malformed_encrypted_record_layout_returns_error_instead_of_panicking() {
611 let aad = aad();
612 let record = EncryptedRecord {
613 encoded: Bytes::from_static(b"\x01short"),
614 format: EncryptedRecordFormat::Aegis256V1,
615 };
616
617 let result = decrypt_payload(
618 record,
619 &test_encryption(EncryptionAlgorithm::Aegis256),
620 &aad,
621 );
622
623 assert!(matches!(
624 result,
625 Err(RecordDecryptionError::MalformedEncryptedRecord)
626 ));
627 }
628
629 #[test]
630 fn encrypted_record_roundtrips_aes256gcm() {
631 let record = make_encrypted_record(
632 EncryptedRecordFormat::Aes256GcmV1,
633 Bytes::from_static(b"0123456789ab"),
634 Bytes::from_static(b"ciphertext"),
635 Bytes::from_static(b"0123456789abcdef"),
636 );
637
638 let bytes = record.to_bytes();
639 let decoded = EncryptedRecord::try_from(bytes).unwrap();
640
641 assert_eq!(decoded, record);
642 assert_eq!(decoded.format, EncryptedRecordFormat::Aes256GcmV1);
643 assert_eq!(decoded.encoded[0], FORMAT_ID_AES256GCM_V1);
644 assert_eq!(decoded.nonce(), b"0123456789ab");
645 assert_eq!(decoded.ciphertext(), b"ciphertext");
646 assert_eq!(decoded.tag(), b"0123456789abcdef");
647 }
648
649 #[test]
650 fn rejects_invalid_format_id() {
651 let err = EncryptedRecord::try_from(Bytes::from_static(b"\xFFpayload")).unwrap_err();
652 assert_eq!(
653 err,
654 RecordDecodeError::InvalidValue(
655 "EncryptedRecord",
656 "invalid encrypted record format id"
657 )
658 );
659 }
660
661 #[test]
662 fn rejects_truncated_layout() {
663 let err = EncryptedRecord::try_from(Bytes::from_static(b"\x01tiny")).unwrap_err();
664 assert_eq!(err, RecordDecodeError::Truncated("EncryptedRecordFrame"));
665 }
666
667 #[test]
668 fn encrypt_record_encrypts_envelope_records() {
669 let aad = aad();
670 let encryption = test_encryption(EncryptionAlgorithm::Aegis256);
671 let headers = vec![Header {
672 name: Bytes::from_static(b"x-test"),
673 value: Bytes::from_static(b"hello"),
674 }];
675 let body = Bytes::from_static(b"secret payload");
676 let record = make_plaintext_envelope(headers.clone(), body.clone()).metered();
677
678 let stored = encrypt_record(record, &encryption, &aad).into_inner();
679 let StoredRecord::Encrypted {
680 record: envelope, ..
681 } = &stored
682 else {
683 panic!("expected encrypted envelope record");
684 };
685 assert_eq!(envelope.format, EncryptedRecordFormat::Aegis256V1);
686 assert_eq!(envelope.algorithm(), EncryptionAlgorithm::Aegis256);
687
688 let decrypted = decrypt_stored_record(stored, &encryption, &aad).unwrap();
689 let Record::Envelope(record) = decrypted.into_inner() else {
690 panic!("expected envelope record");
691 };
692 assert_eq!(record.headers(), headers.as_slice());
693 assert_eq!(record.body().as_ref(), body.as_ref());
694 }
695
696 #[test]
697 fn decrypt_stored_record_preserves_plaintext_command_records() {
698 let token: crate::record::FencingToken = "fence-test".parse().unwrap();
699 let record = StoredRecord::Plaintext(Record::Command(CommandRecord::Fence(token.clone())));
700
701 let decrypted = decrypt_stored_record(
702 record,
703 &test_encryption(EncryptionAlgorithm::Aegis256),
704 &aad(),
705 )
706 .unwrap();
707
708 let Record::Command(record) = decrypted.into_inner() else {
709 panic!("expected command record");
710 };
711 assert_eq!(record, CommandRecord::Fence(token));
712 }
713
714 #[test]
715 fn decrypt_stored_record_decrypts_encrypted_records() {
716 let aad = aad();
717 let record = make_encrypted_stored_record(
718 &test_encryption(EncryptionAlgorithm::Aegis256),
719 vec![Header {
720 name: Bytes::from_static(b"x-test"),
721 value: Bytes::from_static(b"hello"),
722 }],
723 Bytes::from_static(b"secret payload"),
724 &aad,
725 );
726
727 let decrypted = decrypt_stored_record(
728 record,
729 &test_encryption(EncryptionAlgorithm::Aegis256),
730 &aad,
731 )
732 .unwrap();
733
734 let Record::Envelope(record) = decrypted.into_inner() else {
735 panic!("expected envelope record");
736 };
737 assert_eq!(record.headers().len(), 1);
738 assert_eq!(record.headers()[0].name.as_ref(), b"x-test");
739 assert_eq!(record.headers()[0].value.as_ref(), b"hello");
740 assert_eq!(record.body().as_ref(), b"secret payload");
741 }
742
743 #[test]
744 fn decrypt_stored_record_plain_rejects_encrypted_records() {
745 let aad = aad();
746 let record = make_encrypted_stored_record(
747 &test_encryption(EncryptionAlgorithm::Aegis256),
748 vec![],
749 Bytes::from_static(b"secret payload"),
750 &aad,
751 );
752
753 let result = decrypt_stored_record(record, &EncryptionSpec::Plain, &aad);
754
755 assert!(matches!(
756 result,
757 Err(RecordDecryptionError::AlgorithmMismatch {
758 expected: None,
759 actual: Some(EncryptionAlgorithm::Aegis256),
760 })
761 ));
762 }
763
764 #[test]
765 fn decode_stored_record_rejects_encrypted_metered_size_mismatch() {
766 let aad = aad();
767 let stored = make_encrypted_stored_record(
768 &test_encryption(EncryptionAlgorithm::Aegis256),
769 vec![Header {
770 name: Bytes::from_static(b"x-test"),
771 value: Bytes::from_static(b"hello"),
772 }],
773 Bytes::from_static(b"secret payload"),
774 &aad,
775 );
776 let StoredRecord::Encrypted {
777 metered_size,
778 record,
779 } = stored
780 else {
781 panic!("expected encrypted stored record");
782 };
783
784 let result = decrypt_stored_record(
785 StoredRecord::encrypted(record, metered_size + 1),
786 &test_encryption(EncryptionAlgorithm::Aegis256),
787 &aad,
788 );
789
790 assert!(matches!(
791 result,
792 Err(RecordDecryptionError::MeteredSizeMismatch {
793 stored,
794 actual
795 }) if stored == metered_size + 1 && actual == metered_size
796 ));
797 }
798}