Skip to main content

s2_storage/record/
codec.rs

1//! Stored record body encoding.
2//!
3//! Outer stored-record framing lives in `framing.rs`; this module encodes and
4//! decodes the bytes that follow the magic byte and metered-size prefix.
5//!
6//! Command body layout:
7//!
8//! ```text
9//! +-----------------+-----------------+
10//! | command_op: u8  | payload: bytes  |
11//! +-----------------+-----------------+
12//! ```
13//!
14//! Envelope body layout:
15//!
16//! ```text
17//! +-----------------+------------------------------+-------------------+
18//! | header_flag: u8 | num_headers: N bytes, if any | headers...        |
19//! +-----------------+------------------------------+-------------------+
20//! | body: remaining bytes                                              |
21//! +--------------------------------------------------------------------+
22//!
23//! header_flag:
24//!   bits 7..6: reserved, must be 0
25//!   bits 5..4: num_headers byte width, where 0 means no headers
26//!   bits 3..2: header name length byte width minus 1
27//!   bits 1..0: header value length byte width minus 1
28//!
29//! each header:
30//! +------------------------+------------+-------------------------+-------------+
31//! | name_len: name_width   | name bytes | value_len: value_width  | value bytes |
32//! +------------------------+------------+-------------------------+-------------+
33//! ```
34//!
35//! Variable-width integers are big-endian. Header count uses 1-3 bytes when
36//! present; header name/value lengths use 1-4 bytes. When the header-count
37//! width is 0, the decoder treats the record as having no headers and the
38//! remaining bytes are the body.
39
40use std::num::NonZeroU8;
41
42use bytes::{Buf, BufMut, Bytes, BytesMut};
43use s2_common::record::{
44    CommandOp, CommandPayloadError, CommandRecord, EnvelopeRecord, Header, HeaderValidationError,
45    RecordPartsError,
46};
47
48#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
49pub enum StoredRecordDecodeError {
50    #[error("truncated: {0}")]
51    Truncated(&'static str),
52    #[error("invalid value [{0}]: {1}")]
53    InvalidValue(&'static str, &'static str),
54}
55
56pub(crate) trait WireEncode {
57    fn to_bytes(&self) -> Bytes {
58        let expected_size = self.encoded_size();
59        let mut buf = BytesMut::with_capacity(expected_size);
60        self.encode_into(&mut buf);
61        assert_eq!(buf.len(), expected_size, "no reallocation");
62        buf.freeze()
63    }
64
65    fn encoded_size(&self) -> usize;
66
67    fn encode_into(&self, buf: &mut impl BufMut);
68}
69
70const COMMAND_ORDINAL_FENCE: u8 = 0;
71const COMMAND_ORDINAL_TRIM: u8 = 1;
72
73fn command_op_ordinal(op: CommandOp) -> u8 {
74    match op {
75        CommandOp::Fence => COMMAND_ORDINAL_FENCE,
76        CommandOp::Trim => COMMAND_ORDINAL_TRIM,
77    }
78}
79
80fn command_op_from_ordinal(ordinal: u8) -> Option<CommandOp> {
81    match ordinal {
82        COMMAND_ORDINAL_FENCE => Some(CommandOp::Fence),
83        COMMAND_ORDINAL_TRIM => Some(CommandOp::Trim),
84        _ => None,
85    }
86}
87
88impl From<CommandPayloadError> for StoredRecordDecodeError {
89    fn from(e: CommandPayloadError) -> Self {
90        match e {
91            CommandPayloadError::InvalidUtf8(_) => StoredRecordDecodeError::InvalidValue(
92                "CommandPayload",
93                "fencing token not valid utf8",
94            ),
95            CommandPayloadError::FencingTokenTooLong(_) => {
96                StoredRecordDecodeError::InvalidValue("CommandPayload", "fencing token too long")
97            }
98            CommandPayloadError::TrimPointSize(_) => {
99                StoredRecordDecodeError::InvalidValue("CommandPayload", "trim point size")
100            }
101        }
102    }
103}
104
105impl WireEncode for CommandRecord {
106    fn encoded_size(&self) -> usize {
107        1 + match self {
108            CommandRecord::Fence(token) => token.len(),
109            CommandRecord::Trim(trim_point) => size_of_val(trim_point),
110        }
111    }
112
113    fn encode_into(&self, buf: &mut impl BufMut) {
114        buf.put_u8(command_op_ordinal(self.op()));
115        match self {
116            CommandRecord::Fence(token) => {
117                buf.put_slice(token.as_bytes());
118            }
119            CommandRecord::Trim(trim_point) => {
120                buf.put_u64(*trim_point);
121            }
122        }
123    }
124}
125
126pub(super) fn decode_command_record(
127    record: &[u8],
128) -> Result<CommandRecord, StoredRecordDecodeError> {
129    if record.is_empty() {
130        return Err(StoredRecordDecodeError::Truncated("CommandOrdinal"));
131    }
132    let op = command_op_from_ordinal(record[0]).ok_or(StoredRecordDecodeError::InvalidValue(
133        "CommandOrdinal",
134        "unknown",
135    ))?;
136    CommandRecord::try_from_parts(op, &record[1..]).map_err(Into::into)
137}
138
139const EMPTY_HEADER_FLAG: HeaderFlag = HeaderFlag {
140    num_headers_length_bytes: 0,
141    name_length_bytes: NonZeroU8::new(1).unwrap(),
142    value_length_bytes: NonZeroU8::new(1).unwrap(),
143};
144
145/// A compact per-envelope header layout byte.
146///
147/// Header count width can be zero, which means there are no headers and no
148/// encoded count follows. Name and value length widths are stored as width - 1
149/// because valid header length fields are always 1-4 bytes.
150#[derive(Debug, PartialEq, Eq, Clone, Copy)]
151struct HeaderFlag {
152    num_headers_length_bytes: u8,
153    name_length_bytes: NonZeroU8,
154    value_length_bytes: NonZeroU8,
155}
156
157impl HeaderFlag {
158    const RESERVED_MASK: u8 = 0b1100_0000;
159    const NUM_HEADERS_LENGTH_MASK: u8 = 0b0011_0000;
160    const NUM_HEADERS_LENGTH_SHIFT: u8 = 4;
161    const NAME_LENGTH_MASK: u8 = 0b0000_1100;
162    const NAME_LENGTH_SHIFT: u8 = 2;
163    const VALUE_LENGTH_MASK: u8 = 0b0000_0011;
164}
165
166impl From<HeaderFlag> for u8 {
167    fn from(value: HeaderFlag) -> Self {
168        (value.num_headers_length_bytes << HeaderFlag::NUM_HEADERS_LENGTH_SHIFT)
169            | ((value.name_length_bytes.get() - 1) << HeaderFlag::NAME_LENGTH_SHIFT)
170            | (value.value_length_bytes.get() - 1)
171    }
172}
173
174impl TryFrom<u8> for HeaderFlag {
175    type Error = &'static str;
176
177    fn try_from(value: u8) -> Result<Self, Self::Error> {
178        if (value & HeaderFlag::RESERVED_MASK) != 0 {
179            return Err("reserved bit set");
180        }
181        Ok(Self {
182            num_headers_length_bytes: (value & HeaderFlag::NUM_HEADERS_LENGTH_MASK)
183                >> HeaderFlag::NUM_HEADERS_LENGTH_SHIFT,
184            name_length_bytes: NonZeroU8::new(
185                ((value & HeaderFlag::NAME_LENGTH_MASK) >> HeaderFlag::NAME_LENGTH_SHIFT) + 1,
186            )
187            .unwrap(),
188            value_length_bytes: NonZeroU8::new((value & HeaderFlag::VALUE_LENGTH_MASK) + 1)
189                .unwrap(),
190        })
191    }
192}
193
194const EMPTY_HEADERS_ENCODING_INFO: EncodingInfo = EncodingInfo {
195    headers_total_bytes: 0,
196    flag: EMPTY_HEADER_FLAG,
197};
198
199#[derive(Debug, PartialEq, Eq, Clone, Copy)]
200struct EncodingInfo {
201    headers_total_bytes: usize,
202    flag: HeaderFlag,
203}
204
205impl EncodingInfo {
206    fn for_record(record: &EnvelopeRecord) -> Self {
207        Self::from_header_sizing(
208            record.headers().len(),
209            record.headers_total_bytes(),
210            record.header_name_length_width_bytes(),
211            record.header_value_length_width_bytes(),
212        )
213        .expect("envelope record headers should be validated")
214    }
215
216    fn from_header_sizing(
217        header_count: usize,
218        headers_total_bytes: usize,
219        name_length_width_bytes: usize,
220        value_length_width_bytes: usize,
221    ) -> Result<Self, HeaderValidationError> {
222        fn size_bytes_header_count(count: u64) -> Result<u8, HeaderValidationError> {
223            let size = 8 - count.leading_zeros() / 8;
224            if size <= 3 {
225                Ok(size as u8)
226            } else {
227                Err(HeaderValidationError::TooMany)
228            }
229        }
230
231        fn header_part_width(width: usize) -> Result<NonZeroU8, HeaderValidationError> {
232            let width = u8::try_from(width).map_err(|_| HeaderValidationError::TooLong)?;
233            if (1..=4).contains(&width) {
234                Ok(NonZeroU8::new(width).expect("header part width should be non-zero"))
235            } else {
236                Err(HeaderValidationError::TooLong)
237            }
238        }
239
240        if header_count == 0 {
241            return Ok(EMPTY_HEADERS_ENCODING_INFO);
242        }
243
244        let num_headers_length_bytes = size_bytes_header_count(header_count as u64)?;
245        let name_length_bytes = header_part_width(name_length_width_bytes)?;
246        let value_length_bytes = header_part_width(value_length_width_bytes)?;
247
248        Ok(Self {
249            headers_total_bytes,
250            flag: HeaderFlag {
251                num_headers_length_bytes,
252                name_length_bytes,
253                value_length_bytes,
254            },
255        })
256    }
257}
258
259impl WireEncode for EnvelopeRecord {
260    fn encoded_size(&self) -> usize {
261        let encoding_info = EncodingInfo::for_record(self);
262        1 + encoding_info.flag.num_headers_length_bytes as usize
263            + self.headers().len()
264                * (encoding_info.flag.name_length_bytes.get() as usize
265                    + encoding_info.flag.value_length_bytes.get() as usize)
266            + encoding_info.headers_total_bytes
267            + self.body().len()
268    }
269
270    fn encode_into(&self, buf: &mut impl BufMut) {
271        let encoding_info = EncodingInfo::for_record(self);
272        buf.put_u8(encoding_info.flag.into());
273        buf.put_uint(
274            self.headers().len() as u64,
275            encoding_info.flag.num_headers_length_bytes as usize,
276        );
277        for Header { name, value } in self.headers() {
278            buf.put_uint(
279                name.len() as u64,
280                encoding_info.flag.name_length_bytes.get() as usize,
281            );
282            buf.put_slice(name);
283            buf.put_uint(
284                value.len() as u64,
285                encoding_info.flag.value_length_bytes.get() as usize,
286            );
287            buf.put_slice(value);
288        }
289        buf.put_slice(self.body());
290    }
291}
292
293pub(super) fn decode_envelope_record(
294    mut buf: Bytes,
295) -> Result<EnvelopeRecord, StoredRecordDecodeError> {
296    if buf.is_empty() {
297        return Err(StoredRecordDecodeError::InvalidValue(
298            "HeaderFlag",
299            "missing",
300        ));
301    }
302
303    let flag: HeaderFlag = buf
304        .get_u8()
305        .try_into()
306        .map_err(|info| StoredRecordDecodeError::InvalidValue("HeaderFlag", info))?;
307    if flag.num_headers_length_bytes == 0 {
308        return EnvelopeRecord::try_from_parts(vec![], buf).map_err(record_parts_decode_error);
309    }
310
311    let num_headers = buf
312        .try_get_uint(flag.num_headers_length_bytes as usize)
313        .map_err(|_| StoredRecordDecodeError::Truncated("NumHeaders"))?;
314    let num_headers = usize::try_from(num_headers)
315        .map_err(|_| StoredRecordDecodeError::InvalidValue("NumHeaders", "too many"))?;
316
317    let mut headers: Vec<Header> = Vec::with_capacity(num_headers);
318    for _ in 0..num_headers {
319        let name_len = buf
320            .try_get_uint(flag.name_length_bytes.get() as usize)
321            .map_err(|_| StoredRecordDecodeError::Truncated("HeaderNameLen"))?
322            as usize;
323        if name_len == 0 {
324            return Err(StoredRecordDecodeError::InvalidValue("HeaderName", "empty"));
325        }
326        if buf.remaining() < name_len {
327            return Err(StoredRecordDecodeError::Truncated("HeaderName"));
328        }
329        let name = buf.split_to(name_len);
330
331        let value_len = buf
332            .try_get_uint(flag.value_length_bytes.get() as usize)
333            .map_err(|_| StoredRecordDecodeError::Truncated("HeaderValueLen"))?
334            as usize;
335        if buf.remaining() < value_len {
336            return Err(StoredRecordDecodeError::Truncated("HeaderValue"));
337        }
338        let value = buf.split_to(value_len);
339
340        headers.push(Header { name, value })
341    }
342
343    EnvelopeRecord::try_from_parts(headers, buf).map_err(record_parts_decode_error)
344}
345
346fn record_parts_decode_error(error: RecordPartsError) -> StoredRecordDecodeError {
347    match error {
348        RecordPartsError::Header(HeaderValidationError::NameEmpty) => {
349            StoredRecordDecodeError::InvalidValue("HeaderName", "empty")
350        }
351        RecordPartsError::Header(HeaderValidationError::TooMany) => {
352            StoredRecordDecodeError::InvalidValue("NumHeaders", "too many")
353        }
354        RecordPartsError::Header(HeaderValidationError::TooLong) => {
355            StoredRecordDecodeError::InvalidValue("Header", "too long")
356        }
357        RecordPartsError::UnknownCommand | RecordPartsError::CommandPayload(_, _) => {
358            StoredRecordDecodeError::InvalidValue("EnvelopeRecord", "unexpected command record")
359        }
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use bytes::{BufMut, Bytes, BytesMut};
366    use rstest::rstest;
367    use s2_common::record::{FencingToken, FencingTokenTooLongError, SeqNum};
368
369    use super::*;
370
371    fn roundtrip_command(cmd: CommandRecord, expected_len: usize) {
372        assert_eq!(cmd.encoded_size(), expected_len);
373        let encoded = cmd.to_bytes();
374        assert_eq!(encoded.len(), expected_len);
375        assert_eq!(decode_command_record(encoded.as_ref()), Ok(cmd));
376    }
377
378    #[rstest]
379    #[case::empty("")]
380    #[case::arbit("arbitrary")]
381    #[case::full("0123456789012345")]
382    fn command_fence_roundtrip(#[case] token: &str) {
383        let cmd = CommandRecord::Fence(token.parse::<FencingToken>().unwrap());
384        roundtrip_command(cmd, 1 + token.len());
385    }
386
387    #[rstest]
388    #[case::zero(0)]
389    #[case::large(SeqNum::MAX)]
390    fn command_trim_roundtrip(#[case] trim_point: SeqNum) {
391        roundtrip_command(CommandRecord::Trim(trim_point), 1 + size_of::<SeqNum>());
392    }
393
394    #[test]
395    fn decode_invalid_command() {
396        let try_convert = |raw: &[u8]| decode_command_record(raw);
397        assert_eq!(
398            try_convert(&[]),
399            Err(StoredRecordDecodeError::Truncated("CommandOrdinal"))
400        );
401        assert_eq!(
402            try_convert(&[0xff]),
403            Err(StoredRecordDecodeError::InvalidValue(
404                "CommandOrdinal",
405                "unknown"
406            ))
407        );
408        assert_eq!(
409            try_convert(&[command_op_ordinal(CommandOp::Fence), 0xff, 0xff]),
410            Err(StoredRecordDecodeError::InvalidValue(
411                "CommandPayload",
412                "fencing token not valid utf8"
413            ))
414        );
415        assert_eq!(
416            try_convert(&[
417                command_op_ordinal(CommandOp::Fence),
418                b'0',
419                b'1',
420                b'2',
421                b'3',
422                b'4',
423                b'5',
424                b'6',
425                b'7',
426                b'8',
427                b'9',
428                b'0',
429                b'1',
430                b'2',
431                b'3',
432                b'4',
433                b'5',
434                b'6',
435                b'7',
436                b'8',
437                b'9',
438                b'0',
439                b'1',
440                b'2',
441                b'3',
442                b'4',
443                b'5',
444                b'6',
445                b'7',
446                b'8',
447                b'9',
448                b'0',
449                b'1',
450                b'2',
451                b'3',
452                b'4',
453                b'5',
454                b'6',
455                b'7',
456                b'8',
457                b'9',
458            ]),
459            Err(CommandPayloadError::FencingTokenTooLong(FencingTokenTooLongError(40)).into())
460        );
461        assert_eq!(
462            try_convert(&[command_op_ordinal(CommandOp::Trim), 0xff]),
463            Err(CommandPayloadError::TrimPointSize(1).into())
464        );
465    }
466
467    fn roundtrip_envelope_parts(headers: Vec<Header>, body: Bytes) {
468        let encoded: Bytes = EnvelopeRecord::try_from_parts(headers.clone(), body.clone())
469            .unwrap()
470            .to_bytes();
471        let decoded = decode_envelope_record(encoded).unwrap();
472        assert_eq!(decoded.headers(), headers);
473        assert_eq!(decoded.body(), &body);
474    }
475
476    #[test]
477    fn envelope_framed_with_headers() {
478        roundtrip_envelope_parts(
479            vec![
480                Header {
481                    name: Bytes::from("key_1"),
482                    value: Bytes::from("val_1"),
483                },
484                Header {
485                    name: Bytes::from("key_2"),
486                    value: Bytes::from("val_2"),
487                },
488                Header {
489                    name: Bytes::from("key_3"),
490                    value: Bytes::from("val_3"),
491                },
492                Header {
493                    name: Bytes::from("key_4"),
494                    value: Bytes::from("val_4"),
495                },
496            ],
497            Bytes::from("hello"),
498        );
499    }
500
501    #[test]
502    fn envelope_framed_no_headers() {
503        roundtrip_envelope_parts(vec![], Bytes::from("hello"));
504    }
505
506    #[test]
507    fn envelope_decode_rejects_empty_header_name() {
508        let mut encoded = BytesMut::new();
509        encoded.put_u8(
510            HeaderFlag {
511                num_headers_length_bytes: 1,
512                name_length_bytes: NonZeroU8::new(1).unwrap(),
513                value_length_bytes: NonZeroU8::new(1).unwrap(),
514            }
515            .into(),
516        );
517        encoded.put_u8(1);
518        encoded.put_u8(0);
519        encoded.put_u8(5);
520        encoded.put_slice(b"value");
521        encoded.put_slice(b"body");
522
523        assert_eq!(
524            decode_envelope_record(encoded.freeze()),
525            Err(StoredRecordDecodeError::InvalidValue("HeaderName", "empty"))
526        );
527    }
528
529    #[test]
530    fn envelope_framed_duplicate_keys() {
531        roundtrip_envelope_parts(
532            vec![
533                Header {
534                    name: Bytes::from("b"),
535                    value: Bytes::from("val_1"),
536                },
537                Header {
538                    name: Bytes::from("b"),
539                    value: Bytes::from("val_2"),
540                },
541                Header {
542                    name: Bytes::from("a"),
543                    value: Bytes::from("val_3"),
544                },
545            ],
546            Bytes::from("hello"),
547        );
548    }
549
550    #[test]
551    fn flag_ex1() {
552        assert_eq!(
553            Ok(HeaderFlag {
554                num_headers_length_bytes: 2,
555                name_length_bytes: NonZeroU8::new(1).unwrap(),
556                value_length_bytes: NonZeroU8::new(1).unwrap(),
557            }),
558            0b00100000.try_into()
559        );
560
561        let u8_repr: u8 = HeaderFlag {
562            num_headers_length_bytes: 2,
563            name_length_bytes: NonZeroU8::new(1).unwrap(),
564            value_length_bytes: NonZeroU8::new(1).unwrap(),
565        }
566        .into();
567        assert_eq!(u8_repr, 0b00100000);
568    }
569
570    #[test]
571    fn flag_ex2() {
572        assert_eq!(
573            Ok(HeaderFlag {
574                num_headers_length_bytes: 1,
575                name_length_bytes: NonZeroU8::new(1).unwrap(),
576                value_length_bytes: NonZeroU8::new(1).unwrap(),
577            }),
578            0b00010000.try_into()
579        );
580
581        let u8_repr: u8 = HeaderFlag {
582            num_headers_length_bytes: 1,
583            name_length_bytes: NonZeroU8::new(1).unwrap(),
584            value_length_bytes: NonZeroU8::new(1).unwrap(),
585        }
586        .into();
587        assert_eq!(u8_repr, 0b00010000);
588    }
589
590    #[rstest]
591    #[case::one_byte_widths(1, 1)]
592    #[case::two_byte_widths(2, 2)]
593    #[case::three_byte_widths(3, 3)]
594    #[case::four_byte_widths(4, 4)]
595    #[case::mixed_widths(2, 4)]
596    fn encoding_info_uses_cached_header_length_widths(
597        #[case] name_length_width_bytes: usize,
598        #[case] value_length_width_bytes: usize,
599    ) {
600        let encoding_info = EncodingInfo::from_header_sizing(
601            1,
602            42,
603            name_length_width_bytes,
604            value_length_width_bytes,
605        )
606        .unwrap();
607
608        assert_eq!(encoding_info.headers_total_bytes, 42);
609        assert_eq!(
610            encoding_info.flag,
611            HeaderFlag {
612                num_headers_length_bytes: 1,
613                name_length_bytes: NonZeroU8::new(name_length_width_bytes as u8).unwrap(),
614                value_length_bytes: NonZeroU8::new(value_length_width_bytes as u8).unwrap(),
615            }
616        );
617    }
618
619    #[rstest]
620    #[case::zero_name_width(0, 1)]
621    #[case::too_large_name_width(5, 1)]
622    #[case::zero_value_width(1, 0)]
623    #[case::too_large_value_width(1, 5)]
624    fn encoding_info_rejects_invalid_cached_header_length_widths(
625        #[case] name_length_width_bytes: usize,
626        #[case] value_length_width_bytes: usize,
627    ) {
628        assert_eq!(
629            EncodingInfo::from_header_sizing(
630                1,
631                42,
632                name_length_width_bytes,
633                value_length_width_bytes,
634            ),
635            Err(HeaderValidationError::TooLong)
636        );
637    }
638
639    #[test]
640    fn empty_envelope_size() {
641        assert_eq!(
642            1,
643            EnvelopeRecord::try_from_parts(vec![], Bytes::new())
644                .unwrap()
645                .to_bytes()
646                .len()
647        );
648    }
649
650    #[test]
651    fn truncated_envelope_returns_error() {
652        let record = EnvelopeRecord::try_from_parts(
653            vec![Header {
654                name: Bytes::from("key"),
655                value: Bytes::from("value"),
656            }],
657            Bytes::new(),
658        )
659        .unwrap();
660        let encoded = record.to_bytes();
661
662        for len in 1..encoded.len() {
663            let truncated = encoded.slice(..len);
664            assert!(
665                matches!(
666                    decode_envelope_record(truncated),
667                    Err(StoredRecordDecodeError::Truncated(_))
668                ),
669                "expected Truncated error for len {len}"
670            );
671        }
672    }
673}