1#[cfg(test)]
2use bytes::BytesMut;
3use bytes::{Buf, BufMut, Bytes};
4use s2_common::{
5 deep_size::DeepSize,
6 record::{CommandRecord, Metered, MeteredSize, Record, SeqNum, Sequenced},
7};
8
9use super::{
10 codec::{StoredRecordDecodeError, WireEncode, decode_command_record, decode_envelope_record},
11 encryption::EncryptedRecord,
12};
13
14#[derive(Clone, Copy, Debug, PartialEq)]
15#[repr(u8)]
16enum RecordType {
17 Command = 1,
18 Envelope = 2,
19 EncryptedEnvelope = 3,
20}
21
22impl TryFrom<u8> for RecordType {
23 type Error = &'static str;
24
25 fn try_from(value: u8) -> Result<Self, Self::Error> {
26 match value {
27 1 => Ok(Self::Command),
28 2 => Ok(Self::Envelope),
29 3 => Ok(Self::EncryptedEnvelope),
30 _ => Err("invalid record type ordinal"),
31 }
32 }
33}
34
35#[derive(Copy, Clone, Debug, PartialEq)]
36struct MagicByte {
37 record_type: RecordType,
38 metered_size_varlen: u8,
39}
40
41fn read_vint_u32_be(bytes: &[u8]) -> u32 {
43 if bytes.len() > size_of::<u32>() || bytes.is_empty() {
44 panic!("invalid variable int bytes = {} len", bytes.len())
45 }
46 let mut acc: u32 = 0;
47 for &byte in bytes {
48 acc = (acc << 8) | byte as u32;
49 }
50 acc
51}
52
53pub fn try_metered_size(record_bytes: &[u8]) -> Result<u32, &'static str> {
54 let magic_byte_u8 = *record_bytes.first().ok_or("byte range is empty")?;
55 let magic_byte = MagicByte::try_from(magic_byte_u8)?;
56 Ok(read_vint_u32_be(
57 record_bytes
58 .get(1..1 + magic_byte.metered_size_varlen as usize)
59 .ok_or("byte range doesn't include bytes for metered size")?,
60 ))
61}
62
63impl TryFrom<u8> for MagicByte {
64 type Error = &'static str;
65
66 fn try_from(value: u8) -> Result<Self, Self::Error> {
67 let record_type = RecordType::try_from(value & 0b111)?;
68 Ok(Self {
69 record_type,
70 metered_size_varlen: match (value >> 3) & 0b11 {
71 0 => 1u8,
72 1 => 2u8,
73 2 => 3u8,
74 _ => Err("invalid metered_size_varlen")?,
75 },
76 })
77 }
78}
79
80impl From<MagicByte> for u8 {
81 fn from(value: MagicByte) -> Self {
82 ((value.metered_size_varlen - 1) << 3) | value.record_type as u8
83 }
84}
85
86#[derive(Debug, PartialEq, Eq, Clone)]
87pub enum StoredRecord {
88 Plaintext(Record),
89 Encrypted {
95 metered_size: usize,
96 record: EncryptedRecord,
97 },
98}
99
100impl StoredRecord {
101 pub(crate) fn encrypted(record: EncryptedRecord, metered_size: usize) -> Self {
102 Self::Encrypted {
103 metered_size,
104 record,
105 }
106 }
107
108 fn record_type(&self) -> RecordType {
109 match self {
110 Self::Plaintext(Record::Command(_)) => RecordType::Command,
111 Self::Plaintext(Record::Envelope(_)) => RecordType::Envelope,
112 Self::Encrypted { .. } => RecordType::EncryptedEnvelope,
113 }
114 }
115
116 fn encoded_body_size(&self) -> usize {
117 match self {
118 Self::Plaintext(Record::Command(record)) => record.encoded_size(),
119 Self::Plaintext(Record::Envelope(record)) => record.encoded_size(),
120 Self::Encrypted { record, .. } => record.encoded_size(),
121 }
122 }
123
124 fn encode_body_into(&self, buf: &mut impl BufMut) {
125 match self {
126 Self::Plaintext(Record::Command(record)) => record.encode_into(buf),
127 Self::Plaintext(Record::Envelope(record)) => record.encode_into(buf),
128 Self::Encrypted { record, .. } => record.encode_into(buf),
129 }
130 }
131
132 pub fn max_assignable_seq_num(&self) -> SeqNum {
133 match self {
134 Self::Plaintext(_) => SeqNum::MAX,
135 Self::Encrypted { record, .. } => record.max_assignable_seq_num(),
136 }
137 }
138}
139
140impl DeepSize for StoredRecord {
141 fn deep_size(&self) -> usize {
142 match self {
143 Self::Plaintext(record) => record.deep_size(),
144 Self::Encrypted {
145 metered_size,
146 record,
147 } => metered_size.deep_size() + record.deep_size(),
148 }
149 }
150}
151
152impl MeteredSize for StoredRecord {
153 fn metered_size(&self) -> usize {
154 match self {
155 Self::Plaintext(record) => record.metered_size(),
156 Self::Encrypted { metered_size, .. } => *metered_size,
157 }
158 }
159}
160
161impl From<Record> for StoredRecord {
162 fn from(value: Record) -> Self {
163 Self::Plaintext(value)
164 }
165}
166
167pub fn decode_if_command_record(
168 record: &[u8],
169) -> Result<Option<CommandRecord>, StoredRecordDecodeError> {
170 if record.is_empty() {
171 return Err(StoredRecordDecodeError::Truncated("MagicByte"));
172 }
173 let magic_byte = MagicByte::try_from(record[0])
174 .map_err(|msg| StoredRecordDecodeError::InvalidValue("MagicByte", msg))?;
175 match magic_byte.record_type {
176 RecordType::Command => {
177 let offset = 1 + magic_byte.metered_size_varlen as usize;
178 if record.len() < offset {
179 return Err(StoredRecordDecodeError::Truncated("MeteredSize"));
180 }
181 Ok(Some(decode_command_record(&record[offset..])?))
182 }
183 RecordType::Envelope | RecordType::EncryptedEnvelope => Ok(None),
184 }
185}
186
187pub fn encode_stored_record(record: Metered<&StoredRecord>) -> Bytes {
188 record.to_bytes()
189}
190
191pub fn stored_record_encoded_size(record: Metered<&StoredRecord>) -> usize {
192 record.encoded_size()
193}
194
195pub fn encode_stored_record_into(record: Metered<&StoredRecord>, buf: &mut impl BufMut) {
196 record.encode_into(buf);
197}
198
199impl WireEncode for Metered<&StoredRecord> {
200 fn encoded_size(&self) -> usize {
201 1 + magic_byte(self).metered_size_varlen as usize + self.encoded_body_size()
202 }
203
204 fn encode_into(&self, buf: &mut impl BufMut) {
205 let magic_byte = magic_byte(self);
206 buf.put_u8(magic_byte.into());
207 buf.put_uint(
208 self.metered_size() as u64,
209 magic_byte.metered_size_varlen as usize,
210 );
211 self.encode_body_into(buf);
212 }
213}
214
215fn magic_byte(record: &Metered<&StoredRecord>) -> MagicByte {
216 let metered_size = record.metered_size();
217 let metered_size_varlen = 8 - (metered_size.leading_zeros() / 8) as u8;
218 if metered_size_varlen > 3 {
219 panic!("illegal metered size varlen {metered_size} for record")
220 }
221 MagicByte {
222 record_type: record.record_type(),
223 metered_size_varlen,
224 }
225}
226
227pub type StoredSequencedBytes = Sequenced<Bytes>;
228pub type StoredSequencedRecord = Sequenced<StoredRecord>;
229
230pub fn decode_stored_record(
231 mut buf: Bytes,
232) -> Result<Metered<StoredRecord>, StoredRecordDecodeError> {
233 if buf.is_empty() {
234 return Err(StoredRecordDecodeError::Truncated("MagicByte"));
235 }
236 let magic_byte = MagicByte::try_from(buf.get_u8())
237 .map_err(|msg| StoredRecordDecodeError::InvalidValue("MagicByte", msg))?;
238
239 let metered_size =
240 buf.try_get_uint(magic_byte.metered_size_varlen as usize)
241 .map_err(|_| StoredRecordDecodeError::Truncated("MeteredSize"))? as usize;
242
243 let record = match magic_byte.record_type {
244 RecordType::Command => {
245 StoredRecord::Plaintext(Record::Command(decode_command_record(buf.as_ref())?))
246 }
247 RecordType::Envelope => {
248 StoredRecord::Plaintext(Record::Envelope(decode_envelope_record(buf)?))
249 }
250 RecordType::EncryptedEnvelope => {
251 StoredRecord::encrypted(EncryptedRecord::try_from(buf)?, metered_size)
252 }
253 };
254 Ok(Metered::with_size(metered_size, record))
255}
256
257pub fn decode_record(buf: Bytes) -> Result<Metered<Record>, StoredRecordDecodeError> {
258 let stored = decode_stored_record(buf)?;
259 let metered_size = stored.metered_size();
260 match stored.into_inner() {
261 StoredRecord::Plaintext(record) => Ok(record),
262 StoredRecord::Encrypted { .. } => Err(StoredRecordDecodeError::InvalidValue(
263 "RecordType",
264 "encrypted envelope requires decryption",
265 )),
266 }
267 .map(|record| Metered::with_size(metered_size, record))
268}
269
270#[cfg(test)]
271mod test {
272 use proptest::prelude::*;
273 use rstest::rstest;
274 use s2_common::record::{
275 EnvelopeRecord, Header, MAX_FENCING_TOKEN_LENGTH, MeteredExt, StreamPosition, Timestamp,
276 };
277
278 use super::*;
279
280 struct LegacyPlaintextFrame<'a> {
281 record: &'a Record,
282 }
283
284 impl LegacyPlaintextFrame<'_> {
285 fn magic_byte(&self) -> MagicByte {
286 let metered_size = self.record.metered_size();
287 let metered_size_varlen = 8 - (metered_size.leading_zeros() / 8) as u8;
288 assert!(metered_size_varlen <= 3);
289
290 MagicByte {
291 record_type: match self.record {
292 Record::Command(_) => RecordType::Command,
293 Record::Envelope(_) => RecordType::Envelope,
294 },
295 metered_size_varlen,
296 }
297 }
298 }
299
300 impl WireEncode for LegacyPlaintextFrame<'_> {
301 fn encoded_size(&self) -> usize {
302 let body_size = match self.record {
303 Record::Command(record) => record.encoded_size(),
304 Record::Envelope(record) => record.encoded_size(),
305 };
306 1 + self.magic_byte().metered_size_varlen as usize + body_size
307 }
308
309 fn encode_into(&self, buf: &mut impl BufMut) {
310 let magic_byte = self.magic_byte();
311 buf.put_u8(magic_byte.into());
312 buf.put_uint(
313 self.record.metered_size() as u64,
314 magic_byte.metered_size_varlen as usize,
315 );
316 match self.record {
317 Record::Command(record) => record.encode_into(buf),
318 Record::Envelope(record) => record.encode_into(buf),
319 }
320 }
321 }
322
323 fn legacy_plaintext_bytes(record: &Record) -> Bytes {
324 LegacyPlaintextFrame { record }.to_bytes()
325 }
326
327 fn semantic_metered_size(record: &Record) -> usize {
328 let (headers, body) = record.clone().into_parts();
329 8 + (2 * headers.len())
330 + headers
331 .iter()
332 .map(|header| header.name.len() + header.value.len())
333 .sum::<usize>()
334 + body.len()
335 }
336
337 fn bytes_strategy(allow_empty: bool) -> impl Strategy<Value = Bytes> {
338 prop_oneof![
339 prop::collection::vec(any::<u8>(), (if allow_empty { 0 } else { 1 })..10)
340 .prop_map(Bytes::from),
341 prop::collection::vec(any::<u8>(), 100..1000).prop_map(Bytes::from),
342 ]
343 }
344
345 fn header_strategy() -> impl Strategy<Value = Header> {
346 (bytes_strategy(false), bytes_strategy(true))
347 .prop_map(|(name, value)| Header { name, value })
348 }
349
350 fn headers_strategy() -> impl Strategy<Value = Vec<Header>> {
351 prop_oneof![
352 prop::collection::vec(header_strategy(), 0..10),
353 prop::collection::vec(header_strategy(), 200..300),
354 ]
355 }
356
357 fn command_strategy() -> impl Strategy<Value = CommandRecord> {
358 prop_oneof![
359 proptest::string::string_regex(&format!("[ -~]{{0,{MAX_FENCING_TOKEN_LENGTH}}}"))
360 .unwrap()
361 .prop_map(|token| CommandRecord::Fence(token.parse().unwrap())),
362 any::<SeqNum>().prop_map(CommandRecord::Trim),
363 ]
364 }
365
366 proptest!(
367 #![proptest_config(ProptestConfig::with_cases(10))]
368 #[test]
369 fn roundtrip_envelope(
370 seq_num in any::<SeqNum>(),
371 timestamp in any::<Timestamp>(),
372 headers in headers_strategy(),
373 body in bytes_strategy(true),
374 ) {
375 let record = Record::try_from_parts(headers, body).unwrap();
376 let metered_record: Metered<Record> = record.clone().into();
377 let encoded_record =
378 encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
379 let legacy_record = legacy_plaintext_bytes(&record);
380 prop_assert_eq!(encoded_record.as_ref(), legacy_record.as_ref());
381 let decoded_record = decode_record(encoded_record).unwrap();
382 prop_assert_eq!(&decoded_record, &metered_record);
383 let sequenced = decoded_record.sequenced(StreamPosition { seq_num, timestamp });
384 let (position, sequenced_record) = sequenced.into_parts();
385 assert_eq!(position, StreamPosition { seq_num, timestamp });
386 assert_eq!(sequenced_record.into_inner(), record);
387 }
388 );
389
390 proptest!(
391 #![proptest_config(ProptestConfig::with_cases(10))]
392 #[test]
393 fn roundtrip_metered(
394 headers in headers_strategy(),
395 body in bytes_strategy(true),
396 ) {
397 let record = Record::try_from_parts(headers.clone(), body.clone()).unwrap();
398 let encoded_record =
399 encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
400 assert_eq!(record.metered_size(), semantic_metered_size(&record));
401 assert_eq!(record.metered_size(), try_metered_size(encoded_record.as_ref()).unwrap() as usize);
402 }
403 );
404
405 proptest!(
406 #![proptest_config(ProptestConfig::with_cases(10))]
407 #[test]
408 fn roundtrip_command_metered(command in command_strategy()) {
409 let record = Record::Command(command);
410 let encoded_record =
411 encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
412 let expected_metered = semantic_metered_size(&record);
413 let wire_metered = try_metered_size(encoded_record.as_ref()).unwrap() as usize;
414 let decoded_record = decode_record(encoded_record).unwrap();
415
416 assert_eq!(record.metered_size(), expected_metered);
417 assert_eq!(record.metered_size(), wire_metered);
418 prop_assert_eq!(decoded_record, Metered::<Record>::from(record));
419 }
420 );
421
422 #[test]
423 fn roundtrip_encrypted_stored_record() {
424 let mut encoded = BytesMut::with_capacity(1 + 12 + 10 + 16);
425 encoded.put_u8(0x02);
426 encoded.put_slice(b"0123456789ab");
427 encoded.put_slice(b"ciphertext");
428 encoded.put_slice(b"0123456789abcdef");
429 let record =
430 StoredRecord::encrypted(EncryptedRecord::try_from(encoded.freeze()).unwrap(), 123);
431 let metered_record = record.clone().metered();
432 let encoded_record = encode_stored_record(metered_record.as_ref());
433 let decoded_record = decode_stored_record(encoded_record).unwrap();
434 assert_eq!(decoded_record, metered_record);
435 }
436
437 #[rstest]
438 #[case(0b0000_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 1})]
439 #[case(0b0001_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 3})]
440 #[case(0b0000_0011, MagicByte { record_type: RecordType::EncryptedEnvelope, metered_size_varlen: 1})]
441 #[case(0b0000_1001, MagicByte { record_type: RecordType::Command, metered_size_varlen: 2})]
442 fn valid_magic_byte_parsing(#[case] as_u8: u8, #[case] magic_byte: MagicByte) {
443 assert_eq!(MagicByte::try_from(as_u8).unwrap(), magic_byte);
444 assert_eq!(u8::from(magic_byte), as_u8);
445 }
446
447 #[rstest]
448 #[case(0b0000_1101, "invalid record type ordinal")]
449 #[case(0b0001_1001, "invalid metered_size_varlen")]
450 fn invalid_magic_byte_parsing(#[case] as_u8: u8, #[case] expected: &'static str) {
451 assert_eq!(MagicByte::try_from(as_u8), Err(expected));
452 }
453
454 #[test]
455 fn metered_record_truncated_after_magic_byte_returns_error() {
456 let truncated = Bytes::from_static(&[0b0000_0010]);
458 let result = decode_record(truncated);
459 assert_eq!(
460 result,
461 Err(StoredRecordDecodeError::Truncated("MeteredSize"))
462 );
463 }
464
465 #[rstest]
466 #[case::envelope_empty_headers(
467 StoredRecord::from(Record::Envelope(
468 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap()
469 )),
470 &[
471 0x02, 0x0d, 0x00, b'h', b'e', b'l', b'l', b'o',
474 ],
475 )]
476 #[case::envelope_with_header(
477 StoredRecord::from(Record::Envelope(
478 EnvelopeRecord::try_from_parts(
479 vec![Header {
480 name: Bytes::from_static(b"k"),
481 value: Bytes::from_static(b"v"),
482 }],
483 Bytes::from_static(b"b"),
484 ).unwrap()
485 )),
486 &[
487 0x02, 0x0d, 0x10, 0x01, 0x01, b'k',
490 0x01, b'v',
491 b'b',
492 ],
493 )]
494 #[case::command_trim(
495 StoredRecord::from(Record::Command(CommandRecord::Trim(42))),
496 &[
497 0x01, 0x16, 0x01, 0x00, 0x00, 0x00, 0x00,
500 0x00, 0x00, 0x00, 0x2a,
501 ],
502 )]
503 fn stored_record_encoding_matches_existing_wire_format(
504 #[case] record: StoredRecord,
505 #[case] expected: &[u8],
506 ) {
507 let metered_record = record.clone().metered();
508 let encoded_size = stored_record_encoded_size(metered_record.as_ref());
509 let encoded = encode_stored_record(metered_record.as_ref());
510 let mut encoded_into = BytesMut::with_capacity(encoded_size);
511 encode_stored_record_into(metered_record.as_ref(), &mut encoded_into);
512
513 assert_eq!(encoded.len(), encoded_size);
514 assert_eq!(encoded.as_ref(), expected);
515 assert_eq!(encoded_into.as_ref(), expected);
516 assert_eq!(decode_stored_record(encoded).unwrap().into_inner(), record);
517 }
518
519 #[test]
520 fn encrypted_stored_record_encoding_matches_existing_wire_format() {
521 let encrypted_payload = Bytes::from_static(b"\x020123456789abciphertext0123456789abcdef");
522 let record = StoredRecord::encrypted(
523 EncryptedRecord::try_from(encrypted_payload.clone()).unwrap(),
524 123,
525 );
526
527 let encoded = encode_stored_record(record.clone().metered().as_ref());
528
529 assert_eq!(
530 encoded.as_ref(),
531 [&[0x03, 0x7b], encrypted_payload.as_ref()].concat()
532 );
533 assert_eq!(decode_stored_record(encoded).unwrap().into_inner(), record);
534 }
535
536 #[test]
537 fn decode_stored_record_preserves_encoded_metered_size_prefix() {
538 let record = StoredRecord::from(Record::Envelope(
539 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
540 ));
541 let mut encoded = encode_stored_record(record.clone().metered().as_ref()).to_vec();
542 encoded[1] = 99;
543
544 let decoded = decode_stored_record(Bytes::from(encoded)).unwrap();
545
546 assert_eq!(decoded.metered_size(), 99);
547 assert_eq!(decoded.into_inner(), record);
548 }
549
550 #[test]
551 fn decode_record_preserves_encoded_metered_size_prefix() {
552 let record = Record::Envelope(
553 EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
554 );
555 let mut encoded =
556 encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref()).to_vec();
557 encoded[1] = 99;
558
559 let decoded = decode_record(Bytes::from(encoded)).unwrap();
560
561 assert_eq!(decoded.metered_size(), 99);
562 assert_eq!(decoded.into_inner(), record);
563 }
564
565 #[test]
566 fn test_read_varint() {
567 let data = [0u8, 0, 0, 1, 0, 0, 0];
568
569 assert_eq!(read_vint_u32_be(&data[..4]), 1u32);
570 assert_eq!(read_vint_u32_be(&data[2..5]), 2u32.pow(8));
571 assert_eq!(read_vint_u32_be(&data[2..6]), 2u32.pow(16));
572 assert_eq!(read_vint_u32_be(&data[3..]), 2u32.pow(24));
573 }
574}