1use aegis::aegis256::Aegis256;
32use aes_gcm::{Aes256Gcm, KeyInit, aead::AeadInPlace};
33use bytes::{BufMut, Bytes, BytesMut};
34use rand::random;
35use s2_common::{
36 deep_size::DeepSize,
37 encryption::{EncryptionAlgorithm, EncryptionSpec},
38 record::{EnvelopeRecord, Metered, MeteredExt as _, MeteredSize, Record, SeqNum, Sequenced},
39 stream::{
40 AppendInput, AppendRecord, AppendRecordBatch, AppendRecordParts, ReadBatch,
41 ReadSessionOutput,
42 },
43};
44use secrecy::ExposeSecret as _;
45
46use super::{
47 StoredAppendInput, StoredReadBatch, StoredReadSessionOutput, StoredRecord,
48 StoredRecordDecodeError, WireEncode, codec::decode_envelope_record,
49};
50
51const FORMAT_ID_LEN: usize = 1;
52
53const FORMAT_ID_AEGIS256_V1: u8 = 0x01;
54const FORMAT_ID_AES256GCM_V1: u8 = 0x02;
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub(crate) enum EncryptedRecordFormat {
58 Aegis256V1,
59 Aes256GcmV1,
60}
61
62impl EncryptedRecordFormat {
63 const fn try_from_format_id(format_id: u8) -> Result<Self, StoredRecordDecodeError> {
64 match format_id {
65 FORMAT_ID_AEGIS256_V1 => Ok(Self::Aegis256V1),
66 FORMAT_ID_AES256GCM_V1 => Ok(Self::Aes256GcmV1),
67 _ => Err(StoredRecordDecodeError::InvalidValue(
68 "EncryptedRecord",
69 "invalid encrypted record format id",
70 )),
71 }
72 }
73
74 const fn format_id(self) -> u8 {
75 match self {
76 Self::Aegis256V1 => FORMAT_ID_AEGIS256_V1,
77 Self::Aes256GcmV1 => FORMAT_ID_AES256GCM_V1,
78 }
79 }
80
81 const fn algorithm(self) -> EncryptionAlgorithm {
82 match self {
83 Self::Aegis256V1 => EncryptionAlgorithm::Aegis256,
84 Self::Aes256GcmV1 => EncryptionAlgorithm::Aes256Gcm,
85 }
86 }
87
88 const fn nonce_len(self) -> usize {
89 match self {
90 Self::Aegis256V1 => 32,
91 Self::Aes256GcmV1 => 12,
92 }
93 }
94
95 const fn tag_len(self) -> usize {
96 match self {
97 Self::Aegis256V1 => 16,
98 Self::Aes256GcmV1 => 16,
99 }
100 }
101
102 fn put_random_nonce(self, buf: &mut impl BufMut) {
103 match self {
104 Self::Aegis256V1 => buf.put_slice(&random::<[u8; 32]>()),
105 Self::Aes256GcmV1 => buf.put_slice(&random::<[u8; 12]>()),
106 }
107 }
108
109 const fn max_assignable_seq_num(self) -> SeqNum {
110 match self {
111 Self::Aegis256V1 => SeqNum::MAX,
112 Self::Aes256GcmV1 => (1u64 << 32) - 1,
113 }
114 }
115}
116
117#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
118pub enum RecordDecryptionError {
119 #[error("record encryption algorithm mismatch")]
120 AlgorithmMismatch {
121 expected: Option<EncryptionAlgorithm>,
122 actual: Option<EncryptionAlgorithm>,
123 },
124 #[error("record decryption failed")]
125 AuthenticationFailed,
126 #[error("malformed encrypted record")]
127 MalformedEncryptedRecord,
128 #[error("decrypted record metered size mismatch: stored {stored}, actual {actual}")]
129 MeteredSizeMismatch { stored: usize, actual: usize },
130 #[error("malformed decrypted record: {0}")]
131 MalformedDecryptedRecord(#[from] StoredRecordDecodeError),
132}
133
134#[derive(PartialEq, Eq, Clone)]
135pub struct EncryptedRecord {
136 encoded: Bytes,
137 format: EncryptedRecordFormat,
138}
139
140impl EncryptedRecord {
141 fn new(encoded: Bytes, format: EncryptedRecordFormat) -> Self {
142 debug_assert!(!encoded.is_empty());
143 debug_assert_eq!(encoded[0], format.format_id());
144 debug_assert!(encoded.len() >= FORMAT_ID_LEN + format.nonce_len() + format.tag_len());
145 Self { encoded, format }
146 }
147
148 pub fn max_assignable_seq_num(&self) -> SeqNum {
149 self.format.max_assignable_seq_num()
150 }
151
152 pub(crate) fn nonce(&self) -> &[u8] {
153 let start = FORMAT_ID_LEN;
154 let end = start + self.format.nonce_len();
155 &self.encoded[start..end]
156 }
157
158 pub(crate) fn ciphertext(&self) -> &[u8] {
159 let start = FORMAT_ID_LEN + self.format.nonce_len();
160 let end = self.encoded.len() - self.format.tag_len();
161 &self.encoded[start..end]
162 }
163
164 pub(crate) fn tag(&self) -> &[u8] {
165 let start = self.encoded.len() - self.format.tag_len();
166 let end = self.encoded.len();
167 &self.encoded[start..end]
168 }
169
170 fn into_mut_encoded(self) -> BytesMut {
171 self.encoded
172 .try_into_mut()
173 .unwrap_or_else(|encoded| BytesMut::from(encoded.as_ref()))
174 }
175}
176
177impl std::fmt::Debug for EncryptedRecord {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("EncryptedRecord")
180 .field("format_id", &self.encoded[0])
181 .field("format", &self.format)
182 .field("algorithm", &self.format.algorithm())
183 .field("nonce.len", &self.nonce().len())
184 .field("ciphertext.len", &self.ciphertext().len())
185 .field("tag.len", &self.tag().len())
186 .finish()
187 }
188}
189
190impl DeepSize for EncryptedRecord {
191 fn deep_size(&self) -> usize {
192 self.encoded.len()
193 }
194}
195
196impl WireEncode for EncryptedRecord {
197 fn encoded_size(&self) -> usize {
198 self.encoded.len()
199 }
200
201 fn encode_into(&self, buf: &mut impl BufMut) {
202 buf.put_slice(self.encoded.as_ref());
203 }
204}
205
206pub fn encrypt_record(
207 record: Metered<Record>,
208 encryption: &EncryptionSpec,
209 aad: &[u8],
210) -> Metered<StoredRecord> {
211 let metered_size = record.metered_size();
212 let record = match (record.into_inner(), encryption) {
213 (record @ Record::Command(_), _) => StoredRecord::Plaintext(record),
214 (record @ Record::Envelope(_), EncryptionSpec::Plain) => StoredRecord::Plaintext(record),
215 (Record::Envelope(envelope), EncryptionSpec::Aegis256(key)) => {
216 let format = EncryptedRecordFormat::Aegis256V1;
217 let (mut encoded, payload_start) = prep_encryption_buffer(&envelope, format);
218 let (prefix, payload) = encoded.split_at_mut(payload_start);
219 let nonce: &[u8; 32] = prefix[FORMAT_ID_LEN..]
220 .try_into()
221 .expect("AEGIS-256 nonce must be 32 bytes");
222 let tag =
223 Aegis256::<16>::new(key.expose_secret(), nonce).encrypt_in_place(payload, aad);
224 encoded.put_slice(tag.as_ref());
225
226 let encrypted = EncryptedRecord::new(encoded.freeze(), format);
227 StoredRecord::encrypted(encrypted, metered_size)
228 }
229 (Record::Envelope(envelope), EncryptionSpec::Aes256Gcm(key)) => {
230 let format = EncryptedRecordFormat::Aes256GcmV1;
231 let (mut encoded, payload_start) = prep_encryption_buffer(&envelope, format);
232 let (prefix, payload) = encoded.split_at_mut(payload_start);
233 let nonce = aes_gcm::Nonce::from_slice(&prefix[FORMAT_ID_LEN..]);
234 let tag = Aes256Gcm::new(aes_gcm::Key::<Aes256Gcm>::from_slice(key.expose_secret()))
235 .encrypt_in_place_detached(nonce, aad, payload)
236 .expect("AES-256-GCM encryption should not fail on size validation");
237 encoded.put_slice(tag.as_ref());
238
239 let encrypted = EncryptedRecord::new(encoded.freeze(), format);
240 StoredRecord::encrypted(encrypted, metered_size)
241 }
242 };
243 record.metered()
244}
245
246pub fn encrypt_append_input(
247 input: AppendInput,
248 encryption: &EncryptionSpec,
249 aad: &[u8],
250) -> StoredAppendInput {
251 let AppendInput {
252 records,
253 match_seq_num,
254 fencing_token,
255 } = input;
256 let records = records
257 .into_iter()
258 .map(|record| {
259 let AppendRecordParts { timestamp, record } = record.into_parts();
260 AppendRecord::try_from(AppendRecordParts {
261 timestamp,
262 record: encrypt_record(record, encryption, aad),
263 })
264 .expect("record encryption preserves append record limits")
265 })
266 .collect::<Vec<_>>();
267
268 AppendInput {
269 records: AppendRecordBatch::try_from(records)
270 .expect("record encryption preserves append batch limits"),
271 match_seq_num,
272 fencing_token,
273 }
274}
275
276fn prep_encryption_buffer(
277 envelope: &EnvelopeRecord,
278 format: EncryptedRecordFormat,
279) -> (BytesMut, usize) {
280 let payload_start = FORMAT_ID_LEN + format.nonce_len();
281 let mut encoded =
282 BytesMut::with_capacity(payload_start + envelope.encoded_size() + format.tag_len());
283 encoded.put_u8(format.format_id());
284 format.put_random_nonce(&mut encoded);
285 envelope.encode_into(&mut encoded);
286 (encoded, payload_start)
287}
288
289impl TryFrom<Bytes> for EncryptedRecord {
290 type Error = StoredRecordDecodeError;
291
292 fn try_from(encoded: Bytes) -> Result<Self, Self::Error> {
293 if encoded.len() < FORMAT_ID_LEN {
294 return Err(StoredRecordDecodeError::Truncated(
295 "EncryptedRecordFormatId",
296 ));
297 }
298
299 let format = EncryptedRecordFormat::try_from_format_id(encoded[0])?;
300 let nonce_len = format.nonce_len();
301 let tag_len = format.tag_len();
302 if encoded.len() < FORMAT_ID_LEN + nonce_len + tag_len {
303 return Err(StoredRecordDecodeError::Truncated("EncryptedRecordFrame"));
304 }
305
306 Ok(Self::new(encoded, format))
307 }
308}
309
310pub fn decrypt_stored_record(
311 record: StoredRecord,
312 encryption: &EncryptionSpec,
313 aad: &[u8],
314) -> Result<Metered<Record>, RecordDecryptionError> {
315 match record {
316 StoredRecord::Plaintext(record @ Record::Command(_)) => Ok(record.metered()),
317 StoredRecord::Plaintext(record @ Record::Envelope(_)) => match encryption {
318 EncryptionSpec::Plain => Ok(record.metered()),
319 EncryptionSpec::Aegis256(_) => Err(RecordDecryptionError::AlgorithmMismatch {
320 expected: Some(EncryptionAlgorithm::Aegis256),
321 actual: None,
322 }),
323 EncryptionSpec::Aes256Gcm(_) => Err(RecordDecryptionError::AlgorithmMismatch {
324 expected: Some(EncryptionAlgorithm::Aes256Gcm),
325 actual: None,
326 }),
327 },
328 StoredRecord::Encrypted {
329 metered_size,
330 record: encrypted,
331 } => {
332 let plaintext = decrypt_payload(encrypted, encryption, aad)?;
333 let record = Record::Envelope(decode_envelope_record(plaintext)?);
334 let actual_metered_size = record.metered_size();
335 if metered_size != actual_metered_size {
336 return Err(RecordDecryptionError::MeteredSizeMismatch {
337 stored: metered_size,
338 actual: actual_metered_size,
339 });
340 }
341 Ok(record.metered())
342 }
343 }
344}
345
346pub fn decrypt_read_session_output(
347 output: StoredReadSessionOutput,
348 encryption: &EncryptionSpec,
349 aad: &[u8],
350) -> Result<ReadSessionOutput, RecordDecryptionError> {
351 match output {
352 ReadSessionOutput::Heartbeat(tail) => Ok(ReadSessionOutput::Heartbeat(tail)),
353 ReadSessionOutput::Batch(batch) => {
354 decrypt_read_batch(batch, encryption, aad).map(ReadSessionOutput::Batch)
355 }
356 }
357}
358
359fn decrypt_read_batch(
360 batch: StoredReadBatch,
361 encryption: &EncryptionSpec,
362 aad: &[u8],
363) -> Result<ReadBatch, RecordDecryptionError> {
364 let records: Result<Metered<Vec<Sequenced<Record>>>, RecordDecryptionError> = batch
365 .records
366 .into_inner()
367 .into_iter()
368 .map(|record| {
369 let (position, record) = record.into_parts();
370 decrypt_stored_record(record, encryption, aad).map(|record| record.sequenced(position))
371 })
372 .collect();
373
374 Ok(ReadBatch {
375 records: records?,
376 tail: batch.tail,
377 })
378}
379
380fn decrypt_payload(
381 record: EncryptedRecord,
382 encryption: &EncryptionSpec,
383 aad: &[u8],
384) -> Result<Bytes, RecordDecryptionError> {
385 let format = record.format;
386 let (mut encoded, payload_start, payload_end) = decryption_layout(record, format)?;
387 let plaintext_len = payload_end - payload_start;
388
389 match (format, encryption) {
390 (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Aegis256(key)) => {
391 let (prefix, payload_and_tag) = encoded.split_at_mut(payload_start);
392 let nonce: &[u8; 32] = prefix
393 .get(FORMAT_ID_LEN..)
394 .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?
395 .try_into()
396 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
397 let (ciphertext, tag) = payload_and_tag.split_at_mut(plaintext_len);
398 let tag: &[u8; 16] = tag
399 .as_ref()
400 .try_into()
401 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
402 Aegis256::<16>::new(key.expose_secret(), nonce)
403 .decrypt_in_place(ciphertext, tag, aad)
404 .map_err(|_| RecordDecryptionError::AuthenticationFailed)?;
405 Ok(decryption_finish(encoded, payload_start, plaintext_len))
406 }
407 (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Plain) => {
408 Err(RecordDecryptionError::AlgorithmMismatch {
409 expected: None,
410 actual: Some(EncryptionAlgorithm::Aegis256),
411 })
412 }
413 (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Aes256Gcm(_)) => {
414 Err(RecordDecryptionError::AlgorithmMismatch {
415 expected: Some(EncryptionAlgorithm::Aes256Gcm),
416 actual: Some(EncryptionAlgorithm::Aegis256),
417 })
418 }
419 (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Aes256Gcm(key)) => {
420 let (prefix, payload_and_tag) = encoded.split_at_mut(payload_start);
421 let nonce: &[u8; 12] = prefix
422 .get(FORMAT_ID_LEN..)
423 .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?
424 .try_into()
425 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
426 let nonce = aes_gcm::Nonce::from_slice(nonce);
427 let (ciphertext, tag) = payload_and_tag.split_at_mut(plaintext_len);
428 let tag: &[u8; 16] = tag
429 .as_ref()
430 .try_into()
431 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
432 let tag = aes_gcm::Tag::from_slice(tag);
433 let cipher = Aes256Gcm::new(aes_gcm::Key::<Aes256Gcm>::from_slice(key.expose_secret()));
434 cipher
435 .decrypt_in_place_detached(nonce, aad, ciphertext, tag)
436 .map_err(|_| RecordDecryptionError::AuthenticationFailed)?;
437 Ok(decryption_finish(encoded, payload_start, plaintext_len))
438 }
439 (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Plain) => {
440 Err(RecordDecryptionError::AlgorithmMismatch {
441 expected: None,
442 actual: Some(EncryptionAlgorithm::Aes256Gcm),
443 })
444 }
445 (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Aegis256(_)) => {
446 Err(RecordDecryptionError::AlgorithmMismatch {
447 expected: Some(EncryptionAlgorithm::Aegis256),
448 actual: Some(EncryptionAlgorithm::Aes256Gcm),
449 })
450 }
451 }
452}
453
454fn decryption_layout(
455 record: EncryptedRecord,
456 format: EncryptedRecordFormat,
457) -> Result<(BytesMut, usize, usize), RecordDecryptionError> {
458 let payload_start = FORMAT_ID_LEN + format.nonce_len();
459 let payload_end = record
460 .encoded
461 .len()
462 .checked_sub(format.tag_len())
463 .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?;
464 if payload_start > payload_end {
465 return Err(RecordDecryptionError::MalformedEncryptedRecord);
466 }
467 Ok((record.into_mut_encoded(), payload_start, payload_end))
468}
469
470fn decryption_finish(mut encoded: BytesMut, payload_start: usize, plaintext_len: usize) -> Bytes {
471 let _ = encoded.split_to(payload_start);
472 encoded.truncate(plaintext_len);
473 encoded.freeze()
474}
475
476#[cfg(test)]
477mod tests {
478 use bytes::Bytes;
479 use rstest::rstest;
480 use s2_common::record::{CommandRecord, EnvelopeRecord, FencingToken, Header, MeteredExt};
481
482 use super::*;
483
484 const TEST_KEY: [u8; 32] = [0x42; 32];
485 const OTHER_TEST_KEY: [u8; 32] = [0x99; 32];
486
487 fn test_encryption(alg: EncryptionAlgorithm) -> EncryptionSpec {
488 match alg {
489 EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(TEST_KEY),
490 EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(TEST_KEY),
491 }
492 }
493
494 fn other_test_encryption(alg: EncryptionAlgorithm) -> EncryptionSpec {
495 match alg {
496 EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(OTHER_TEST_KEY),
497 EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(OTHER_TEST_KEY),
498 }
499 }
500
501 fn encrypt_test_record(
502 plaintext: EnvelopeRecord,
503 alg: EncryptionAlgorithm,
504 aad: &[u8],
505 ) -> EncryptedRecord {
506 let stored = encrypt_record(
507 Record::Envelope(plaintext).metered(),
508 &test_encryption(alg),
509 aad,
510 )
511 .into_inner();
512 let StoredRecord::Encrypted { record, .. } = stored else {
513 panic!("expected encrypted envelope record");
514 };
515 record
516 }
517
518 fn make_encrypted_record(
519 format: EncryptedRecordFormat,
520 nonce: impl AsRef<[u8]>,
521 ciphertext: impl AsRef<[u8]>,
522 tag: impl AsRef<[u8]>,
523 ) -> EncryptedRecord {
524 let nonce = nonce.as_ref();
525 let ciphertext = ciphertext.as_ref();
526 let tag = tag.as_ref();
527
528 assert_eq!(nonce.len(), format.nonce_len());
529 assert_eq!(tag.len(), format.tag_len());
530
531 let mut encoded =
532 BytesMut::with_capacity(FORMAT_ID_LEN + nonce.len() + ciphertext.len() + tag.len());
533 encoded.put_u8(format.format_id());
534 encoded.put_slice(nonce);
535 encoded.put_slice(ciphertext);
536 encoded.put_slice(tag);
537
538 EncryptedRecord::new(encoded.freeze(), format)
539 }
540
541 fn aad() -> [u8; 32] {
542 [0xA5; 32]
543 }
544
545 fn make_envelope(headers: Vec<Header>, body: Bytes) -> EnvelopeRecord {
546 EnvelopeRecord::try_from_parts(headers, body).unwrap()
547 }
548
549 fn make_plaintext_envelope(headers: Vec<Header>, body: Bytes) -> Record {
550 Record::Envelope(make_envelope(headers, body))
551 }
552
553 fn make_encrypted_stored_record(
554 encryption: &EncryptionSpec,
555 headers: Vec<Header>,
556 body: Bytes,
557 aad: &[u8],
558 ) -> StoredRecord {
559 let stored = encrypt_record(
560 make_plaintext_envelope(headers, body).metered(),
561 encryption,
562 aad,
563 )
564 .into_inner();
565 let StoredRecord::Encrypted { .. } = &stored else {
566 panic!("plain encryption should not produce an encrypted record");
567 };
568 stored
569 }
570
571 #[rstest]
572 #[case::aegis_unique(EncryptionAlgorithm::Aegis256, false)]
573 #[case::aegis_shared(EncryptionAlgorithm::Aegis256, true)]
574 #[case::aes_unique(EncryptionAlgorithm::Aes256Gcm, false)]
575 #[case::aes_shared(EncryptionAlgorithm::Aes256Gcm, true)]
576 fn encrypted_payload_roundtrips(
577 #[case] algorithm: EncryptionAlgorithm,
578 #[case] shared_encoded_record_buffer: bool,
579 ) {
580 let headers = vec![Header {
581 name: Bytes::from_static(b"x-test"),
582 value: Bytes::from_static(b"hello"),
583 }];
584 let body = Bytes::from_static(b"secret payload");
585
586 let aad = aad();
587 let plaintext = make_envelope(headers.clone(), body.clone());
588 let encryption = test_encryption(algorithm);
589 let encrypted_record = encrypt_test_record(plaintext, algorithm, &aad);
590 let encrypted_record = if shared_encoded_record_buffer {
591 let shared = encrypted_record.encoded.clone();
592 EncryptedRecord::try_from(shared).unwrap()
593 } else {
594 encrypted_record
595 };
596 let decrypted = decrypt_payload(encrypted_record, &encryption, &aad).unwrap();
597 let (out_headers, out_body) = decode_envelope_record(decrypted).unwrap().into_parts();
598
599 assert_eq!(out_headers, headers);
600 assert_eq!(out_body, body);
601 }
602
603 #[rstest]
604 #[case(EncryptionAlgorithm::Aegis256)]
605 #[case(EncryptionAlgorithm::Aes256Gcm)]
606 fn wrong_key_fails(#[case] algorithm: EncryptionAlgorithm) {
607 let aad = aad();
608 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
609 let encrypted_record = encrypt_test_record(plaintext, algorithm, &aad);
610 let result = decrypt_payload(encrypted_record, &other_test_encryption(algorithm), &aad);
611 assert!(matches!(
612 result,
613 Err(RecordDecryptionError::AuthenticationFailed)
614 ));
615 }
616
617 #[test]
618 fn empty_body_fails() {
619 let result = EncryptedRecord::try_from(Bytes::new());
620 assert!(matches!(
621 result,
622 Err(StoredRecordDecodeError::Truncated(
623 "EncryptedRecordFormatId"
624 ))
625 ));
626 }
627
628 #[test]
629 fn format_id_byte_present() {
630 let aad = aad();
631 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
632 let encrypted_record = encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad);
633 let encoded = encrypted_record.to_bytes();
634 assert_eq!(encrypted_record.format, EncryptedRecordFormat::Aegis256V1);
635 assert_eq!(encoded[0], 0x01);
636 }
637
638 #[test]
639 fn format_id_flip_detected() {
640 let aad = aad();
641 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
642 let mut encoded_record =
643 encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad)
644 .to_bytes()
645 .to_vec();
646 assert_eq!(encoded_record[0], 0x01);
647 encoded_record[0] = 0x02;
648 let encrypted_record = EncryptedRecord::try_from(Bytes::from(encoded_record)).unwrap();
649 let result = decrypt_payload(
650 encrypted_record,
651 &test_encryption(EncryptionAlgorithm::Aegis256),
652 &aad,
653 );
654 assert!(matches!(
655 result,
656 Err(RecordDecryptionError::AlgorithmMismatch {
657 expected: Some(EncryptionAlgorithm::Aegis256),
658 actual: Some(EncryptionAlgorithm::Aes256Gcm),
659 })
660 ));
661 }
662
663 #[test]
664 fn wrong_aad_fails() {
665 let aad = aad();
666 let other_aad = [0x5A; 32];
667 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
668 let encrypted_record = encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad);
669 let result = decrypt_payload(
670 encrypted_record,
671 &test_encryption(EncryptionAlgorithm::Aegis256),
672 &other_aad,
673 );
674 assert!(matches!(
675 result,
676 Err(RecordDecryptionError::AuthenticationFailed)
677 ));
678 }
679
680 #[test]
681 fn malformed_encrypted_record_layout_returns_error_instead_of_panicking() {
682 let aad = aad();
683 let record = EncryptedRecord {
684 encoded: Bytes::from_static(b"\x01short"),
685 format: EncryptedRecordFormat::Aegis256V1,
686 };
687
688 let result = decrypt_payload(
689 record,
690 &test_encryption(EncryptionAlgorithm::Aegis256),
691 &aad,
692 );
693
694 assert!(matches!(
695 result,
696 Err(RecordDecryptionError::MalformedEncryptedRecord)
697 ));
698 }
699
700 #[test]
701 fn encrypted_record_roundtrips_aes256gcm() {
702 let record = make_encrypted_record(
703 EncryptedRecordFormat::Aes256GcmV1,
704 Bytes::from_static(b"0123456789ab"),
705 Bytes::from_static(b"ciphertext"),
706 Bytes::from_static(b"0123456789abcdef"),
707 );
708
709 let bytes = record.to_bytes();
710 let decoded = EncryptedRecord::try_from(bytes).unwrap();
711
712 assert_eq!(decoded, record);
713 assert_eq!(decoded.format, EncryptedRecordFormat::Aes256GcmV1);
714 assert_eq!(decoded.encoded[0], FORMAT_ID_AES256GCM_V1);
715 assert_eq!(decoded.nonce(), b"0123456789ab");
716 assert_eq!(decoded.ciphertext(), b"ciphertext");
717 assert_eq!(decoded.tag(), b"0123456789abcdef");
718 }
719
720 #[test]
721 fn rejects_invalid_format_id() {
722 let err = EncryptedRecord::try_from(Bytes::from_static(b"\xFFpayload")).unwrap_err();
723 assert_eq!(
724 err,
725 StoredRecordDecodeError::InvalidValue(
726 "EncryptedRecord",
727 "invalid encrypted record format id"
728 )
729 );
730 }
731
732 #[test]
733 fn rejects_truncated_layout() {
734 let err = EncryptedRecord::try_from(Bytes::from_static(b"\x01tiny")).unwrap_err();
735 assert_eq!(
736 err,
737 StoredRecordDecodeError::Truncated("EncryptedRecordFrame")
738 );
739 }
740
741 #[test]
742 fn encrypt_record_encrypts_envelope_records() {
743 let aad = aad();
744 let encryption = test_encryption(EncryptionAlgorithm::Aegis256);
745 let headers = vec![Header {
746 name: Bytes::from_static(b"x-test"),
747 value: Bytes::from_static(b"hello"),
748 }];
749 let body = Bytes::from_static(b"secret payload");
750 let record = make_plaintext_envelope(headers.clone(), body.clone()).metered();
751
752 let stored = encrypt_record(record, &encryption, &aad).into_inner();
753 let StoredRecord::Encrypted {
754 record: envelope, ..
755 } = &stored
756 else {
757 panic!("expected encrypted envelope record");
758 };
759 assert_eq!(envelope.format, EncryptedRecordFormat::Aegis256V1);
760
761 let decrypted = decrypt_stored_record(stored, &encryption, &aad).unwrap();
762 let Record::Envelope(record) = decrypted.into_inner() else {
763 panic!("expected envelope record");
764 };
765 assert_eq!(record.headers(), headers.as_slice());
766 assert_eq!(record.body().as_ref(), body.as_ref());
767 }
768
769 #[test]
770 fn decrypt_stored_record_preserves_plaintext_command_records() {
771 let token: FencingToken = "fence-test".parse().unwrap();
772 let record = StoredRecord::Plaintext(Record::Command(CommandRecord::Fence(token.clone())));
773
774 let decrypted = decrypt_stored_record(
775 record,
776 &test_encryption(EncryptionAlgorithm::Aegis256),
777 &aad(),
778 )
779 .unwrap();
780
781 let Record::Command(record) = decrypted.into_inner() else {
782 panic!("expected command record");
783 };
784 assert_eq!(record, CommandRecord::Fence(token));
785 }
786
787 #[test]
788 fn decrypt_stored_record_decrypts_encrypted_records() {
789 let aad = aad();
790 let record = make_encrypted_stored_record(
791 &test_encryption(EncryptionAlgorithm::Aegis256),
792 vec![Header {
793 name: Bytes::from_static(b"x-test"),
794 value: Bytes::from_static(b"hello"),
795 }],
796 Bytes::from_static(b"secret payload"),
797 &aad,
798 );
799
800 let decrypted = decrypt_stored_record(
801 record,
802 &test_encryption(EncryptionAlgorithm::Aegis256),
803 &aad,
804 )
805 .unwrap();
806
807 let Record::Envelope(record) = decrypted.into_inner() else {
808 panic!("expected envelope record");
809 };
810 assert_eq!(record.headers().len(), 1);
811 assert_eq!(record.headers()[0].name.as_ref(), b"x-test");
812 assert_eq!(record.headers()[0].value.as_ref(), b"hello");
813 assert_eq!(record.body().as_ref(), b"secret payload");
814 }
815
816 #[test]
817 fn decrypt_stored_record_plain_rejects_encrypted_records() {
818 let aad = aad();
819 let record = make_encrypted_stored_record(
820 &test_encryption(EncryptionAlgorithm::Aegis256),
821 vec![],
822 Bytes::from_static(b"secret payload"),
823 &aad,
824 );
825
826 let result = decrypt_stored_record(record, &EncryptionSpec::Plain, &aad);
827
828 assert!(matches!(
829 result,
830 Err(RecordDecryptionError::AlgorithmMismatch {
831 expected: None,
832 actual: Some(EncryptionAlgorithm::Aegis256),
833 })
834 ));
835 }
836
837 #[test]
838 fn decode_stored_record_rejects_encrypted_metered_size_mismatch() {
839 let aad = aad();
840 let stored = make_encrypted_stored_record(
841 &test_encryption(EncryptionAlgorithm::Aegis256),
842 vec![Header {
843 name: Bytes::from_static(b"x-test"),
844 value: Bytes::from_static(b"hello"),
845 }],
846 Bytes::from_static(b"secret payload"),
847 &aad,
848 );
849 let StoredRecord::Encrypted {
850 metered_size,
851 record,
852 } = stored
853 else {
854 panic!("expected encrypted stored record");
855 };
856
857 let result = decrypt_stored_record(
858 StoredRecord::encrypted(record, metered_size + 1),
859 &test_encryption(EncryptionAlgorithm::Aegis256),
860 &aad,
861 );
862
863 assert!(matches!(
864 result,
865 Err(RecordDecryptionError::MeteredSizeMismatch {
866 stored,
867 actual
868 }) if stored == metered_size + 1 && actual == metered_size
869 ));
870 }
871}