1#[cfg(test)]
2use bytes::BytesMut;
3use bytes::{Buf, BufMut, Bytes};
4use s2_common::{
5 deep_size::DeepSize,
6 encryption::EncryptionAlgorithm,
7 record::{CommandRecord, Metered, MeteredSize, Record, SeqNum, Sequenced},
8};
9
10use super::{
11 codec::{StoredRecordDecodeError, WireEncode, decode_command_record, decode_envelope_record},
12 encryption::EncryptedRecord,
13};
14
15#[derive(Clone, Copy, Debug, PartialEq)]
16#[repr(u8)]
17enum RecordType {
18 Command = 1,
19 Envelope = 2,
20 EncryptedEnvelope = 3,
21}
22
23impl TryFrom<u8> for RecordType {
24 type Error = &'static str;
25
26 fn try_from(value: u8) -> Result<Self, Self::Error> {
27 match value {
28 1 => Ok(Self::Command),
29 2 => Ok(Self::Envelope),
30 3 => Ok(Self::EncryptedEnvelope),
31 _ => Err("invalid record type ordinal"),
32 }
33 }
34}
35
36#[derive(Copy, Clone, Debug, PartialEq)]
37struct MagicByte {
38 record_type: RecordType,
39 metered_size_varlen: u8,
40}
41
42fn read_vint_u32_be(bytes: &[u8]) -> u32 {
44 if bytes.len() > size_of::<u32>() || bytes.is_empty() {
45 panic!("invalid variable int bytes = {} len", bytes.len())
46 }
47 let mut acc: u32 = 0;
48 for &byte in bytes {
49 acc = (acc << 8) | byte as u32;
50 }
51 acc
52}
53
54pub fn try_metered_size(record_bytes: &[u8]) -> Result<u32, &'static str> {
55 let magic_byte_u8 = *record_bytes.first().ok_or("byte range is empty")?;
56 let magic_byte = MagicByte::try_from(magic_byte_u8)?;
57 Ok(read_vint_u32_be(
58 record_bytes
59 .get(1..1 + magic_byte.metered_size_varlen as usize)
60 .ok_or("byte range doesn't include bytes for metered size")?,
61 ))
62}
63
64impl TryFrom<u8> for MagicByte {
65 type Error = &'static str;
66
67 fn try_from(value: u8) -> Result<Self, Self::Error> {
68 let record_type = RecordType::try_from(value & 0b111)?;
69 Ok(Self {
70 record_type,
71 metered_size_varlen: match (value >> 3) & 0b11 {
72 0 => 1u8,
73 1 => 2u8,
74 2 => 3u8,
75 _ => Err("invalid metered_size_varlen")?,
76 },
77 })
78 }
79}
80
81impl From<MagicByte> for u8 {
82 fn from(value: MagicByte) -> Self {
83 ((value.metered_size_varlen - 1) << 3) | value.record_type as u8
84 }
85}
86
87#[derive(Debug, PartialEq, Eq, Clone)]
88pub enum StoredRecord {
89 Plaintext(Record),
90 Encrypted {
96 metered_size: usize,
97 record: EncryptedRecord,
98 },
99}
100
101impl StoredRecord {
102 pub(crate) fn encrypted(record: EncryptedRecord, metered_size: usize) -> Self {
103 Self::Encrypted {
104 metered_size,
105 record,
106 }
107 }
108
109 fn record_type(&self) -> RecordType {
110 match self {
111 Self::Plaintext(Record::Command(_)) => RecordType::Command,
112 Self::Plaintext(Record::Envelope(_)) => RecordType::Envelope,
113 Self::Encrypted { .. } => RecordType::EncryptedEnvelope,
114 }
115 }
116
117 fn encoded_body_size(&self) -> usize {
118 match self {
119 Self::Plaintext(Record::Command(record)) => record.encoded_size(),
120 Self::Plaintext(Record::Envelope(record)) => record.encoded_size(),
121 Self::Encrypted { record, .. } => record.encoded_size(),
122 }
123 }
124
125 fn encode_body_into(&self, buf: &mut impl BufMut) {
126 match self {
127 Self::Plaintext(Record::Command(record)) => record.encode_into(buf),
128 Self::Plaintext(Record::Envelope(record)) => record.encode_into(buf),
129 Self::Encrypted { record, .. } => record.encode_into(buf),
130 }
131 }
132
133 pub fn encryption_algorithm(&self) -> Option<EncryptionAlgorithm> {
134 match self {
135 Self::Plaintext(_) => None,
136 Self::Encrypted { record, .. } => Some(record.algorithm()),
137 }
138 }
139
140 pub fn max_assignable_seq_num(&self) -> SeqNum {
141 match self {
142 Self::Plaintext(_) => SeqNum::MAX,
143 Self::Encrypted { record, .. } => record.max_assignable_seq_num(),
144 }
145 }
146}
147
148impl DeepSize for StoredRecord {
149 fn deep_size(&self) -> usize {
150 match self {
151 Self::Plaintext(record) => record.deep_size(),
152 Self::Encrypted {
153 metered_size,
154 record,
155 } => metered_size.deep_size() + record.deep_size(),
156 }
157 }
158}
159
160impl MeteredSize for StoredRecord {
161 fn metered_size(&self) -> usize {
162 match self {
163 Self::Plaintext(record) => record.metered_size(),
164 Self::Encrypted { metered_size, .. } => *metered_size,
165 }
166 }
167}
168
169impl From<Record> for StoredRecord {
170 fn from(value: Record) -> Self {
171 Self::Plaintext(value)
172 }
173}
174
175pub fn decode_if_command_record(
176 record: &[u8],
177) -> Result<Option<CommandRecord>, StoredRecordDecodeError> {
178 if record.is_empty() {
179 return Err(StoredRecordDecodeError::Truncated("MagicByte"));
180 }
181 let magic_byte = MagicByte::try_from(record[0])
182 .map_err(|msg| StoredRecordDecodeError::InvalidValue("MagicByte", msg))?;
183 match magic_byte.record_type {
184 RecordType::Command => {
185 let offset = 1 + magic_byte.metered_size_varlen as usize;
186 if record.len() < offset {
187 return Err(StoredRecordDecodeError::Truncated("MeteredSize"));
188 }
189 Ok(Some(decode_command_record(&record[offset..])?))
190 }
191 RecordType::Envelope | RecordType::EncryptedEnvelope => Ok(None),
192 }
193}
194
195pub fn encode_stored_record(record: Metered<&StoredRecord>) -> Bytes {
196 record.to_bytes()
197}
198
199pub fn stored_record_encoded_size(record: Metered<&StoredRecord>) -> usize {
200 record.encoded_size()
201}
202
203pub fn encode_stored_record_into(record: Metered<&StoredRecord>, buf: &mut impl BufMut) {
204 record.encode_into(buf);
205}
206
207impl WireEncode for Metered<&StoredRecord> {
208 fn encoded_size(&self) -> usize {
209 1 + magic_byte(self).metered_size_varlen as usize + self.encoded_body_size()
210 }
211
212 fn encode_into(&self, buf: &mut impl BufMut) {
213 let magic_byte = magic_byte(self);
214 buf.put_u8(magic_byte.into());
215 buf.put_uint(
216 self.metered_size() as u64,
217 magic_byte.metered_size_varlen as usize,
218 );
219 self.encode_body_into(buf);
220 }
221}
222
223fn magic_byte(record: &Metered<&StoredRecord>) -> MagicByte {
224 let metered_size = record.metered_size();
225 let metered_size_varlen = 8 - (metered_size.leading_zeros() / 8) as u8;
226 if metered_size_varlen > 3 {
227 panic!("illegal metered size varlen {metered_size} for record")
228 }
229 MagicByte {
230 record_type: record.record_type(),
231 metered_size_varlen,
232 }
233}
234
235pub type StoredSequencedBytes = Sequenced<Bytes>;
236pub type StoredSequencedRecord = Sequenced<StoredRecord>;
237
238pub fn decode_stored_record(
239 mut buf: Bytes,
240) -> Result<Metered<StoredRecord>, StoredRecordDecodeError> {
241 if buf.is_empty() {
242 return Err(StoredRecordDecodeError::Truncated("MagicByte"));
243 }
244 let magic_byte = MagicByte::try_from(buf.get_u8())
245 .map_err(|msg| StoredRecordDecodeError::InvalidValue("MagicByte", msg))?;
246
247 let metered_size =
248 buf.try_get_uint(magic_byte.metered_size_varlen as usize)
249 .map_err(|_| StoredRecordDecodeError::Truncated("MeteredSize"))? as usize;
250
251 let record = match magic_byte.record_type {
252 RecordType::Command => {
253 StoredRecord::Plaintext(Record::Command(decode_command_record(buf.as_ref())?))
254 }
255 RecordType::Envelope => {
256 StoredRecord::Plaintext(Record::Envelope(decode_envelope_record(buf)?))
257 }
258 RecordType::EncryptedEnvelope => {
259 StoredRecord::encrypted(EncryptedRecord::try_from(buf)?, metered_size)
260 }
261 };
262 Ok(Metered::with_size(metered_size, record))
263}
264
265pub fn decode_record(buf: Bytes) -> Result<Metered<Record>, StoredRecordDecodeError> {
266 let stored = decode_stored_record(buf)?;
267 let metered_size = stored.metered_size();
268 match stored.into_inner() {
269 StoredRecord::Plaintext(record) => Ok(record),
270 StoredRecord::Encrypted { .. } => Err(StoredRecordDecodeError::InvalidValue(
271 "RecordType",
272 "encrypted envelope requires decryption",
273 )),
274 }
275 .map(|record| Metered::with_size(metered_size, record))
276}
277
278#[cfg(test)]
279mod test {
280 use proptest::prelude::*;
281 use rstest::rstest;
282 use s2_common::record::{
283 EnvelopeRecord, Header, MAX_FENCING_TOKEN_LENGTH, MeteredExt, StreamPosition, Timestamp,
284 };
285
286 use super::*;
287
288 struct LegacyPlaintextFrame<'a> {
289 record: &'a Record,
290 }
291
292 impl LegacyPlaintextFrame<'_> {
293 fn magic_byte(&self) -> MagicByte {
294 let metered_size = self.record.metered_size();
295 let metered_size_varlen = 8 - (metered_size.leading_zeros() / 8) as u8;
296 assert!(metered_size_varlen <= 3);
297
298 MagicByte {
299 record_type: match self.record {
300 Record::Command(_) => RecordType::Command,
301 Record::Envelope(_) => RecordType::Envelope,
302 },
303 metered_size_varlen,
304 }
305 }
306 }
307
308 impl WireEncode for LegacyPlaintextFrame<'_> {
309 fn encoded_size(&self) -> usize {
310 let body_size = match self.record {
311 Record::Command(record) => record.encoded_size(),
312 Record::Envelope(record) => record.encoded_size(),
313 };
314 1 + self.magic_byte().metered_size_varlen as usize + body_size
315 }
316
317 fn encode_into(&self, buf: &mut impl BufMut) {
318 let magic_byte = self.magic_byte();
319 buf.put_u8(magic_byte.into());
320 buf.put_uint(
321 self.record.metered_size() as u64,
322 magic_byte.metered_size_varlen as usize,
323 );
324 match self.record {
325 Record::Command(record) => record.encode_into(buf),
326 Record::Envelope(record) => record.encode_into(buf),
327 }
328 }
329 }
330
331 fn legacy_plaintext_bytes(record: &Record) -> Bytes {
332 LegacyPlaintextFrame { record }.to_bytes()
333 }
334
335 fn semantic_metered_size(record: &Record) -> usize {
336 let (headers, body) = record.clone().into_parts();
337 8 + (2 * headers.len())
338 + headers
339 .iter()
340 .map(|header| header.name.len() + header.value.len())
341 .sum::<usize>()
342 + body.len()
343 }
344
345 fn bytes_strategy(allow_empty: bool) -> impl Strategy<Value = Bytes> {
346 prop_oneof![
347 prop::collection::vec(any::<u8>(), (if allow_empty { 0 } else { 1 })..10)
348 .prop_map(Bytes::from),
349 prop::collection::vec(any::<u8>(), 100..1000).prop_map(Bytes::from),
350 ]
351 }
352
353 fn header_strategy() -> impl Strategy<Value = Header> {
354 (bytes_strategy(false), bytes_strategy(true))
355 .prop_map(|(name, value)| Header { name, value })
356 }
357
358 fn headers_strategy() -> impl Strategy<Value = Vec<Header>> {
359 prop_oneof![
360 prop::collection::vec(header_strategy(), 0..10),
361 prop::collection::vec(header_strategy(), 200..300),
362 ]
363 }
364
365 fn command_strategy() -> impl Strategy<Value = CommandRecord> {
366 prop_oneof![
367 proptest::string::string_regex(&format!("[ -~]{{0,{MAX_FENCING_TOKEN_LENGTH}}}"))
368 .unwrap()
369 .prop_map(|token| CommandRecord::Fence(token.parse().unwrap())),
370 any::<SeqNum>().prop_map(CommandRecord::Trim),
371 ]
372 }
373
374 proptest!(
375 #![proptest_config(ProptestConfig::with_cases(10))]
376 #[test]
377 fn roundtrip_envelope(
378 seq_num in any::<SeqNum>(),
379 timestamp in any::<Timestamp>(),
380 headers in headers_strategy(),
381 body in bytes_strategy(true),
382 ) {
383 let record = Record::try_from_parts(headers, body).unwrap();
384 let metered_record: Metered<Record> = record.clone().into();
385 let encoded_record =
386 encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
387 let legacy_record = legacy_plaintext_bytes(&record);
388 prop_assert_eq!(encoded_record.as_ref(), legacy_record.as_ref());
389 let decoded_record = decode_record(encoded_record).unwrap();
390 prop_assert_eq!(&decoded_record, &metered_record);
391 let sequenced = decoded_record.sequenced(StreamPosition { seq_num, timestamp });
392 let (position, sequenced_record) = sequenced.into_parts();
393 assert_eq!(position, StreamPosition { seq_num, timestamp });
394 assert_eq!(sequenced_record.into_inner(), record);
395 }
396 );
397
398 proptest!(
399 #![proptest_config(ProptestConfig::with_cases(10))]
400 #[test]
401 fn roundtrip_metered(
402 headers in headers_strategy(),
403 body in bytes_strategy(true),
404 ) {
405 let record = Record::try_from_parts(headers.clone(), body.clone()).unwrap();
406 let encoded_record =
407 encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
408 assert_eq!(record.metered_size(), semantic_metered_size(&record));
409 assert_eq!(record.metered_size(), try_metered_size(encoded_record.as_ref()).unwrap() as usize);
410 }
411 );
412
413 proptest!(
414 #![proptest_config(ProptestConfig::with_cases(10))]
415 #[test]
416 fn roundtrip_command_metered(command in command_strategy()) {
417 let record = Record::Command(command);
418 let encoded_record =
419 encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
420 let expected_metered = semantic_metered_size(&record);
421 let wire_metered = try_metered_size(encoded_record.as_ref()).unwrap() as usize;
422 let decoded_record = decode_record(encoded_record).unwrap();
423
424 assert_eq!(record.metered_size(), expected_metered);
425 assert_eq!(record.metered_size(), wire_metered);
426 prop_assert_eq!(decoded_record, Metered::<Record>::from(record));
427 }
428 );
429
430 #[test]
431 fn roundtrip_encrypted_stored_record() {
432 let mut encoded = BytesMut::with_capacity(1 + 12 + 10 + 16);
433 encoded.put_u8(0x02);
434 encoded.put_slice(b"0123456789ab");
435 encoded.put_slice(b"ciphertext");
436 encoded.put_slice(b"0123456789abcdef");
437 let record =
438 StoredRecord::encrypted(EncryptedRecord::try_from(encoded.freeze()).unwrap(), 123);
439 let metered_record = record.clone().metered();
440 let encoded_record = encode_stored_record(metered_record.as_ref());
441 let decoded_record = decode_stored_record(encoded_record).unwrap();
442 assert_eq!(decoded_record, metered_record);
443 }
444
445 #[rstest]
446 #[case(0b0000_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 1})]
447 #[case(0b0001_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 3})]
448 #[case(0b0000_0011, MagicByte { record_type: RecordType::EncryptedEnvelope, metered_size_varlen: 1})]
449 #[case(0b0000_1001, MagicByte { record_type: RecordType::Command, metered_size_varlen: 2})]
450 fn valid_magic_byte_parsing(#[case] as_u8: u8, #[case] magic_byte: MagicByte) {
451 assert_eq!(MagicByte::try_from(as_u8).unwrap(), magic_byte);
452 assert_eq!(u8::from(magic_byte), as_u8);
453 }
454
455 #[rstest]
456 #[case(0b0000_1101, "invalid record type ordinal")]
457 #[case(0b0001_1001, "invalid metered_size_varlen")]
458 fn invalid_magic_byte_parsing(#[case] as_u8: u8, #[case] expected: &'static str) {
459 assert_eq!(MagicByte::try_from(as_u8), Err(expected));
460 }
461
462 #[test]
463 fn metered_record_truncated_after_magic_byte_returns_error() {
464 let truncated = Bytes::from_static(&[0b0000_0010]);
466 let result = decode_record(truncated);
467 assert_eq!(
468 result,
469 Err(StoredRecordDecodeError::Truncated("MeteredSize"))
470 );
471 }
472
473 #[rstest]
474 #[case::envelope_empty_headers(
475 StoredRecord::from(Record::Envelope(
476 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap()
477 )),
478 &[
479 0x02, 0x0d, 0x00, b'h', b'e', b'l', b'l', b'o',
482 ],
483 )]
484 #[case::envelope_with_header(
485 StoredRecord::from(Record::Envelope(
486 EnvelopeRecord::try_from_parts(
487 vec![Header {
488 name: Bytes::from_static(b"k"),
489 value: Bytes::from_static(b"v"),
490 }],
491 Bytes::from_static(b"b"),
492 ).unwrap()
493 )),
494 &[
495 0x02, 0x0d, 0x10, 0x01, 0x01, b'k',
498 0x01, b'v',
499 b'b',
500 ],
501 )]
502 #[case::command_trim(
503 StoredRecord::from(Record::Command(CommandRecord::Trim(42))),
504 &[
505 0x01, 0x16, 0x01, 0x00, 0x00, 0x00, 0x00,
508 0x00, 0x00, 0x00, 0x2a,
509 ],
510 )]
511 fn stored_record_encoding_matches_existing_wire_format(
512 #[case] record: StoredRecord,
513 #[case] expected: &[u8],
514 ) {
515 let metered_record = record.clone().metered();
516 let encoded_size = stored_record_encoded_size(metered_record.as_ref());
517 let encoded = encode_stored_record(metered_record.as_ref());
518 let mut encoded_into = BytesMut::with_capacity(encoded_size);
519 encode_stored_record_into(metered_record.as_ref(), &mut encoded_into);
520
521 assert_eq!(encoded.len(), encoded_size);
522 assert_eq!(encoded.as_ref(), expected);
523 assert_eq!(encoded_into.as_ref(), expected);
524 assert_eq!(decode_stored_record(encoded).unwrap().into_inner(), record);
525 }
526
527 #[test]
528 fn encrypted_stored_record_encoding_matches_existing_wire_format() {
529 let encrypted_payload = Bytes::from_static(b"\x020123456789abciphertext0123456789abcdef");
530 let record = StoredRecord::encrypted(
531 EncryptedRecord::try_from(encrypted_payload.clone()).unwrap(),
532 123,
533 );
534
535 let encoded = encode_stored_record(record.clone().metered().as_ref());
536
537 assert_eq!(
538 encoded.as_ref(),
539 [&[0x03, 0x7b], encrypted_payload.as_ref()].concat()
540 );
541 assert_eq!(decode_stored_record(encoded).unwrap().into_inner(), record);
542 }
543
544 #[test]
545 fn decode_stored_record_preserves_encoded_metered_size_prefix() {
546 let record = StoredRecord::from(Record::Envelope(
547 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
548 ));
549 let mut encoded = encode_stored_record(record.clone().metered().as_ref()).to_vec();
550 encoded[1] = 99;
551
552 let decoded = decode_stored_record(Bytes::from(encoded)).unwrap();
553
554 assert_eq!(decoded.metered_size(), 99);
555 assert_eq!(decoded.into_inner(), record);
556 }
557
558 #[test]
559 fn decode_record_preserves_encoded_metered_size_prefix() {
560 let record = Record::Envelope(
561 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
562 );
563 let mut encoded =
564 encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref()).to_vec();
565 encoded[1] = 99;
566
567 let decoded = decode_record(Bytes::from(encoded)).unwrap();
568
569 assert_eq!(decoded.metered_size(), 99);
570 assert_eq!(decoded.into_inner(), record);
571 }
572
573 #[test]
574 fn test_read_varint() {
575 let data = [0u8, 0, 0, 1, 0, 0, 0];
576
577 assert_eq!(read_vint_u32_be(&data[..4]), 1u32);
578 assert_eq!(read_vint_u32_be(&data[2..5]), 2u32.pow(8));
579 assert_eq!(read_vint_u32_be(&data[2..6]), 2u32.pow(16));
580 assert_eq!(read_vint_u32_be(&data[3..]), 2u32.pow(24));
581 }
582}