s2_common/record/
mod.rs

1mod batcher;
2mod command;
3mod envelope;
4mod fencing;
5mod metering;
6
7pub use batcher::{RecordBatch, RecordBatcher};
8use bytes::{Buf, BufMut, Bytes, BytesMut};
9pub use command::CommandRecord;
10use command::{CommandOp, CommandPayloadError};
11use enum_ordinalize::Ordinalize;
12pub use envelope::EnvelopeRecord;
13use envelope::HeaderValidationError;
14pub use fencing::{FencingToken, FencingTokenTooLongError, MAX_FENCING_TOKEN_LENGTH};
15pub use metering::{Metered, MeteredSize};
16
17use crate::deep_size::DeepSize;
18
19pub type SeqNum = u64;
20pub type NonZeroSeqNum = std::num::NonZeroU64;
21pub type Timestamp = u64;
22
23#[derive(Debug, PartialEq, Eq, Clone, Copy)]
24pub struct StreamPosition {
25    pub seq_num: SeqNum,
26    pub timestamp: Timestamp,
27}
28
29impl StreamPosition {
30    pub const MIN: StreamPosition = StreamPosition {
31        seq_num: SeqNum::MIN,
32        timestamp: Timestamp::MIN,
33    };
34}
35
36impl std::fmt::Display for StreamPosition {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "{} @ {}", self.seq_num, self.timestamp)
39    }
40}
41
42impl DeepSize for StreamPosition {
43    fn deep_size(&self) -> usize {
44        self.seq_num.deep_size() + self.timestamp.deep_size()
45    }
46}
47
48#[derive(Debug, Clone, PartialEq, thiserror::Error)]
49pub enum InternalRecordError {
50    #[error("truncated: {0}")]
51    Truncated(&'static str),
52    #[error("invalid value [{0}]: {1}")]
53    InvalidValue(&'static str, &'static str),
54}
55
56/// `impl Display` can be safely returned to the client without leaking internal details.
57#[derive(Debug, PartialEq, thiserror::Error)]
58pub enum PublicRecordError {
59    #[error("unknown command")]
60    UnknownCommand,
61    #[error("invalid `{0}` command: {1}")]
62    CommandPayload(CommandOp, CommandPayloadError),
63    #[error("invalid header: {0}")]
64    Header(#[from] HeaderValidationError),
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct Header {
69    pub name: Bytes,
70    pub value: Bytes,
71}
72
73impl DeepSize for Header {
74    fn deep_size(&self) -> usize {
75        self.name.len() + self.value.len()
76    }
77}
78
79#[derive(Clone, Copy, Debug, PartialEq, Ordinalize)]
80#[repr(u8)]
81pub enum RecordType {
82    Command = 1,
83    Envelope = 2,
84}
85
86#[derive(Copy, Clone, Debug, PartialEq)]
87pub struct MagicByte {
88    pub record_type: RecordType,
89    pub metered_size_varlen: u8,
90}
91
92/// Read bytes to u32 in big-endian order.
93fn read_vint_u32_be(bytes: &[u8]) -> u32 {
94    if bytes.len() > size_of::<u32>() || bytes.is_empty() {
95        panic!("invalid variable int bytes = {} len", bytes.len())
96    }
97    let mut acc: u32 = 0;
98    for &byte in bytes {
99        acc = (acc << 8) | byte as u32;
100    }
101    acc
102}
103
104pub fn try_metered_size(record_bytes: &[u8]) -> Result<u32, &'static str> {
105    let magic_byte_u8 = *record_bytes.first().ok_or("byte range is empty")?;
106    let magic_byte = MagicByte::try_from(magic_byte_u8)?;
107    Ok(read_vint_u32_be(
108        record_bytes
109            .get(1..1 + magic_byte.metered_size_varlen as usize)
110            .ok_or("byte range doesn't include bytes for metered size")?,
111    ))
112}
113
114impl MeteredSize for Record {
115    fn metered_size(&self) -> usize {
116        8 + (match self {
117            Record::Command(command) => 2 + command.op().to_id().len() + command.payload().len(),
118            Record::Envelope(envelope) => {
119                (2 * envelope.headers().len())
120                    + envelope.headers().deep_size()
121                    + envelope.body().len()
122            }
123        })
124    }
125}
126
127impl TryFrom<u8> for MagicByte {
128    type Error = &'static str;
129
130    fn try_from(value: u8) -> Result<Self, Self::Error> {
131        let record_type =
132            RecordType::from_ordinal(value & 0b111).ok_or("invalid record type ordinal")?;
133        Ok(Self {
134            record_type,
135            metered_size_varlen: match (value >> 3) & 0b11 {
136                0 => 1u8,
137                1 => 2u8,
138                2 => 3u8,
139                _ => Err("invalid metered_size_varlen")?,
140            },
141        })
142    }
143}
144
145impl From<MagicByte> for u8 {
146    fn from(value: MagicByte) -> Self {
147        ((value.metered_size_varlen - 1) << 3) | value.record_type as u8
148    }
149}
150
151#[derive(Debug, PartialEq, Eq, Clone)]
152pub enum Record {
153    Command(CommandRecord),
154    Envelope(EnvelopeRecord),
155}
156
157impl DeepSize for Record {
158    fn deep_size(&self) -> usize {
159        match self {
160            Self::Command(c) => c.deep_size(),
161            Self::Envelope(e) => e.deep_size(),
162        }
163    }
164}
165
166impl Record {
167    pub fn try_from_parts(headers: Vec<Header>, body: Bytes) -> Result<Self, PublicRecordError> {
168        if headers.len() == 1 {
169            let header = &headers[0];
170            if header.name.is_empty() {
171                let op = CommandOp::from_id(header.value.as_ref())
172                    .ok_or(PublicRecordError::UnknownCommand)?;
173                let command_record = CommandRecord::try_from_parts(op, body.as_ref())
174                    .map_err(|e| PublicRecordError::CommandPayload(op, e))?;
175                return Ok(Self::Command(command_record));
176            }
177        }
178        let envelope = EnvelopeRecord::try_from_parts(headers, body)?;
179        Ok(Self::Envelope(envelope))
180    }
181
182    pub fn sequenced(self, position: StreamPosition) -> SequencedRecord {
183        SequencedRecord {
184            position,
185            record: self,
186        }
187    }
188
189    pub fn into_parts(self) -> (Vec<Header>, Bytes) {
190        match self {
191            Record::Envelope(e) => e.into_parts(),
192            Record::Command(c) => {
193                let op = c.op();
194                let header = Header {
195                    name: Bytes::new(),
196                    value: Bytes::from_static(op.to_id()),
197                };
198                (vec![header], c.payload())
199            }
200        }
201    }
202}
203
204pub fn decode_if_command_record(
205    record: &[u8],
206) -> Result<Option<CommandRecord>, InternalRecordError> {
207    if record.is_empty() {
208        return Err(InternalRecordError::Truncated("MagicByte"));
209    }
210    let magic_byte = MagicByte::try_from(record[0])
211        .map_err(|msg| InternalRecordError::InvalidValue("MagicByte", msg))?;
212    match magic_byte.record_type {
213        RecordType::Command => {
214            let offset = 1 + magic_byte.metered_size_varlen as usize;
215            if record.len() < offset {
216                return Err(InternalRecordError::Truncated("MeteredSize"));
217            }
218            Ok(Some(CommandRecord::try_from(&record[offset..])?))
219        }
220        RecordType::Envelope => Ok(None),
221    }
222}
223
224pub trait Encodable {
225    fn to_bytes(&self) -> Bytes {
226        let expected_size = self.encoded_size();
227        let mut buf = BytesMut::with_capacity(expected_size);
228        self.encode_into(&mut buf);
229        assert_eq!(buf.len(), expected_size, "no reallocation");
230        buf.freeze()
231    }
232
233    fn encoded_size(&self) -> usize;
234
235    fn encode_into(&self, buf: &mut impl BufMut);
236}
237
238impl Encodable for Metered<&Record> {
239    fn encoded_size(&self) -> usize {
240        1 + self.magic_byte().metered_size_varlen as usize
241            + match &**self {
242                Record::Command(r) => r.encoded_size(),
243                Record::Envelope(r) => r.encoded_size(),
244            }
245    }
246
247    fn encode_into(&self, buf: &mut impl BufMut) {
248        let magic_byte = self.magic_byte();
249        buf.put_u8(magic_byte.into());
250        buf.put_uint(
251            self.metered_size() as u64,
252            magic_byte.metered_size_varlen as usize,
253        );
254        match &**self {
255            Record::Command(r) => r.encode_into(buf),
256            Record::Envelope(r) => r.encode_into(buf),
257        }
258    }
259}
260
261#[derive(Debug, Clone, PartialEq, Eq)]
262pub struct SequencedRecord {
263    pub position: StreamPosition,
264    pub record: Record,
265}
266
267impl MeteredSize for SequencedRecord {
268    fn metered_size(&self) -> usize {
269        self.record.metered_size()
270    }
271}
272
273impl DeepSize for SequencedRecord {
274    fn deep_size(&self) -> usize {
275        self.position.deep_size() + self.record.deep_size()
276    }
277}
278
279impl Metered<Record> {
280    pub fn sequenced(self, position: StreamPosition) -> Metered<SequencedRecord> {
281        Metered {
282            size: self.metered_size(),
283            inner: self.inner.sequenced(position),
284        }
285    }
286}
287
288impl Metered<&Record> {
289    fn magic_byte(&self) -> MagicByte {
290        let metered_size = self.metered_size();
291        let metered_size_varlen = 8 - (metered_size.leading_zeros() / 8) as u8;
292        if metered_size_varlen > 3 {
293            panic!("illegal metered size varlen {metered_size} for record")
294        }
295        let record_type = match self.inner {
296            Record::Command(_) => RecordType::Command,
297            Record::Envelope(_) => RecordType::Envelope,
298        };
299        MagicByte {
300            record_type,
301            metered_size_varlen,
302        }
303    }
304}
305
306impl TryFrom<Bytes> for Metered<Record> {
307    type Error = InternalRecordError;
308
309    fn try_from(mut buf: Bytes) -> Result<Self, Self::Error> {
310        if buf.is_empty() {
311            return Err(InternalRecordError::Truncated("MagicByte"));
312        }
313        let magic_byte = MagicByte::try_from(buf.get_u8())
314            .map_err(|msg| InternalRecordError::InvalidValue("MagicByte", msg))?;
315
316        let metered_size = buf.get_uint(magic_byte.metered_size_varlen as usize) as usize;
317
318        Ok(Self {
319            size: metered_size,
320            inner: match magic_byte.record_type {
321                RecordType::Command => Record::Command(CommandRecord::try_from(buf.as_ref())?),
322                RecordType::Envelope => Record::Envelope(EnvelopeRecord::try_from(buf)?),
323            },
324        })
325    }
326}
327
328impl Metered<SequencedRecord> {
329    pub fn parts(&self) -> (StreamPosition, Metered<&Record>) {
330        (
331            self.position,
332            Metered {
333                size: self.size,
334                inner: &self.inner.record,
335            },
336        )
337    }
338
339    pub fn into_parts(self) -> (StreamPosition, Metered<Record>) {
340        (
341            self.position,
342            Metered {
343                size: self.size,
344                inner: self.inner.record,
345            },
346        )
347    }
348}
349
350#[cfg(test)]
351mod test {
352    use proptest::prelude::*;
353    use rstest::rstest;
354
355    use super::*;
356
357    fn bytes_strategy(allow_empty: bool) -> impl Strategy<Value = Bytes> {
358        prop_oneof![
359            prop::collection::vec(any::<u8>(), (if allow_empty { 0 } else { 1 })..10)
360                .prop_map(Bytes::from),
361            prop::collection::vec(any::<u8>(), 100..1000).prop_map(Bytes::from),
362        ]
363    }
364
365    fn header_strategy() -> impl Strategy<Value = Header> {
366        (bytes_strategy(false), bytes_strategy(true))
367            .prop_map(|(name, value)| Header { name, value })
368    }
369
370    fn headers_strategy() -> impl Strategy<Value = Vec<Header>> {
371        prop_oneof![
372            prop::collection::vec(header_strategy(), 0..10),
373            prop::collection::vec(header_strategy(), 200..300),
374        ]
375    }
376
377    proptest!(
378        #![proptest_config(ProptestConfig::with_cases(10))]
379        #[test]
380        fn roundtrip_envelope(
381            seq_num in any::<SeqNum>(),
382            timestamp in any::<Timestamp>(),
383            headers in headers_strategy(),
384            body in bytes_strategy(true),
385        ) {
386            let record = Record::try_from_parts(headers, body).unwrap();
387            let metered_record: Metered<Record> = record.clone().into();
388            let encoded_record = metered_record.as_ref().to_bytes();
389            let decoded_record = Metered::try_from(encoded_record).unwrap();
390            prop_assert_eq!(&decoded_record, &metered_record);
391            let sequenced = decoded_record.sequenced(StreamPosition { seq_num, timestamp });
392            assert_eq!(sequenced.position, StreamPosition {seq_num, timestamp});
393            assert_eq!(sequenced.record, record);
394        }
395    );
396
397    proptest!(
398        #![proptest_config(ProptestConfig::with_cases(10))]
399        #[test]
400        fn roundtrip_metered(
401            headers in headers_strategy(),
402            body in bytes_strategy(true),
403        ) {
404            let record = Record::try_from_parts(headers.clone(), body.clone()).unwrap();
405            let encoded_record = Metered::from(&record).to_bytes();
406            assert_eq!(record.metered_size(), try_metered_size(encoded_record.as_ref()).unwrap() as usize);
407        }
408    );
409
410    #[test]
411    fn empty_header_name_solo() {
412        let headers = vec![Header {
413            name: Bytes::new(),
414            value: Bytes::from("hi"),
415        }];
416        let body = Bytes::from("hello");
417        assert_eq!(
418            Record::try_from_parts(headers, body),
419            Err(PublicRecordError::UnknownCommand)
420        );
421    }
422
423    #[test]
424    fn empty_header_name_among_others() {
425        let headers = vec![
426            Header {
427                name: Bytes::from("boku"),
428                value: Bytes::from("hi"),
429            },
430            Header {
431                name: Bytes::new(),
432                value: Bytes::from("hi"),
433            },
434        ];
435        let body = Bytes::from("hello");
436        assert_eq!(
437            Record::try_from_parts(headers, body),
438            Err(PublicRecordError::Header(HeaderValidationError::NameEmpty))
439        );
440    }
441
442    #[rstest]
443    #[case::fence_empty(b"fence", b"")]
444    #[case::fence_uuid(b"fence", b"my-special-uuid")]
445    #[should_panic(expected = "FencingTokenTooLongError(49)")]
446    #[case::fence_too_long(b"fence", b"toolongtoolongtoolongtoolongtoolongtoolongtoolong")]
447    #[case::trim_0(b"trim", b"\x00\x00\x00\x00\x00\x00\x00\x00")]
448    #[should_panic(expected = "TrimPointSize(0)")]
449    #[case::trim_empty(b"trim", b"")]
450    #[should_panic(expected = "TrimPointSize(9)")]
451    #[case::trim_overflow(b"trim", b"\x00\x00\x00\x00\x00\x00\x00\x00\x00")]
452    fn command_records(#[case] op: &'static [u8], #[case] payload: &'static [u8]) {
453        let headers = vec![Header {
454            name: Bytes::new(),
455            value: Bytes::from_static(op),
456        }];
457        let body = Bytes::from_static(payload);
458        let record = Record::try_from_parts(headers.clone(), body.clone()).unwrap();
459        let record_metered = record.metered_size();
460        match &record {
461            Record::Command(cmd) => {
462                assert_eq!(cmd.op().to_id(), op);
463                assert_eq!(cmd.payload().as_ref(), payload);
464            }
465            Record::Envelope(e) => panic!("Command expected, got Envelope: {e:?}"),
466        }
467        let sequenced_record = record.sequenced(StreamPosition {
468            seq_num: 42,
469            timestamp: 100_000,
470        });
471        let sequenced_metered = sequenced_record.metered_size();
472        assert_eq!(record_metered, sequenced_metered);
473        assert_eq!(
474            sequenced_record.position,
475            StreamPosition {
476                seq_num: 42,
477                timestamp: 100_000,
478            }
479        );
480        assert_eq!(
481            sequenced_record.record,
482            Record::try_from_parts(headers, body).unwrap()
483        );
484    }
485
486    #[rstest]
487    #[case(0b0000_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 1})]
488    #[case(0b0001_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 3})]
489    #[case(0b0000_1001, MagicByte { record_type: RecordType::Command, metered_size_varlen: 2})]
490    #[should_panic(expected = "invalid record type ordinal")]
491    #[case(0b0000_1101, MagicByte { record_type: RecordType::Command, metered_size_varlen: 2})]
492    fn magic_byte_parsing(#[case] as_u8: u8, #[case] magic_byte: MagicByte) {
493        assert_eq!(MagicByte::try_from(as_u8).unwrap(), magic_byte);
494        assert_eq!(u8::from(magic_byte), as_u8);
495    }
496
497    #[test]
498    fn test_read_varint() {
499        let data = [0u8, 0, 0, 1, 0, 0, 0];
500
501        assert_eq!(read_vint_u32_be(&data[..4]), 1u32);
502        assert_eq!(read_vint_u32_be(&data[2..5]), 2u32.pow(8));
503        assert_eq!(read_vint_u32_be(&data[2..6]), 2u32.pow(16));
504        assert_eq!(read_vint_u32_be(&data[3..]), 2u32.pow(24));
505    }
506}