Skip to main content

s2_storage/record/
framing.rs

1#[cfg(test)]
2use bytes::BytesMut;
3use bytes::{Buf, BufMut, Bytes};
4use s2_common::{
5    deep_size::DeepSize,
6    record::{CommandRecord, Metered, MeteredSize, Record, SeqNum, Sequenced},
7};
8
9use super::{
10    codec::{StoredRecordDecodeError, WireEncode, decode_command_record, decode_envelope_record},
11    encryption::EncryptedRecord,
12};
13
14#[derive(Clone, Copy, Debug, PartialEq)]
15#[repr(u8)]
16enum RecordType {
17    Command = 1,
18    Envelope = 2,
19    EncryptedEnvelope = 3,
20}
21
22impl TryFrom<u8> for RecordType {
23    type Error = &'static str;
24
25    fn try_from(value: u8) -> Result<Self, Self::Error> {
26        match value {
27            1 => Ok(Self::Command),
28            2 => Ok(Self::Envelope),
29            3 => Ok(Self::EncryptedEnvelope),
30            _ => Err("invalid record type ordinal"),
31        }
32    }
33}
34
35#[derive(Copy, Clone, Debug, PartialEq)]
36struct MagicByte {
37    record_type: RecordType,
38    metered_size_varlen: u8,
39}
40
41/// Read bytes to u32 in big-endian order.
42fn read_vint_u32_be(bytes: &[u8]) -> u32 {
43    if bytes.len() > size_of::<u32>() || bytes.is_empty() {
44        panic!("invalid variable int bytes = {} len", bytes.len())
45    }
46    let mut acc: u32 = 0;
47    for &byte in bytes {
48        acc = (acc << 8) | byte as u32;
49    }
50    acc
51}
52
53pub fn try_metered_size(record_bytes: &[u8]) -> Result<u32, &'static str> {
54    let magic_byte_u8 = *record_bytes.first().ok_or("byte range is empty")?;
55    let magic_byte = MagicByte::try_from(magic_byte_u8)?;
56    Ok(read_vint_u32_be(
57        record_bytes
58            .get(1..1 + magic_byte.metered_size_varlen as usize)
59            .ok_or("byte range doesn't include bytes for metered size")?,
60    ))
61}
62
63impl TryFrom<u8> for MagicByte {
64    type Error = &'static str;
65
66    fn try_from(value: u8) -> Result<Self, Self::Error> {
67        let record_type = RecordType::try_from(value & 0b111)?;
68        Ok(Self {
69            record_type,
70            metered_size_varlen: match (value >> 3) & 0b11 {
71                0 => 1u8,
72                1 => 2u8,
73                2 => 3u8,
74                _ => Err("invalid metered_size_varlen")?,
75            },
76        })
77    }
78}
79
80impl From<MagicByte> for u8 {
81    fn from(value: MagicByte) -> Self {
82        ((value.metered_size_varlen - 1) << 3) | value.record_type as u8
83    }
84}
85
86#[derive(Debug, PartialEq, Eq, Clone)]
87pub enum StoredRecord {
88    Plaintext(Record),
89    /// Encrypted envelope record bytes plus the logical plaintext metered size.
90    ///
91    /// The stored `metered_size` must match the decrypted envelope record's
92    /// metered size. Decoding preserves the encoded prefix, and decryption
93    /// validates it before returning a logical record.
94    Encrypted {
95        metered_size: usize,
96        record: EncryptedRecord,
97    },
98}
99
100impl StoredRecord {
101    pub(crate) fn encrypted(record: EncryptedRecord, metered_size: usize) -> Self {
102        Self::Encrypted {
103            metered_size,
104            record,
105        }
106    }
107
108    fn record_type(&self) -> RecordType {
109        match self {
110            Self::Plaintext(Record::Command(_)) => RecordType::Command,
111            Self::Plaintext(Record::Envelope(_)) => RecordType::Envelope,
112            Self::Encrypted { .. } => RecordType::EncryptedEnvelope,
113        }
114    }
115
116    fn encoded_body_size(&self) -> usize {
117        match self {
118            Self::Plaintext(Record::Command(record)) => record.encoded_size(),
119            Self::Plaintext(Record::Envelope(record)) => record.encoded_size(),
120            Self::Encrypted { record, .. } => record.encoded_size(),
121        }
122    }
123
124    fn encode_body_into(&self, buf: &mut impl BufMut) {
125        match self {
126            Self::Plaintext(Record::Command(record)) => record.encode_into(buf),
127            Self::Plaintext(Record::Envelope(record)) => record.encode_into(buf),
128            Self::Encrypted { record, .. } => record.encode_into(buf),
129        }
130    }
131
132    pub fn max_assignable_seq_num(&self) -> SeqNum {
133        match self {
134            Self::Plaintext(_) => SeqNum::MAX,
135            Self::Encrypted { record, .. } => record.max_assignable_seq_num(),
136        }
137    }
138}
139
140impl DeepSize for StoredRecord {
141    fn deep_size(&self) -> usize {
142        match self {
143            Self::Plaintext(record) => record.deep_size(),
144            Self::Encrypted {
145                metered_size,
146                record,
147            } => metered_size.deep_size() + record.deep_size(),
148        }
149    }
150}
151
152impl MeteredSize for StoredRecord {
153    fn metered_size(&self) -> usize {
154        match self {
155            Self::Plaintext(record) => record.metered_size(),
156            Self::Encrypted { metered_size, .. } => *metered_size,
157        }
158    }
159}
160
161impl From<Record> for StoredRecord {
162    fn from(value: Record) -> Self {
163        Self::Plaintext(value)
164    }
165}
166
167pub fn decode_if_command_record(
168    record: &[u8],
169) -> Result<Option<CommandRecord>, StoredRecordDecodeError> {
170    if record.is_empty() {
171        return Err(StoredRecordDecodeError::Truncated("MagicByte"));
172    }
173    let magic_byte = MagicByte::try_from(record[0])
174        .map_err(|msg| StoredRecordDecodeError::InvalidValue("MagicByte", msg))?;
175    match magic_byte.record_type {
176        RecordType::Command => {
177            let offset = 1 + magic_byte.metered_size_varlen as usize;
178            if record.len() < offset {
179                return Err(StoredRecordDecodeError::Truncated("MeteredSize"));
180            }
181            Ok(Some(decode_command_record(&record[offset..])?))
182        }
183        RecordType::Envelope | RecordType::EncryptedEnvelope => Ok(None),
184    }
185}
186
187pub fn encode_stored_record(record: Metered<&StoredRecord>) -> Bytes {
188    record.to_bytes()
189}
190
191pub fn stored_record_encoded_size(record: Metered<&StoredRecord>) -> usize {
192    record.encoded_size()
193}
194
195pub fn encode_stored_record_into(record: Metered<&StoredRecord>, buf: &mut impl BufMut) {
196    record.encode_into(buf);
197}
198
199impl WireEncode for Metered<&StoredRecord> {
200    fn encoded_size(&self) -> usize {
201        1 + magic_byte(self).metered_size_varlen as usize + self.encoded_body_size()
202    }
203
204    fn encode_into(&self, buf: &mut impl BufMut) {
205        let magic_byte = magic_byte(self);
206        buf.put_u8(magic_byte.into());
207        buf.put_uint(
208            self.metered_size() as u64,
209            magic_byte.metered_size_varlen as usize,
210        );
211        self.encode_body_into(buf);
212    }
213}
214
215fn magic_byte(record: &Metered<&StoredRecord>) -> MagicByte {
216    let metered_size = record.metered_size();
217    let metered_size_varlen = 8 - (metered_size.leading_zeros() / 8) as u8;
218    if metered_size_varlen > 3 {
219        panic!("illegal metered size varlen {metered_size} for record")
220    }
221    MagicByte {
222        record_type: record.record_type(),
223        metered_size_varlen,
224    }
225}
226
227pub type StoredSequencedBytes = Sequenced<Bytes>;
228pub type StoredSequencedRecord = Sequenced<StoredRecord>;
229
230pub fn decode_stored_record(
231    mut buf: Bytes,
232) -> Result<Metered<StoredRecord>, StoredRecordDecodeError> {
233    if buf.is_empty() {
234        return Err(StoredRecordDecodeError::Truncated("MagicByte"));
235    }
236    let magic_byte = MagicByte::try_from(buf.get_u8())
237        .map_err(|msg| StoredRecordDecodeError::InvalidValue("MagicByte", msg))?;
238
239    let metered_size =
240        buf.try_get_uint(magic_byte.metered_size_varlen as usize)
241            .map_err(|_| StoredRecordDecodeError::Truncated("MeteredSize"))? as usize;
242
243    let record = match magic_byte.record_type {
244        RecordType::Command => {
245            StoredRecord::Plaintext(Record::Command(decode_command_record(buf.as_ref())?))
246        }
247        RecordType::Envelope => {
248            StoredRecord::Plaintext(Record::Envelope(decode_envelope_record(buf)?))
249        }
250        RecordType::EncryptedEnvelope => {
251            StoredRecord::encrypted(EncryptedRecord::try_from(buf)?, metered_size)
252        }
253    };
254    Ok(Metered::with_size(metered_size, record))
255}
256
257pub fn decode_record(buf: Bytes) -> Result<Metered<Record>, StoredRecordDecodeError> {
258    let stored = decode_stored_record(buf)?;
259    let metered_size = stored.metered_size();
260    match stored.into_inner() {
261        StoredRecord::Plaintext(record) => Ok(record),
262        StoredRecord::Encrypted { .. } => Err(StoredRecordDecodeError::InvalidValue(
263            "RecordType",
264            "encrypted envelope requires decryption",
265        )),
266    }
267    .map(|record| Metered::with_size(metered_size, record))
268}
269
270#[cfg(test)]
271mod test {
272    use proptest::prelude::*;
273    use rstest::rstest;
274    use s2_common::record::{
275        EnvelopeRecord, Header, MAX_FENCING_TOKEN_LENGTH, MeteredExt, StreamPosition, Timestamp,
276    };
277
278    use super::*;
279
280    struct LegacyPlaintextFrame<'a> {
281        record: &'a Record,
282    }
283
284    impl LegacyPlaintextFrame<'_> {
285        fn magic_byte(&self) -> MagicByte {
286            let metered_size = self.record.metered_size();
287            let metered_size_varlen = 8 - (metered_size.leading_zeros() / 8) as u8;
288            assert!(metered_size_varlen <= 3);
289
290            MagicByte {
291                record_type: match self.record {
292                    Record::Command(_) => RecordType::Command,
293                    Record::Envelope(_) => RecordType::Envelope,
294                },
295                metered_size_varlen,
296            }
297        }
298    }
299
300    impl WireEncode for LegacyPlaintextFrame<'_> {
301        fn encoded_size(&self) -> usize {
302            let body_size = match self.record {
303                Record::Command(record) => record.encoded_size(),
304                Record::Envelope(record) => record.encoded_size(),
305            };
306            1 + self.magic_byte().metered_size_varlen as usize + body_size
307        }
308
309        fn encode_into(&self, buf: &mut impl BufMut) {
310            let magic_byte = self.magic_byte();
311            buf.put_u8(magic_byte.into());
312            buf.put_uint(
313                self.record.metered_size() as u64,
314                magic_byte.metered_size_varlen as usize,
315            );
316            match self.record {
317                Record::Command(record) => record.encode_into(buf),
318                Record::Envelope(record) => record.encode_into(buf),
319            }
320        }
321    }
322
323    fn legacy_plaintext_bytes(record: &Record) -> Bytes {
324        LegacyPlaintextFrame { record }.to_bytes()
325    }
326
327    fn semantic_metered_size(record: &Record) -> usize {
328        let (headers, body) = record.clone().into_parts();
329        8 + (2 * headers.len())
330            + headers
331                .iter()
332                .map(|header| header.name.len() + header.value.len())
333                .sum::<usize>()
334            + body.len()
335    }
336
337    fn bytes_strategy(allow_empty: bool) -> impl Strategy<Value = Bytes> {
338        prop_oneof![
339            prop::collection::vec(any::<u8>(), (if allow_empty { 0 } else { 1 })..10)
340                .prop_map(Bytes::from),
341            prop::collection::vec(any::<u8>(), 100..1000).prop_map(Bytes::from),
342        ]
343    }
344
345    fn header_strategy() -> impl Strategy<Value = Header> {
346        (bytes_strategy(false), bytes_strategy(true))
347            .prop_map(|(name, value)| Header { name, value })
348    }
349
350    fn headers_strategy() -> impl Strategy<Value = Vec<Header>> {
351        prop_oneof![
352            prop::collection::vec(header_strategy(), 0..10),
353            prop::collection::vec(header_strategy(), 200..300),
354        ]
355    }
356
357    fn command_strategy() -> impl Strategy<Value = CommandRecord> {
358        prop_oneof![
359            proptest::string::string_regex(&format!("[ -~]{{0,{MAX_FENCING_TOKEN_LENGTH}}}"))
360                .unwrap()
361                .prop_map(|token| CommandRecord::Fence(token.parse().unwrap())),
362            any::<SeqNum>().prop_map(CommandRecord::Trim),
363        ]
364    }
365
366    proptest!(
367        #![proptest_config(ProptestConfig::with_cases(10))]
368        #[test]
369        fn roundtrip_envelope(
370            seq_num in any::<SeqNum>(),
371            timestamp in any::<Timestamp>(),
372            headers in headers_strategy(),
373            body in bytes_strategy(true),
374        ) {
375            let record = Record::try_from_parts(headers, body).unwrap();
376            let metered_record: Metered<Record> = record.clone().into();
377            let encoded_record =
378                encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
379            let legacy_record = legacy_plaintext_bytes(&record);
380            prop_assert_eq!(encoded_record.as_ref(), legacy_record.as_ref());
381            let decoded_record = decode_record(encoded_record).unwrap();
382            prop_assert_eq!(&decoded_record, &metered_record);
383            let sequenced = decoded_record.sequenced(StreamPosition { seq_num, timestamp });
384            let (position, sequenced_record) = sequenced.into_parts();
385            assert_eq!(position, StreamPosition { seq_num, timestamp });
386            assert_eq!(sequenced_record.into_inner(), record);
387        }
388    );
389
390    proptest!(
391        #![proptest_config(ProptestConfig::with_cases(10))]
392        #[test]
393        fn roundtrip_metered(
394            headers in headers_strategy(),
395            body in bytes_strategy(true),
396        ) {
397            let record = Record::try_from_parts(headers.clone(), body.clone()).unwrap();
398            let encoded_record =
399                encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
400            assert_eq!(record.metered_size(), semantic_metered_size(&record));
401            assert_eq!(record.metered_size(), try_metered_size(encoded_record.as_ref()).unwrap() as usize);
402        }
403    );
404
405    proptest!(
406        #![proptest_config(ProptestConfig::with_cases(10))]
407        #[test]
408        fn roundtrip_command_metered(command in command_strategy()) {
409            let record = Record::Command(command);
410            let encoded_record =
411                encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
412            let expected_metered = semantic_metered_size(&record);
413            let wire_metered = try_metered_size(encoded_record.as_ref()).unwrap() as usize;
414            let decoded_record = decode_record(encoded_record).unwrap();
415
416            assert_eq!(record.metered_size(), expected_metered);
417            assert_eq!(record.metered_size(), wire_metered);
418            prop_assert_eq!(decoded_record, Metered::<Record>::from(record));
419        }
420    );
421
422    #[test]
423    fn roundtrip_encrypted_stored_record() {
424        let mut encoded = BytesMut::with_capacity(1 + 12 + 10 + 16);
425        encoded.put_u8(0x02);
426        encoded.put_slice(b"0123456789ab");
427        encoded.put_slice(b"ciphertext");
428        encoded.put_slice(b"0123456789abcdef");
429        let record =
430            StoredRecord::encrypted(EncryptedRecord::try_from(encoded.freeze()).unwrap(), 123);
431        let metered_record = record.clone().metered();
432        let encoded_record = encode_stored_record(metered_record.as_ref());
433        let decoded_record = decode_stored_record(encoded_record).unwrap();
434        assert_eq!(decoded_record, metered_record);
435    }
436
437    #[rstest]
438    #[case(0b0000_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 1})]
439    #[case(0b0001_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 3})]
440    #[case(0b0000_0011, MagicByte { record_type: RecordType::EncryptedEnvelope, metered_size_varlen: 1})]
441    #[case(0b0000_1001, MagicByte { record_type: RecordType::Command, metered_size_varlen: 2})]
442    fn valid_magic_byte_parsing(#[case] as_u8: u8, #[case] magic_byte: MagicByte) {
443        assert_eq!(MagicByte::try_from(as_u8).unwrap(), magic_byte);
444        assert_eq!(u8::from(magic_byte), as_u8);
445    }
446
447    #[rstest]
448    #[case(0b0000_1101, "invalid record type ordinal")]
449    #[case(0b0001_1001, "invalid metered_size_varlen")]
450    fn invalid_magic_byte_parsing(#[case] as_u8: u8, #[case] expected: &'static str) {
451        assert_eq!(MagicByte::try_from(as_u8), Err(expected));
452    }
453
454    #[test]
455    fn metered_record_truncated_after_magic_byte_returns_error() {
456        // Magic byte: Envelope (0b0000_0010), metered_size_varlen = 1 -> expects 1 more byte.
457        let truncated = Bytes::from_static(&[0b0000_0010]);
458        let result = decode_record(truncated);
459        assert_eq!(
460            result,
461            Err(StoredRecordDecodeError::Truncated("MeteredSize"))
462        );
463    }
464
465    #[rstest]
466    #[case::envelope_empty_headers(
467        StoredRecord::from(Record::Envelope(
468            EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap()
469        )),
470        &[
471            0x02, 0x0d, // envelope record, metered size 13
472            0x00, // no headers
473            b'h', b'e', b'l', b'l', b'o',
474        ],
475    )]
476    #[case::envelope_with_header(
477        StoredRecord::from(Record::Envelope(
478            EnvelopeRecord::try_from_parts(
479                vec![Header {
480                    name: Bytes::from_static(b"k"),
481                    value: Bytes::from_static(b"v"),
482                }],
483                Bytes::from_static(b"b"),
484            ).unwrap()
485        )),
486        &[
487            0x02, 0x0d, // envelope record, metered size 13
488            0x10, 0x01, // one header, one byte for num headers
489            0x01, b'k',
490            0x01, b'v',
491            b'b',
492        ],
493    )]
494    #[case::command_trim(
495        StoredRecord::from(Record::Command(CommandRecord::Trim(42))),
496        &[
497            0x01, 0x16, // command record, metered size 22
498            0x01, // trim command ordinal
499            0x00, 0x00, 0x00, 0x00,
500            0x00, 0x00, 0x00, 0x2a,
501        ],
502    )]
503    fn stored_record_encoding_matches_existing_wire_format(
504        #[case] record: StoredRecord,
505        #[case] expected: &[u8],
506    ) {
507        let metered_record = record.clone().metered();
508        let encoded_size = stored_record_encoded_size(metered_record.as_ref());
509        let encoded = encode_stored_record(metered_record.as_ref());
510        let mut encoded_into = BytesMut::with_capacity(encoded_size);
511        encode_stored_record_into(metered_record.as_ref(), &mut encoded_into);
512
513        assert_eq!(encoded.len(), encoded_size);
514        assert_eq!(encoded.as_ref(), expected);
515        assert_eq!(encoded_into.as_ref(), expected);
516        assert_eq!(decode_stored_record(encoded).unwrap().into_inner(), record);
517    }
518
519    #[test]
520    fn encrypted_stored_record_encoding_matches_existing_wire_format() {
521        let encrypted_payload = Bytes::from_static(b"\x020123456789abciphertext0123456789abcdef");
522        let record = StoredRecord::encrypted(
523            EncryptedRecord::try_from(encrypted_payload.clone()).unwrap(),
524            123,
525        );
526
527        let encoded = encode_stored_record(record.clone().metered().as_ref());
528
529        assert_eq!(
530            encoded.as_ref(),
531            [&[0x03, 0x7b], encrypted_payload.as_ref()].concat()
532        );
533        assert_eq!(decode_stored_record(encoded).unwrap().into_inner(), record);
534    }
535
536    #[test]
537    fn decode_stored_record_preserves_encoded_metered_size_prefix() {
538        let record = StoredRecord::from(Record::Envelope(
539            EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
540        ));
541        let mut encoded = encode_stored_record(record.clone().metered().as_ref()).to_vec();
542        encoded[1] = 99;
543
544        let decoded = decode_stored_record(Bytes::from(encoded)).unwrap();
545
546        assert_eq!(decoded.metered_size(), 99);
547        assert_eq!(decoded.into_inner(), record);
548    }
549
550    #[test]
551    fn decode_record_preserves_encoded_metered_size_prefix() {
552        let record = Record::Envelope(
553            EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
554        );
555        let mut encoded =
556            encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref()).to_vec();
557        encoded[1] = 99;
558
559        let decoded = decode_record(Bytes::from(encoded)).unwrap();
560
561        assert_eq!(decoded.metered_size(), 99);
562        assert_eq!(decoded.into_inner(), record);
563    }
564
565    #[test]
566    fn test_read_varint() {
567        let data = [0u8, 0, 0, 1, 0, 0, 0];
568
569        assert_eq!(read_vint_u32_be(&data[..4]), 1u32);
570        assert_eq!(read_vint_u32_be(&data[2..5]), 2u32.pow(8));
571        assert_eq!(read_vint_u32_be(&data[2..6]), 2u32.pow(16));
572        assert_eq!(read_vint_u32_be(&data[3..]), 2u32.pow(24));
573    }
574}