1use aegis::aegis256::Aegis256;
32use aes_gcm::{Aes256Gcm, KeyInit, aead::AeadInPlace};
33use bytes::{BufMut, Bytes, BytesMut};
34use rand::random;
35use s2_common::{
36 deep_size::DeepSize,
37 encryption::{EncryptionAlgorithm, EncryptionSpec},
38 record::{EnvelopeRecord, Metered, MeteredExt as _, MeteredSize, Record, SeqNum, Sequenced},
39 stream::{
40 AppendInput, AppendRecord, AppendRecordBatch, AppendRecordParts, ReadBatch,
41 ReadSessionOutput,
42 },
43};
44use secrecy::ExposeSecret as _;
45
46use super::{
47 StoredAppendInput, StoredReadBatch, StoredReadSessionOutput, StoredRecord,
48 StoredRecordDecodeError, WireEncode, codec::decode_envelope_record,
49};
50
51const FORMAT_ID_LEN: usize = 1;
52
53const FORMAT_ID_AEGIS256_V1: u8 = 0x01;
54const FORMAT_ID_AES256GCM_V1: u8 = 0x02;
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub(crate) enum EncryptedRecordFormat {
58 Aegis256V1,
59 Aes256GcmV1,
60}
61
62impl EncryptedRecordFormat {
63 const fn try_from_format_id(format_id: u8) -> Result<Self, StoredRecordDecodeError> {
64 match format_id {
65 FORMAT_ID_AEGIS256_V1 => Ok(Self::Aegis256V1),
66 FORMAT_ID_AES256GCM_V1 => Ok(Self::Aes256GcmV1),
67 _ => Err(StoredRecordDecodeError::InvalidValue(
68 "EncryptedRecord",
69 "invalid encrypted record format id",
70 )),
71 }
72 }
73
74 const fn format_id(self) -> u8 {
75 match self {
76 Self::Aegis256V1 => FORMAT_ID_AEGIS256_V1,
77 Self::Aes256GcmV1 => FORMAT_ID_AES256GCM_V1,
78 }
79 }
80
81 const fn algorithm(self) -> EncryptionAlgorithm {
82 match self {
83 Self::Aegis256V1 => EncryptionAlgorithm::Aegis256,
84 Self::Aes256GcmV1 => EncryptionAlgorithm::Aes256Gcm,
85 }
86 }
87
88 const fn nonce_len(self) -> usize {
89 match self {
90 Self::Aegis256V1 => 32,
91 Self::Aes256GcmV1 => 12,
92 }
93 }
94
95 const fn tag_len(self) -> usize {
96 match self {
97 Self::Aegis256V1 => 16,
98 Self::Aes256GcmV1 => 16,
99 }
100 }
101
102 fn put_random_nonce(self, buf: &mut impl BufMut) {
103 match self {
104 Self::Aegis256V1 => buf.put_slice(&random::<[u8; 32]>()),
105 Self::Aes256GcmV1 => buf.put_slice(&random::<[u8; 12]>()),
106 }
107 }
108
109 const fn max_assignable_seq_num(self) -> SeqNum {
110 match self {
111 Self::Aegis256V1 => SeqNum::MAX,
112 Self::Aes256GcmV1 => (1u64 << 32) - 1,
113 }
114 }
115}
116
117#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
118pub enum RecordDecryptionError {
119 #[error("record encryption algorithm mismatch")]
120 AlgorithmMismatch {
121 expected: Option<EncryptionAlgorithm>,
122 actual: Option<EncryptionAlgorithm>,
123 },
124 #[error("record decryption failed")]
125 AuthenticationFailed,
126 #[error("malformed encrypted record")]
127 MalformedEncryptedRecord,
128 #[error("decrypted record metered size mismatch: stored {stored}, actual {actual}")]
129 MeteredSizeMismatch { stored: usize, actual: usize },
130 #[error("malformed decrypted record: {0}")]
131 MalformedDecryptedRecord(#[from] StoredRecordDecodeError),
132}
133
134#[derive(PartialEq, Eq, Clone)]
135pub struct EncryptedRecord {
136 encoded: Bytes,
137 format: EncryptedRecordFormat,
138}
139
140impl EncryptedRecord {
141 fn new(encoded: Bytes, format: EncryptedRecordFormat) -> Self {
142 debug_assert!(!encoded.is_empty());
143 debug_assert_eq!(encoded[0], format.format_id());
144 debug_assert!(encoded.len() >= FORMAT_ID_LEN + format.nonce_len() + format.tag_len());
145 Self { encoded, format }
146 }
147
148 pub fn algorithm(&self) -> EncryptionAlgorithm {
149 self.format.algorithm()
150 }
151
152 pub fn max_assignable_seq_num(&self) -> SeqNum {
153 self.format.max_assignable_seq_num()
154 }
155
156 pub(crate) fn nonce(&self) -> &[u8] {
157 let start = FORMAT_ID_LEN;
158 let end = start + self.format.nonce_len();
159 &self.encoded[start..end]
160 }
161
162 pub(crate) fn ciphertext(&self) -> &[u8] {
163 let start = FORMAT_ID_LEN + self.format.nonce_len();
164 let end = self.encoded.len() - self.format.tag_len();
165 &self.encoded[start..end]
166 }
167
168 pub(crate) fn tag(&self) -> &[u8] {
169 let start = self.encoded.len() - self.format.tag_len();
170 let end = self.encoded.len();
171 &self.encoded[start..end]
172 }
173
174 fn into_mut_encoded(self) -> BytesMut {
175 self.encoded
176 .try_into_mut()
177 .unwrap_or_else(|encoded| BytesMut::from(encoded.as_ref()))
178 }
179}
180
181impl std::fmt::Debug for EncryptedRecord {
182 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183 f.debug_struct("EncryptedRecord")
184 .field("format_id", &self.encoded[0])
185 .field("format", &self.format)
186 .field("algorithm", &self.format.algorithm())
187 .field("nonce.len", &self.nonce().len())
188 .field("ciphertext.len", &self.ciphertext().len())
189 .field("tag.len", &self.tag().len())
190 .finish()
191 }
192}
193
194impl DeepSize for EncryptedRecord {
195 fn deep_size(&self) -> usize {
196 self.encoded.len()
197 }
198}
199
200impl WireEncode for EncryptedRecord {
201 fn encoded_size(&self) -> usize {
202 self.encoded.len()
203 }
204
205 fn encode_into(&self, buf: &mut impl BufMut) {
206 buf.put_slice(self.encoded.as_ref());
207 }
208}
209
210pub fn encrypt_record(
211 record: Metered<Record>,
212 encryption: &EncryptionSpec,
213 aad: &[u8],
214) -> Metered<StoredRecord> {
215 let metered_size = record.metered_size();
216 let record = match (record.into_inner(), encryption) {
217 (record @ Record::Command(_), _) => StoredRecord::Plaintext(record),
218 (record @ Record::Envelope(_), EncryptionSpec::Plain) => StoredRecord::Plaintext(record),
219 (Record::Envelope(envelope), EncryptionSpec::Aegis256(key)) => {
220 let format = EncryptedRecordFormat::Aegis256V1;
221 let (mut encoded, payload_start) = prep_encryption_buffer(&envelope, format);
222 let (prefix, payload) = encoded.split_at_mut(payload_start);
223 let nonce: &[u8; 32] = prefix[FORMAT_ID_LEN..]
224 .try_into()
225 .expect("AEGIS-256 nonce must be 32 bytes");
226 let tag =
227 Aegis256::<16>::new(key.expose_secret(), nonce).encrypt_in_place(payload, aad);
228 encoded.put_slice(tag.as_ref());
229
230 let encrypted = EncryptedRecord::new(encoded.freeze(), format);
231 StoredRecord::encrypted(encrypted, metered_size)
232 }
233 (Record::Envelope(envelope), EncryptionSpec::Aes256Gcm(key)) => {
234 let format = EncryptedRecordFormat::Aes256GcmV1;
235 let (mut encoded, payload_start) = prep_encryption_buffer(&envelope, format);
236 let (prefix, payload) = encoded.split_at_mut(payload_start);
237 let nonce = aes_gcm::Nonce::from_slice(&prefix[FORMAT_ID_LEN..]);
238 let tag = Aes256Gcm::new(aes_gcm::Key::<Aes256Gcm>::from_slice(key.expose_secret()))
239 .encrypt_in_place_detached(nonce, aad, payload)
240 .expect("AES-256-GCM encryption should not fail on size validation");
241 encoded.put_slice(tag.as_ref());
242
243 let encrypted = EncryptedRecord::new(encoded.freeze(), format);
244 StoredRecord::encrypted(encrypted, metered_size)
245 }
246 };
247 record.metered()
248}
249
250pub fn encrypt_append_input(
251 input: AppendInput,
252 encryption: &EncryptionSpec,
253 aad: &[u8],
254) -> StoredAppendInput {
255 let AppendInput {
256 records,
257 match_seq_num,
258 fencing_token,
259 } = input;
260 let records = records
261 .into_iter()
262 .map(|record| {
263 let AppendRecordParts { timestamp, record } = record.into_parts();
264 AppendRecord::try_from(AppendRecordParts {
265 timestamp,
266 record: encrypt_record(record, encryption, aad),
267 })
268 .expect("record encryption preserves append record limits")
269 })
270 .collect::<Vec<_>>();
271
272 AppendInput {
273 records: AppendRecordBatch::try_from(records)
274 .expect("record encryption preserves append batch limits"),
275 match_seq_num,
276 fencing_token,
277 }
278}
279
280fn prep_encryption_buffer(
281 envelope: &EnvelopeRecord,
282 format: EncryptedRecordFormat,
283) -> (BytesMut, usize) {
284 let payload_start = FORMAT_ID_LEN + format.nonce_len();
285 let mut encoded =
286 BytesMut::with_capacity(payload_start + envelope.encoded_size() + format.tag_len());
287 encoded.put_u8(format.format_id());
288 format.put_random_nonce(&mut encoded);
289 envelope.encode_into(&mut encoded);
290 (encoded, payload_start)
291}
292
293impl TryFrom<Bytes> for EncryptedRecord {
294 type Error = StoredRecordDecodeError;
295
296 fn try_from(encoded: Bytes) -> Result<Self, Self::Error> {
297 if encoded.len() < FORMAT_ID_LEN {
298 return Err(StoredRecordDecodeError::Truncated(
299 "EncryptedRecordFormatId",
300 ));
301 }
302
303 let format = EncryptedRecordFormat::try_from_format_id(encoded[0])?;
304 let nonce_len = format.nonce_len();
305 let tag_len = format.tag_len();
306 if encoded.len() < FORMAT_ID_LEN + nonce_len + tag_len {
307 return Err(StoredRecordDecodeError::Truncated("EncryptedRecordFrame"));
308 }
309
310 Ok(Self::new(encoded, format))
311 }
312}
313
314pub fn decrypt_stored_record(
315 record: StoredRecord,
316 encryption: &EncryptionSpec,
317 aad: &[u8],
318) -> Result<Metered<Record>, RecordDecryptionError> {
319 match record {
320 StoredRecord::Plaintext(record @ Record::Command(_)) => Ok(record.metered()),
321 StoredRecord::Plaintext(record @ Record::Envelope(_)) => match encryption {
322 EncryptionSpec::Plain => Ok(record.metered()),
323 EncryptionSpec::Aegis256(_) => Err(RecordDecryptionError::AlgorithmMismatch {
324 expected: Some(EncryptionAlgorithm::Aegis256),
325 actual: None,
326 }),
327 EncryptionSpec::Aes256Gcm(_) => Err(RecordDecryptionError::AlgorithmMismatch {
328 expected: Some(EncryptionAlgorithm::Aes256Gcm),
329 actual: None,
330 }),
331 },
332 StoredRecord::Encrypted {
333 metered_size,
334 record: encrypted,
335 } => {
336 let plaintext = decrypt_payload(encrypted, encryption, aad)?;
337 let record = Record::Envelope(decode_envelope_record(plaintext)?);
338 let actual_metered_size = record.metered_size();
339 if metered_size != actual_metered_size {
340 return Err(RecordDecryptionError::MeteredSizeMismatch {
341 stored: metered_size,
342 actual: actual_metered_size,
343 });
344 }
345 Ok(record.metered())
346 }
347 }
348}
349
350pub fn decrypt_read_session_output(
351 output: StoredReadSessionOutput,
352 encryption: &EncryptionSpec,
353 aad: &[u8],
354) -> Result<ReadSessionOutput, RecordDecryptionError> {
355 match output {
356 ReadSessionOutput::Heartbeat(tail) => Ok(ReadSessionOutput::Heartbeat(tail)),
357 ReadSessionOutput::Batch(batch) => {
358 decrypt_read_batch(batch, encryption, aad).map(ReadSessionOutput::Batch)
359 }
360 }
361}
362
363fn decrypt_read_batch(
364 batch: StoredReadBatch,
365 encryption: &EncryptionSpec,
366 aad: &[u8],
367) -> Result<ReadBatch, RecordDecryptionError> {
368 let records: Result<Metered<Vec<Sequenced<Record>>>, RecordDecryptionError> = batch
369 .records
370 .into_inner()
371 .into_iter()
372 .map(|record| {
373 let (position, record) = record.into_parts();
374 decrypt_stored_record(record, encryption, aad).map(|record| record.sequenced(position))
375 })
376 .collect();
377
378 Ok(ReadBatch {
379 records: records?,
380 tail: batch.tail,
381 })
382}
383
384fn decrypt_payload(
385 record: EncryptedRecord,
386 encryption: &EncryptionSpec,
387 aad: &[u8],
388) -> Result<Bytes, RecordDecryptionError> {
389 let format = record.format;
390 let (mut encoded, payload_start, payload_end) = decryption_layout(record, format)?;
391 let plaintext_len = payload_end - payload_start;
392
393 match (format, encryption) {
394 (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Aegis256(key)) => {
395 let (prefix, payload_and_tag) = encoded.split_at_mut(payload_start);
396 let nonce: &[u8; 32] = prefix
397 .get(FORMAT_ID_LEN..)
398 .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?
399 .try_into()
400 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
401 let (ciphertext, tag) = payload_and_tag.split_at_mut(plaintext_len);
402 let tag: &[u8; 16] = tag
403 .as_ref()
404 .try_into()
405 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
406 Aegis256::<16>::new(key.expose_secret(), nonce)
407 .decrypt_in_place(ciphertext, tag, aad)
408 .map_err(|_| RecordDecryptionError::AuthenticationFailed)?;
409 Ok(decryption_finish(encoded, payload_start, plaintext_len))
410 }
411 (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Plain) => {
412 Err(RecordDecryptionError::AlgorithmMismatch {
413 expected: None,
414 actual: Some(EncryptionAlgorithm::Aegis256),
415 })
416 }
417 (EncryptedRecordFormat::Aegis256V1, EncryptionSpec::Aes256Gcm(_)) => {
418 Err(RecordDecryptionError::AlgorithmMismatch {
419 expected: Some(EncryptionAlgorithm::Aes256Gcm),
420 actual: Some(EncryptionAlgorithm::Aegis256),
421 })
422 }
423 (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Aes256Gcm(key)) => {
424 let (prefix, payload_and_tag) = encoded.split_at_mut(payload_start);
425 let nonce: &[u8; 12] = prefix
426 .get(FORMAT_ID_LEN..)
427 .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?
428 .try_into()
429 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
430 let nonce = aes_gcm::Nonce::from_slice(nonce);
431 let (ciphertext, tag) = payload_and_tag.split_at_mut(plaintext_len);
432 let tag: &[u8; 16] = tag
433 .as_ref()
434 .try_into()
435 .map_err(|_| RecordDecryptionError::MalformedEncryptedRecord)?;
436 let tag = aes_gcm::Tag::from_slice(tag);
437 let cipher = Aes256Gcm::new(aes_gcm::Key::<Aes256Gcm>::from_slice(key.expose_secret()));
438 cipher
439 .decrypt_in_place_detached(nonce, aad, ciphertext, tag)
440 .map_err(|_| RecordDecryptionError::AuthenticationFailed)?;
441 Ok(decryption_finish(encoded, payload_start, plaintext_len))
442 }
443 (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Plain) => {
444 Err(RecordDecryptionError::AlgorithmMismatch {
445 expected: None,
446 actual: Some(EncryptionAlgorithm::Aes256Gcm),
447 })
448 }
449 (EncryptedRecordFormat::Aes256GcmV1, EncryptionSpec::Aegis256(_)) => {
450 Err(RecordDecryptionError::AlgorithmMismatch {
451 expected: Some(EncryptionAlgorithm::Aegis256),
452 actual: Some(EncryptionAlgorithm::Aes256Gcm),
453 })
454 }
455 }
456}
457
458fn decryption_layout(
459 record: EncryptedRecord,
460 format: EncryptedRecordFormat,
461) -> Result<(BytesMut, usize, usize), RecordDecryptionError> {
462 let payload_start = FORMAT_ID_LEN + format.nonce_len();
463 let payload_end = record
464 .encoded
465 .len()
466 .checked_sub(format.tag_len())
467 .ok_or(RecordDecryptionError::MalformedEncryptedRecord)?;
468 if payload_start > payload_end {
469 return Err(RecordDecryptionError::MalformedEncryptedRecord);
470 }
471 Ok((record.into_mut_encoded(), payload_start, payload_end))
472}
473
474fn decryption_finish(mut encoded: BytesMut, payload_start: usize, plaintext_len: usize) -> Bytes {
475 let _ = encoded.split_to(payload_start);
476 encoded.truncate(plaintext_len);
477 encoded.freeze()
478}
479
480#[cfg(test)]
481mod tests {
482 use bytes::Bytes;
483 use rstest::rstest;
484 use s2_common::record::{CommandRecord, EnvelopeRecord, FencingToken, Header, MeteredExt};
485
486 use super::*;
487
488 const TEST_KEY: [u8; 32] = [0x42; 32];
489 const OTHER_TEST_KEY: [u8; 32] = [0x99; 32];
490
491 fn test_encryption(alg: EncryptionAlgorithm) -> EncryptionSpec {
492 match alg {
493 EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(TEST_KEY),
494 EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(TEST_KEY),
495 }
496 }
497
498 fn other_test_encryption(alg: EncryptionAlgorithm) -> EncryptionSpec {
499 match alg {
500 EncryptionAlgorithm::Aegis256 => EncryptionSpec::aegis256(OTHER_TEST_KEY),
501 EncryptionAlgorithm::Aes256Gcm => EncryptionSpec::aes256_gcm(OTHER_TEST_KEY),
502 }
503 }
504
505 fn encrypt_test_record(
506 plaintext: EnvelopeRecord,
507 alg: EncryptionAlgorithm,
508 aad: &[u8],
509 ) -> EncryptedRecord {
510 let stored = encrypt_record(
511 Record::Envelope(plaintext).metered(),
512 &test_encryption(alg),
513 aad,
514 )
515 .into_inner();
516 let StoredRecord::Encrypted { record, .. } = stored else {
517 panic!("expected encrypted envelope record");
518 };
519 record
520 }
521
522 fn make_encrypted_record(
523 format: EncryptedRecordFormat,
524 nonce: impl AsRef<[u8]>,
525 ciphertext: impl AsRef<[u8]>,
526 tag: impl AsRef<[u8]>,
527 ) -> EncryptedRecord {
528 let nonce = nonce.as_ref();
529 let ciphertext = ciphertext.as_ref();
530 let tag = tag.as_ref();
531
532 assert_eq!(nonce.len(), format.nonce_len());
533 assert_eq!(tag.len(), format.tag_len());
534
535 let mut encoded =
536 BytesMut::with_capacity(FORMAT_ID_LEN + nonce.len() + ciphertext.len() + tag.len());
537 encoded.put_u8(format.format_id());
538 encoded.put_slice(nonce);
539 encoded.put_slice(ciphertext);
540 encoded.put_slice(tag);
541
542 EncryptedRecord::new(encoded.freeze(), format)
543 }
544
545 fn aad() -> [u8; 32] {
546 [0xA5; 32]
547 }
548
549 fn make_envelope(headers: Vec<Header>, body: Bytes) -> EnvelopeRecord {
550 EnvelopeRecord::try_from_parts(headers, body).unwrap()
551 }
552
553 fn make_plaintext_envelope(headers: Vec<Header>, body: Bytes) -> Record {
554 Record::Envelope(make_envelope(headers, body))
555 }
556
557 fn make_encrypted_stored_record(
558 encryption: &EncryptionSpec,
559 headers: Vec<Header>,
560 body: Bytes,
561 aad: &[u8],
562 ) -> StoredRecord {
563 let stored = encrypt_record(
564 make_plaintext_envelope(headers, body).metered(),
565 encryption,
566 aad,
567 )
568 .into_inner();
569 let StoredRecord::Encrypted { .. } = &stored else {
570 panic!("plain encryption should not produce an encrypted record");
571 };
572 stored
573 }
574
575 #[rstest]
576 #[case::aegis_unique(EncryptionAlgorithm::Aegis256, false)]
577 #[case::aegis_shared(EncryptionAlgorithm::Aegis256, true)]
578 #[case::aes_unique(EncryptionAlgorithm::Aes256Gcm, false)]
579 #[case::aes_shared(EncryptionAlgorithm::Aes256Gcm, true)]
580 fn encrypted_payload_roundtrips(
581 #[case] algorithm: EncryptionAlgorithm,
582 #[case] shared_encoded_record_buffer: bool,
583 ) {
584 let headers = vec![Header {
585 name: Bytes::from_static(b"x-test"),
586 value: Bytes::from_static(b"hello"),
587 }];
588 let body = Bytes::from_static(b"secret payload");
589
590 let aad = aad();
591 let plaintext = make_envelope(headers.clone(), body.clone());
592 let encryption = test_encryption(algorithm);
593 let encrypted_record = encrypt_test_record(plaintext, algorithm, &aad);
594 let encrypted_record = if shared_encoded_record_buffer {
595 let shared = encrypted_record.encoded.clone();
596 EncryptedRecord::try_from(shared).unwrap()
597 } else {
598 encrypted_record
599 };
600 let decrypted = decrypt_payload(encrypted_record, &encryption, &aad).unwrap();
601 let (out_headers, out_body) = decode_envelope_record(decrypted).unwrap().into_parts();
602
603 assert_eq!(out_headers, headers);
604 assert_eq!(out_body, body);
605 }
606
607 #[rstest]
608 #[case(EncryptionAlgorithm::Aegis256)]
609 #[case(EncryptionAlgorithm::Aes256Gcm)]
610 fn wrong_key_fails(#[case] algorithm: EncryptionAlgorithm) {
611 let aad = aad();
612 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
613 let encrypted_record = encrypt_test_record(plaintext, algorithm, &aad);
614 let result = decrypt_payload(encrypted_record, &other_test_encryption(algorithm), &aad);
615 assert!(matches!(
616 result,
617 Err(RecordDecryptionError::AuthenticationFailed)
618 ));
619 }
620
621 #[test]
622 fn empty_body_fails() {
623 let result = EncryptedRecord::try_from(Bytes::new());
624 assert!(matches!(
625 result,
626 Err(StoredRecordDecodeError::Truncated(
627 "EncryptedRecordFormatId"
628 ))
629 ));
630 }
631
632 #[test]
633 fn format_id_byte_present() {
634 let aad = aad();
635 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
636 let encrypted_record = encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad);
637 let encoded = encrypted_record.to_bytes();
638 assert_eq!(encrypted_record.format, EncryptedRecordFormat::Aegis256V1);
639 assert_eq!(encrypted_record.algorithm(), EncryptionAlgorithm::Aegis256);
640 assert_eq!(encoded[0], 0x01);
641 }
642
643 #[test]
644 fn format_id_flip_detected() {
645 let aad = aad();
646 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
647 let mut encoded_record =
648 encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad)
649 .to_bytes()
650 .to_vec();
651 assert_eq!(encoded_record[0], 0x01);
652 encoded_record[0] = 0x02;
653 let encrypted_record = EncryptedRecord::try_from(Bytes::from(encoded_record)).unwrap();
654 let result = decrypt_payload(
655 encrypted_record,
656 &test_encryption(EncryptionAlgorithm::Aegis256),
657 &aad,
658 );
659 assert!(matches!(
660 result,
661 Err(RecordDecryptionError::AlgorithmMismatch {
662 expected: Some(EncryptionAlgorithm::Aegis256),
663 actual: Some(EncryptionAlgorithm::Aes256Gcm),
664 })
665 ));
666 }
667
668 #[test]
669 fn wrong_aad_fails() {
670 let aad = aad();
671 let other_aad = [0x5A; 32];
672 let plaintext = make_envelope(vec![], Bytes::from_static(b"data"));
673 let encrypted_record = encrypt_test_record(plaintext, EncryptionAlgorithm::Aegis256, &aad);
674 let result = decrypt_payload(
675 encrypted_record,
676 &test_encryption(EncryptionAlgorithm::Aegis256),
677 &other_aad,
678 );
679 assert!(matches!(
680 result,
681 Err(RecordDecryptionError::AuthenticationFailed)
682 ));
683 }
684
685 #[test]
686 fn malformed_encrypted_record_layout_returns_error_instead_of_panicking() {
687 let aad = aad();
688 let record = EncryptedRecord {
689 encoded: Bytes::from_static(b"\x01short"),
690 format: EncryptedRecordFormat::Aegis256V1,
691 };
692
693 let result = decrypt_payload(
694 record,
695 &test_encryption(EncryptionAlgorithm::Aegis256),
696 &aad,
697 );
698
699 assert!(matches!(
700 result,
701 Err(RecordDecryptionError::MalformedEncryptedRecord)
702 ));
703 }
704
705 #[test]
706 fn encrypted_record_roundtrips_aes256gcm() {
707 let record = make_encrypted_record(
708 EncryptedRecordFormat::Aes256GcmV1,
709 Bytes::from_static(b"0123456789ab"),
710 Bytes::from_static(b"ciphertext"),
711 Bytes::from_static(b"0123456789abcdef"),
712 );
713
714 let bytes = record.to_bytes();
715 let decoded = EncryptedRecord::try_from(bytes).unwrap();
716
717 assert_eq!(decoded, record);
718 assert_eq!(decoded.format, EncryptedRecordFormat::Aes256GcmV1);
719 assert_eq!(decoded.encoded[0], FORMAT_ID_AES256GCM_V1);
720 assert_eq!(decoded.nonce(), b"0123456789ab");
721 assert_eq!(decoded.ciphertext(), b"ciphertext");
722 assert_eq!(decoded.tag(), b"0123456789abcdef");
723 }
724
725 #[test]
726 fn rejects_invalid_format_id() {
727 let err = EncryptedRecord::try_from(Bytes::from_static(b"\xFFpayload")).unwrap_err();
728 assert_eq!(
729 err,
730 StoredRecordDecodeError::InvalidValue(
731 "EncryptedRecord",
732 "invalid encrypted record format id"
733 )
734 );
735 }
736
737 #[test]
738 fn rejects_truncated_layout() {
739 let err = EncryptedRecord::try_from(Bytes::from_static(b"\x01tiny")).unwrap_err();
740 assert_eq!(
741 err,
742 StoredRecordDecodeError::Truncated("EncryptedRecordFrame")
743 );
744 }
745
746 #[test]
747 fn encrypt_record_encrypts_envelope_records() {
748 let aad = aad();
749 let encryption = test_encryption(EncryptionAlgorithm::Aegis256);
750 let headers = vec![Header {
751 name: Bytes::from_static(b"x-test"),
752 value: Bytes::from_static(b"hello"),
753 }];
754 let body = Bytes::from_static(b"secret payload");
755 let record = make_plaintext_envelope(headers.clone(), body.clone()).metered();
756
757 let stored = encrypt_record(record, &encryption, &aad).into_inner();
758 let StoredRecord::Encrypted {
759 record: envelope, ..
760 } = &stored
761 else {
762 panic!("expected encrypted envelope record");
763 };
764 assert_eq!(envelope.format, EncryptedRecordFormat::Aegis256V1);
765 assert_eq!(envelope.algorithm(), EncryptionAlgorithm::Aegis256);
766
767 let decrypted = decrypt_stored_record(stored, &encryption, &aad).unwrap();
768 let Record::Envelope(record) = decrypted.into_inner() else {
769 panic!("expected envelope record");
770 };
771 assert_eq!(record.headers(), headers.as_slice());
772 assert_eq!(record.body().as_ref(), body.as_ref());
773 }
774
775 #[test]
776 fn decrypt_stored_record_preserves_plaintext_command_records() {
777 let token: FencingToken = "fence-test".parse().unwrap();
778 let record = StoredRecord::Plaintext(Record::Command(CommandRecord::Fence(token.clone())));
779
780 let decrypted = decrypt_stored_record(
781 record,
782 &test_encryption(EncryptionAlgorithm::Aegis256),
783 &aad(),
784 )
785 .unwrap();
786
787 let Record::Command(record) = decrypted.into_inner() else {
788 panic!("expected command record");
789 };
790 assert_eq!(record, CommandRecord::Fence(token));
791 }
792
793 #[test]
794 fn decrypt_stored_record_decrypts_encrypted_records() {
795 let aad = aad();
796 let record = make_encrypted_stored_record(
797 &test_encryption(EncryptionAlgorithm::Aegis256),
798 vec![Header {
799 name: Bytes::from_static(b"x-test"),
800 value: Bytes::from_static(b"hello"),
801 }],
802 Bytes::from_static(b"secret payload"),
803 &aad,
804 );
805
806 let decrypted = decrypt_stored_record(
807 record,
808 &test_encryption(EncryptionAlgorithm::Aegis256),
809 &aad,
810 )
811 .unwrap();
812
813 let Record::Envelope(record) = decrypted.into_inner() else {
814 panic!("expected envelope record");
815 };
816 assert_eq!(record.headers().len(), 1);
817 assert_eq!(record.headers()[0].name.as_ref(), b"x-test");
818 assert_eq!(record.headers()[0].value.as_ref(), b"hello");
819 assert_eq!(record.body().as_ref(), b"secret payload");
820 }
821
822 #[test]
823 fn decrypt_stored_record_plain_rejects_encrypted_records() {
824 let aad = aad();
825 let record = make_encrypted_stored_record(
826 &test_encryption(EncryptionAlgorithm::Aegis256),
827 vec![],
828 Bytes::from_static(b"secret payload"),
829 &aad,
830 );
831
832 let result = decrypt_stored_record(record, &EncryptionSpec::Plain, &aad);
833
834 assert!(matches!(
835 result,
836 Err(RecordDecryptionError::AlgorithmMismatch {
837 expected: None,
838 actual: Some(EncryptionAlgorithm::Aegis256),
839 })
840 ));
841 }
842
843 #[test]
844 fn decode_stored_record_rejects_encrypted_metered_size_mismatch() {
845 let aad = aad();
846 let stored = make_encrypted_stored_record(
847 &test_encryption(EncryptionAlgorithm::Aegis256),
848 vec![Header {
849 name: Bytes::from_static(b"x-test"),
850 value: Bytes::from_static(b"hello"),
851 }],
852 Bytes::from_static(b"secret payload"),
853 &aad,
854 );
855 let StoredRecord::Encrypted {
856 metered_size,
857 record,
858 } = stored
859 else {
860 panic!("expected encrypted stored record");
861 };
862
863 let result = decrypt_stored_record(
864 StoredRecord::encrypted(record, metered_size + 1),
865 &test_encryption(EncryptionAlgorithm::Aegis256),
866 &aad,
867 );
868
869 assert!(matches!(
870 result,
871 Err(RecordDecryptionError::MeteredSizeMismatch {
872 stored,
873 actual
874 }) if stored == metered_size + 1 && actual == metered_size
875 ));
876 }
877}