Skip to main content

s2_storage/record/
framing.rs

1#[cfg(test)]
2use bytes::BytesMut;
3use bytes::{Buf, BufMut, Bytes};
4use s2_common::{
5    deep_size::DeepSize,
6    encryption::EncryptionAlgorithm,
7    record::{CommandRecord, Metered, MeteredSize, Record, SeqNum, Sequenced},
8};
9
10use super::{
11    codec::{StoredRecordDecodeError, WireEncode, decode_command_record, decode_envelope_record},
12    encryption::EncryptedRecord,
13};
14
15#[derive(Clone, Copy, Debug, PartialEq)]
16#[repr(u8)]
17enum RecordType {
18    Command = 1,
19    Envelope = 2,
20    EncryptedEnvelope = 3,
21}
22
23impl TryFrom<u8> for RecordType {
24    type Error = &'static str;
25
26    fn try_from(value: u8) -> Result<Self, Self::Error> {
27        match value {
28            1 => Ok(Self::Command),
29            2 => Ok(Self::Envelope),
30            3 => Ok(Self::EncryptedEnvelope),
31            _ => Err("invalid record type ordinal"),
32        }
33    }
34}
35
36#[derive(Copy, Clone, Debug, PartialEq)]
37struct MagicByte {
38    record_type: RecordType,
39    metered_size_varlen: u8,
40}
41
42/// Read bytes to u32 in big-endian order.
43fn read_vint_u32_be(bytes: &[u8]) -> u32 {
44    if bytes.len() > size_of::<u32>() || bytes.is_empty() {
45        panic!("invalid variable int bytes = {} len", bytes.len())
46    }
47    let mut acc: u32 = 0;
48    for &byte in bytes {
49        acc = (acc << 8) | byte as u32;
50    }
51    acc
52}
53
54pub fn try_metered_size(record_bytes: &[u8]) -> Result<u32, &'static str> {
55    let magic_byte_u8 = *record_bytes.first().ok_or("byte range is empty")?;
56    let magic_byte = MagicByte::try_from(magic_byte_u8)?;
57    Ok(read_vint_u32_be(
58        record_bytes
59            .get(1..1 + magic_byte.metered_size_varlen as usize)
60            .ok_or("byte range doesn't include bytes for metered size")?,
61    ))
62}
63
64impl TryFrom<u8> for MagicByte {
65    type Error = &'static str;
66
67    fn try_from(value: u8) -> Result<Self, Self::Error> {
68        let record_type = RecordType::try_from(value & 0b111)?;
69        Ok(Self {
70            record_type,
71            metered_size_varlen: match (value >> 3) & 0b11 {
72                0 => 1u8,
73                1 => 2u8,
74                2 => 3u8,
75                _ => Err("invalid metered_size_varlen")?,
76            },
77        })
78    }
79}
80
81impl From<MagicByte> for u8 {
82    fn from(value: MagicByte) -> Self {
83        ((value.metered_size_varlen - 1) << 3) | value.record_type as u8
84    }
85}
86
87#[derive(Debug, PartialEq, Eq, Clone)]
88pub enum StoredRecord {
89    Plaintext(Record),
90    /// Encrypted envelope record bytes plus the logical plaintext metered size.
91    ///
92    /// The stored `metered_size` must match the decrypted envelope record's
93    /// metered size. Decoding preserves the encoded prefix, and decryption
94    /// validates it before returning a logical record.
95    Encrypted {
96        metered_size: usize,
97        record: EncryptedRecord,
98    },
99}
100
101impl StoredRecord {
102    pub(crate) fn encrypted(record: EncryptedRecord, metered_size: usize) -> Self {
103        Self::Encrypted {
104            metered_size,
105            record,
106        }
107    }
108
109    fn record_type(&self) -> RecordType {
110        match self {
111            Self::Plaintext(Record::Command(_)) => RecordType::Command,
112            Self::Plaintext(Record::Envelope(_)) => RecordType::Envelope,
113            Self::Encrypted { .. } => RecordType::EncryptedEnvelope,
114        }
115    }
116
117    fn encoded_body_size(&self) -> usize {
118        match self {
119            Self::Plaintext(Record::Command(record)) => record.encoded_size(),
120            Self::Plaintext(Record::Envelope(record)) => record.encoded_size(),
121            Self::Encrypted { record, .. } => record.encoded_size(),
122        }
123    }
124
125    fn encode_body_into(&self, buf: &mut impl BufMut) {
126        match self {
127            Self::Plaintext(Record::Command(record)) => record.encode_into(buf),
128            Self::Plaintext(Record::Envelope(record)) => record.encode_into(buf),
129            Self::Encrypted { record, .. } => record.encode_into(buf),
130        }
131    }
132
133    pub fn encryption_algorithm(&self) -> Option<EncryptionAlgorithm> {
134        match self {
135            Self::Plaintext(_) => None,
136            Self::Encrypted { record, .. } => Some(record.algorithm()),
137        }
138    }
139
140    pub fn max_assignable_seq_num(&self) -> SeqNum {
141        match self {
142            Self::Plaintext(_) => SeqNum::MAX,
143            Self::Encrypted { record, .. } => record.max_assignable_seq_num(),
144        }
145    }
146}
147
148impl DeepSize for StoredRecord {
149    fn deep_size(&self) -> usize {
150        match self {
151            Self::Plaintext(record) => record.deep_size(),
152            Self::Encrypted {
153                metered_size,
154                record,
155            } => metered_size.deep_size() + record.deep_size(),
156        }
157    }
158}
159
160impl MeteredSize for StoredRecord {
161    fn metered_size(&self) -> usize {
162        match self {
163            Self::Plaintext(record) => record.metered_size(),
164            Self::Encrypted { metered_size, .. } => *metered_size,
165        }
166    }
167}
168
169impl From<Record> for StoredRecord {
170    fn from(value: Record) -> Self {
171        Self::Plaintext(value)
172    }
173}
174
175pub fn decode_if_command_record(
176    record: &[u8],
177) -> Result<Option<CommandRecord>, StoredRecordDecodeError> {
178    if record.is_empty() {
179        return Err(StoredRecordDecodeError::Truncated("MagicByte"));
180    }
181    let magic_byte = MagicByte::try_from(record[0])
182        .map_err(|msg| StoredRecordDecodeError::InvalidValue("MagicByte", msg))?;
183    match magic_byte.record_type {
184        RecordType::Command => {
185            let offset = 1 + magic_byte.metered_size_varlen as usize;
186            if record.len() < offset {
187                return Err(StoredRecordDecodeError::Truncated("MeteredSize"));
188            }
189            Ok(Some(decode_command_record(&record[offset..])?))
190        }
191        RecordType::Envelope | RecordType::EncryptedEnvelope => Ok(None),
192    }
193}
194
195pub fn encode_stored_record(record: Metered<&StoredRecord>) -> Bytes {
196    record.to_bytes()
197}
198
199pub fn stored_record_encoded_size(record: Metered<&StoredRecord>) -> usize {
200    record.encoded_size()
201}
202
203pub fn encode_stored_record_into(record: Metered<&StoredRecord>, buf: &mut impl BufMut) {
204    record.encode_into(buf);
205}
206
207impl WireEncode for Metered<&StoredRecord> {
208    fn encoded_size(&self) -> usize {
209        1 + magic_byte(self).metered_size_varlen as usize + self.encoded_body_size()
210    }
211
212    fn encode_into(&self, buf: &mut impl BufMut) {
213        let magic_byte = magic_byte(self);
214        buf.put_u8(magic_byte.into());
215        buf.put_uint(
216            self.metered_size() as u64,
217            magic_byte.metered_size_varlen as usize,
218        );
219        self.encode_body_into(buf);
220    }
221}
222
223fn magic_byte(record: &Metered<&StoredRecord>) -> MagicByte {
224    let metered_size = record.metered_size();
225    let metered_size_varlen = 8 - (metered_size.leading_zeros() / 8) as u8;
226    if metered_size_varlen > 3 {
227        panic!("illegal metered size varlen {metered_size} for record")
228    }
229    MagicByte {
230        record_type: record.record_type(),
231        metered_size_varlen,
232    }
233}
234
235pub type StoredSequencedBytes = Sequenced<Bytes>;
236pub type StoredSequencedRecord = Sequenced<StoredRecord>;
237
238pub fn decode_stored_record(
239    mut buf: Bytes,
240) -> Result<Metered<StoredRecord>, StoredRecordDecodeError> {
241    if buf.is_empty() {
242        return Err(StoredRecordDecodeError::Truncated("MagicByte"));
243    }
244    let magic_byte = MagicByte::try_from(buf.get_u8())
245        .map_err(|msg| StoredRecordDecodeError::InvalidValue("MagicByte", msg))?;
246
247    let metered_size =
248        buf.try_get_uint(magic_byte.metered_size_varlen as usize)
249            .map_err(|_| StoredRecordDecodeError::Truncated("MeteredSize"))? as usize;
250
251    let record = match magic_byte.record_type {
252        RecordType::Command => {
253            StoredRecord::Plaintext(Record::Command(decode_command_record(buf.as_ref())?))
254        }
255        RecordType::Envelope => {
256            StoredRecord::Plaintext(Record::Envelope(decode_envelope_record(buf)?))
257        }
258        RecordType::EncryptedEnvelope => {
259            StoredRecord::encrypted(EncryptedRecord::try_from(buf)?, metered_size)
260        }
261    };
262    Ok(Metered::with_size(metered_size, record))
263}
264
265pub fn decode_record(buf: Bytes) -> Result<Metered<Record>, StoredRecordDecodeError> {
266    let stored = decode_stored_record(buf)?;
267    let metered_size = stored.metered_size();
268    match stored.into_inner() {
269        StoredRecord::Plaintext(record) => Ok(record),
270        StoredRecord::Encrypted { .. } => Err(StoredRecordDecodeError::InvalidValue(
271            "RecordType",
272            "encrypted envelope requires decryption",
273        )),
274    }
275    .map(|record| Metered::with_size(metered_size, record))
276}
277
278#[cfg(test)]
279mod test {
280    use proptest::prelude::*;
281    use rstest::rstest;
282    use s2_common::record::{
283        EnvelopeRecord, Header, MAX_FENCING_TOKEN_LENGTH, MeteredExt, StreamPosition, Timestamp,
284    };
285
286    use super::*;
287
288    struct LegacyPlaintextFrame<'a> {
289        record: &'a Record,
290    }
291
292    impl LegacyPlaintextFrame<'_> {
293        fn magic_byte(&self) -> MagicByte {
294            let metered_size = self.record.metered_size();
295            let metered_size_varlen = 8 - (metered_size.leading_zeros() / 8) as u8;
296            assert!(metered_size_varlen <= 3);
297
298            MagicByte {
299                record_type: match self.record {
300                    Record::Command(_) => RecordType::Command,
301                    Record::Envelope(_) => RecordType::Envelope,
302                },
303                metered_size_varlen,
304            }
305        }
306    }
307
308    impl WireEncode for LegacyPlaintextFrame<'_> {
309        fn encoded_size(&self) -> usize {
310            let body_size = match self.record {
311                Record::Command(record) => record.encoded_size(),
312                Record::Envelope(record) => record.encoded_size(),
313            };
314            1 + self.magic_byte().metered_size_varlen as usize + body_size
315        }
316
317        fn encode_into(&self, buf: &mut impl BufMut) {
318            let magic_byte = self.magic_byte();
319            buf.put_u8(magic_byte.into());
320            buf.put_uint(
321                self.record.metered_size() as u64,
322                magic_byte.metered_size_varlen as usize,
323            );
324            match self.record {
325                Record::Command(record) => record.encode_into(buf),
326                Record::Envelope(record) => record.encode_into(buf),
327            }
328        }
329    }
330
331    fn legacy_plaintext_bytes(record: &Record) -> Bytes {
332        LegacyPlaintextFrame { record }.to_bytes()
333    }
334
335    fn semantic_metered_size(record: &Record) -> usize {
336        let (headers, body) = record.clone().into_parts();
337        8 + (2 * headers.len())
338            + headers
339                .iter()
340                .map(|header| header.name.len() + header.value.len())
341                .sum::<usize>()
342            + body.len()
343    }
344
345    fn bytes_strategy(allow_empty: bool) -> impl Strategy<Value = Bytes> {
346        prop_oneof![
347            prop::collection::vec(any::<u8>(), (if allow_empty { 0 } else { 1 })..10)
348                .prop_map(Bytes::from),
349            prop::collection::vec(any::<u8>(), 100..1000).prop_map(Bytes::from),
350        ]
351    }
352
353    fn header_strategy() -> impl Strategy<Value = Header> {
354        (bytes_strategy(false), bytes_strategy(true))
355            .prop_map(|(name, value)| Header { name, value })
356    }
357
358    fn headers_strategy() -> impl Strategy<Value = Vec<Header>> {
359        prop_oneof![
360            prop::collection::vec(header_strategy(), 0..10),
361            prop::collection::vec(header_strategy(), 200..300),
362        ]
363    }
364
365    fn command_strategy() -> impl Strategy<Value = CommandRecord> {
366        prop_oneof![
367            proptest::string::string_regex(&format!("[ -~]{{0,{MAX_FENCING_TOKEN_LENGTH}}}"))
368                .unwrap()
369                .prop_map(|token| CommandRecord::Fence(token.parse().unwrap())),
370            any::<SeqNum>().prop_map(CommandRecord::Trim),
371        ]
372    }
373
374    proptest!(
375        #![proptest_config(ProptestConfig::with_cases(10))]
376        #[test]
377        fn roundtrip_envelope(
378            seq_num in any::<SeqNum>(),
379            timestamp in any::<Timestamp>(),
380            headers in headers_strategy(),
381            body in bytes_strategy(true),
382        ) {
383            let record = Record::try_from_parts(headers, body).unwrap();
384            let metered_record: Metered<Record> = record.clone().into();
385            let encoded_record =
386                encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
387            let legacy_record = legacy_plaintext_bytes(&record);
388            prop_assert_eq!(encoded_record.as_ref(), legacy_record.as_ref());
389            let decoded_record = decode_record(encoded_record).unwrap();
390            prop_assert_eq!(&decoded_record, &metered_record);
391            let sequenced = decoded_record.sequenced(StreamPosition { seq_num, timestamp });
392            let (position, sequenced_record) = sequenced.into_parts();
393            assert_eq!(position, StreamPosition { seq_num, timestamp });
394            assert_eq!(sequenced_record.into_inner(), record);
395        }
396    );
397
398    proptest!(
399        #![proptest_config(ProptestConfig::with_cases(10))]
400        #[test]
401        fn roundtrip_metered(
402            headers in headers_strategy(),
403            body in bytes_strategy(true),
404        ) {
405            let record = Record::try_from_parts(headers.clone(), body.clone()).unwrap();
406            let encoded_record =
407                encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
408            assert_eq!(record.metered_size(), semantic_metered_size(&record));
409            assert_eq!(record.metered_size(), try_metered_size(encoded_record.as_ref()).unwrap() as usize);
410        }
411    );
412
413    proptest!(
414        #![proptest_config(ProptestConfig::with_cases(10))]
415        #[test]
416        fn roundtrip_command_metered(command in command_strategy()) {
417            let record = Record::Command(command);
418            let encoded_record =
419                encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref());
420            let expected_metered = semantic_metered_size(&record);
421            let wire_metered = try_metered_size(encoded_record.as_ref()).unwrap() as usize;
422            let decoded_record = decode_record(encoded_record).unwrap();
423
424            assert_eq!(record.metered_size(), expected_metered);
425            assert_eq!(record.metered_size(), wire_metered);
426            prop_assert_eq!(decoded_record, Metered::<Record>::from(record));
427        }
428    );
429
430    #[test]
431    fn roundtrip_encrypted_stored_record() {
432        let mut encoded = BytesMut::with_capacity(1 + 12 + 10 + 16);
433        encoded.put_u8(0x02);
434        encoded.put_slice(b"0123456789ab");
435        encoded.put_slice(b"ciphertext");
436        encoded.put_slice(b"0123456789abcdef");
437        let record =
438            StoredRecord::encrypted(EncryptedRecord::try_from(encoded.freeze()).unwrap(), 123);
439        let metered_record = record.clone().metered();
440        let encoded_record = encode_stored_record(metered_record.as_ref());
441        let decoded_record = decode_stored_record(encoded_record).unwrap();
442        assert_eq!(decoded_record, metered_record);
443    }
444
445    #[rstest]
446    #[case(0b0000_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 1})]
447    #[case(0b0001_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 3})]
448    #[case(0b0000_0011, MagicByte { record_type: RecordType::EncryptedEnvelope, metered_size_varlen: 1})]
449    #[case(0b0000_1001, MagicByte { record_type: RecordType::Command, metered_size_varlen: 2})]
450    fn valid_magic_byte_parsing(#[case] as_u8: u8, #[case] magic_byte: MagicByte) {
451        assert_eq!(MagicByte::try_from(as_u8).unwrap(), magic_byte);
452        assert_eq!(u8::from(magic_byte), as_u8);
453    }
454
455    #[rstest]
456    #[case(0b0000_1101, "invalid record type ordinal")]
457    #[case(0b0001_1001, "invalid metered_size_varlen")]
458    fn invalid_magic_byte_parsing(#[case] as_u8: u8, #[case] expected: &'static str) {
459        assert_eq!(MagicByte::try_from(as_u8), Err(expected));
460    }
461
462    #[test]
463    fn metered_record_truncated_after_magic_byte_returns_error() {
464        // Magic byte: Envelope (0b0000_0010), metered_size_varlen = 1 -> expects 1 more byte.
465        let truncated = Bytes::from_static(&[0b0000_0010]);
466        let result = decode_record(truncated);
467        assert_eq!(
468            result,
469            Err(StoredRecordDecodeError::Truncated("MeteredSize"))
470        );
471    }
472
473    #[rstest]
474    #[case::envelope_empty_headers(
475        StoredRecord::from(Record::Envelope(
476            EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap()
477        )),
478        &[
479            0x02, 0x0d, // envelope record, metered size 13
480            0x00, // no headers
481            b'h', b'e', b'l', b'l', b'o',
482        ],
483    )]
484    #[case::envelope_with_header(
485        StoredRecord::from(Record::Envelope(
486            EnvelopeRecord::try_from_parts(
487                vec![Header {
488                    name: Bytes::from_static(b"k"),
489                    value: Bytes::from_static(b"v"),
490                }],
491                Bytes::from_static(b"b"),
492            ).unwrap()
493        )),
494        &[
495            0x02, 0x0d, // envelope record, metered size 13
496            0x10, 0x01, // one header, one byte for num headers
497            0x01, b'k',
498            0x01, b'v',
499            b'b',
500        ],
501    )]
502    #[case::command_trim(
503        StoredRecord::from(Record::Command(CommandRecord::Trim(42))),
504        &[
505            0x01, 0x16, // command record, metered size 22
506            0x01, // trim command ordinal
507            0x00, 0x00, 0x00, 0x00,
508            0x00, 0x00, 0x00, 0x2a,
509        ],
510    )]
511    fn stored_record_encoding_matches_existing_wire_format(
512        #[case] record: StoredRecord,
513        #[case] expected: &[u8],
514    ) {
515        let metered_record = record.clone().metered();
516        let encoded_size = stored_record_encoded_size(metered_record.as_ref());
517        let encoded = encode_stored_record(metered_record.as_ref());
518        let mut encoded_into = BytesMut::with_capacity(encoded_size);
519        encode_stored_record_into(metered_record.as_ref(), &mut encoded_into);
520
521        assert_eq!(encoded.len(), encoded_size);
522        assert_eq!(encoded.as_ref(), expected);
523        assert_eq!(encoded_into.as_ref(), expected);
524        assert_eq!(decode_stored_record(encoded).unwrap().into_inner(), record);
525    }
526
527    #[test]
528    fn encrypted_stored_record_encoding_matches_existing_wire_format() {
529        let encrypted_payload = Bytes::from_static(b"\x020123456789abciphertext0123456789abcdef");
530        let record = StoredRecord::encrypted(
531            EncryptedRecord::try_from(encrypted_payload.clone()).unwrap(),
532            123,
533        );
534
535        let encoded = encode_stored_record(record.clone().metered().as_ref());
536
537        assert_eq!(
538            encoded.as_ref(),
539            [&[0x03, 0x7b], encrypted_payload.as_ref()].concat()
540        );
541        assert_eq!(decode_stored_record(encoded).unwrap().into_inner(), record);
542    }
543
544    #[test]
545    fn decode_stored_record_preserves_encoded_metered_size_prefix() {
546        let record = StoredRecord::from(Record::Envelope(
547            EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
548        ));
549        let mut encoded = encode_stored_record(record.clone().metered().as_ref()).to_vec();
550        encoded[1] = 99;
551
552        let decoded = decode_stored_record(Bytes::from(encoded)).unwrap();
553
554        assert_eq!(decoded.metered_size(), 99);
555        assert_eq!(decoded.into_inner(), record);
556    }
557
558    #[test]
559    fn decode_record_preserves_encoded_metered_size_prefix() {
560        let record = Record::Envelope(
561            EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap(),
562        );
563        let mut encoded =
564            encode_stored_record(StoredRecord::from(record.clone()).metered().as_ref()).to_vec();
565        encoded[1] = 99;
566
567        let decoded = decode_record(Bytes::from(encoded)).unwrap();
568
569        assert_eq!(decoded.metered_size(), 99);
570        assert_eq!(decoded.into_inner(), record);
571    }
572
573    #[test]
574    fn test_read_varint() {
575        let data = [0u8, 0, 0, 1, 0, 0, 0];
576
577        assert_eq!(read_vint_u32_be(&data[..4]), 1u32);
578        assert_eq!(read_vint_u32_be(&data[2..5]), 2u32.pow(8));
579        assert_eq!(read_vint_u32_be(&data[2..6]), 2u32.pow(16));
580        assert_eq!(read_vint_u32_be(&data[3..]), 2u32.pow(24));
581    }
582}