Skip to main content

s2_common/record/
envelope.rs

1use std::num::NonZeroU8;
2
3use bytes::{Buf, BufMut, Bytes};
4
5use super::{Encodable, Header, MeteredSize, RecordDecodeError, RecordPartsError};
6use crate::deep_size::DeepSize;
7
8#[derive(Debug, PartialEq, thiserror::Error)]
9pub enum HeaderValidationError {
10    #[error("too many")]
11    TooMany,
12    #[error("too long")]
13    TooLong,
14    #[error("empty name")]
15    NameEmpty,
16}
17
18#[derive(PartialEq, Eq, Clone)]
19pub struct EnvelopeRecord {
20    headers: Vec<Header>,
21    body: Bytes,
22    encoding_info: EncodingInfo,
23}
24
25impl std::fmt::Debug for EnvelopeRecord {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("EnvelopeRecord")
28            .field("headers.len", &self.headers.len())
29            .field("body.len", &self.body.len())
30            .finish()
31    }
32}
33
34impl DeepSize for EnvelopeRecord {
35    fn deep_size(&self) -> usize {
36        self.headers.deep_size() + self.body.deep_size()
37    }
38}
39
40impl MeteredSize for EnvelopeRecord {
41    fn metered_size(&self) -> usize {
42        8 + (2 * self.headers.len()) + self.encoding_info.headers_total_bytes + self.body.len()
43    }
44}
45
46impl EnvelopeRecord {
47    pub fn headers(&self) -> &[Header] {
48        &self.headers
49    }
50
51    pub fn body(&self) -> &Bytes {
52        &self.body
53    }
54
55    pub fn into_parts(self) -> (Vec<Header>, Bytes) {
56        (self.headers, self.body)
57    }
58
59    pub fn try_from_parts(headers: Vec<Header>, body: Bytes) -> Result<Self, RecordPartsError> {
60        let encoding_info = headers.as_slice().try_into()?;
61        Ok(Self {
62            headers,
63            body,
64            encoding_info,
65        })
66    }
67}
68
69impl Encodable for EnvelopeRecord {
70    fn encoded_size(&self) -> usize {
71        1 + self.encoding_info.flag.num_headers_length_bytes as usize
72            + self.headers.len()
73                * (self.encoding_info.flag.name_length_bytes.get() as usize
74                    + self.encoding_info.flag.value_length_bytes.get() as usize)
75            + self.encoding_info.headers_total_bytes
76            + self.body.len()
77    }
78
79    fn encode_into(&self, buf: &mut impl BufMut) {
80        // Write prefix: flag and number of headers.
81        buf.put_u8(self.encoding_info.flag.into());
82        buf.put_uint(
83            self.headers.len() as u64,
84            self.encoding_info.flag.num_headers_length_bytes as usize,
85        );
86        // Write headers.
87        for Header { name, value } in &self.headers {
88            buf.put_uint(
89                name.len() as u64,
90                self.encoding_info.flag.name_length_bytes.get() as usize,
91            );
92            buf.put_slice(name);
93            buf.put_uint(
94                value.len() as u64,
95                self.encoding_info.flag.value_length_bytes.get() as usize,
96            );
97            buf.put_slice(value);
98        }
99        buf.put_slice(&self.body);
100    }
101}
102
103impl TryFrom<Bytes> for EnvelopeRecord {
104    type Error = RecordDecodeError;
105
106    fn try_from(mut buf: Bytes) -> Result<Self, Self::Error> {
107        if buf.is_empty() {
108            return Err(RecordDecodeError::InvalidValue("HeaderFlag", "missing"));
109        }
110
111        let flag: HeaderFlag = buf
112            .get_u8()
113            .try_into()
114            .map_err(|info| RecordDecodeError::InvalidValue("HeaderFlag", info))?;
115        if flag.num_headers_length_bytes == 0 {
116            return Ok(Self {
117                encoding_info: EMPTY_HEADERS_ENCODING_INFO,
118                headers: vec![],
119                body: buf,
120            });
121        }
122
123        let num_headers = buf
124            .try_get_uint(flag.num_headers_length_bytes as usize)
125            .map_err(|_| RecordDecodeError::Truncated("NumHeaders"))?;
126
127        let mut headers_total_bytes = 0;
128        let mut headers: Vec<Header> = Vec::with_capacity(num_headers as usize);
129        for _ in 0..num_headers {
130            let name_len = buf
131                .try_get_uint(flag.name_length_bytes.get() as usize)
132                .map_err(|_| RecordDecodeError::Truncated("HeaderNameLen"))?
133                as usize;
134            if name_len == 0 {
135                return Err(RecordDecodeError::InvalidValue("HeaderName", "empty"));
136            }
137            if buf.remaining() < name_len {
138                return Err(RecordDecodeError::Truncated("HeaderName"));
139            }
140            let name = buf.split_to(name_len);
141
142            let value_len = buf
143                .try_get_uint(flag.value_length_bytes.get() as usize)
144                .map_err(|_| RecordDecodeError::Truncated("HeaderValueLen"))?
145                as usize;
146            if buf.remaining() < value_len {
147                return Err(RecordDecodeError::Truncated("HeaderValue"));
148            }
149            let value = buf.split_to(value_len);
150
151            headers_total_bytes += name.len() + value.len();
152            headers.push(Header { name, value })
153        }
154
155        Ok(Self {
156            encoding_info: EncodingInfo {
157                headers_total_bytes,
158                flag,
159            },
160            headers,
161            body: buf,
162        })
163    }
164}
165
166const EMPTY_HEADER_FLAG: HeaderFlag = HeaderFlag {
167    num_headers_length_bytes: 0,
168    name_length_bytes: NonZeroU8::new(1).unwrap(),
169    value_length_bytes: NonZeroU8::new(1).unwrap(),
170};
171
172#[derive(Debug, PartialEq, Eq, Clone, Copy)]
173struct HeaderFlag {
174    num_headers_length_bytes: u8,
175    name_length_bytes: NonZeroU8,
176    value_length_bytes: NonZeroU8,
177}
178
179impl From<HeaderFlag> for u8 {
180    fn from(value: HeaderFlag) -> Self {
181        (value.num_headers_length_bytes << 4)
182            | ((value.name_length_bytes.get() - 1) << 2)
183            | (value.value_length_bytes.get() - 1)
184    }
185}
186
187impl TryFrom<u8> for HeaderFlag {
188    type Error = &'static str;
189
190    fn try_from(value: u8) -> Result<Self, Self::Error> {
191        if (value & (0b11u8 << 6)) != 0u8 {
192            return Err("reserved bit set");
193        }
194        Ok(Self {
195            num_headers_length_bytes: (0b110000 & value) >> 4,
196            name_length_bytes: NonZeroU8::new(((0b1100 & value) >> 2) + 1).unwrap(),
197            value_length_bytes: NonZeroU8::new((0b11 & value) + 1).unwrap(),
198        })
199    }
200}
201
202const EMPTY_HEADERS_ENCODING_INFO: EncodingInfo = EncodingInfo {
203    headers_total_bytes: 0,
204    flag: EMPTY_HEADER_FLAG,
205};
206
207#[derive(Debug, PartialEq, Eq, Clone, Copy)]
208struct EncodingInfo {
209    headers_total_bytes: usize,
210    flag: HeaderFlag,
211}
212
213impl TryFrom<&[Header]> for EncodingInfo {
214    type Error = HeaderValidationError;
215
216    fn try_from(headers: &[Header]) -> Result<Self, Self::Error> {
217        // Given number of KV pairs, determine how many bytes are required for storing
218        // the length number.
219        fn size_bytes_headers_len(elems: u64) -> Result<u8, HeaderValidationError> {
220            let size = 8 - elems.leading_zeros() / 8;
221            if size > 3 {
222                Err(HeaderValidationError::TooMany)
223            } else {
224                Ok(size as u8)
225            }
226        }
227
228        // Given max length of a name (key) or value, determine how many bytes are required for
229        // storing this number.
230        fn size_bytes_name_value_len(elems: u64) -> Result<NonZeroU8, HeaderValidationError> {
231            if elems == 0 {
232                return Ok(NonZeroU8::new(1u8).unwrap());
233            }
234            let size = 8 - (elems.leading_zeros() / 8);
235            if size > 4 {
236                Err(HeaderValidationError::TooLong)
237            } else {
238                Ok(NonZeroU8::new(size as u8).unwrap())
239            }
240        }
241
242        if headers.is_empty() {
243            return Ok(EMPTY_HEADERS_ENCODING_INFO);
244        }
245
246        let (headers_total_bytes, name_max, value_max) = headers.iter().try_fold(
247            (0usize, 0usize, 0usize),
248            |(size_bytes_acc, name_max, value_max), Header { name, value }| {
249                if name.is_empty() {
250                    return Err(HeaderValidationError::NameEmpty);
251                }
252                let name_len = name.len();
253                let value_len = value.len();
254                Ok((
255                    size_bytes_acc + name_len + value_len,
256                    name_max.max(name_len),
257                    value_max.max(value_len),
258                ))
259            },
260        )?;
261
262        let num_headers_length_bytes = size_bytes_headers_len(headers.len() as u64)?;
263        let name_length_bytes = size_bytes_name_value_len(name_max as u64)?;
264        let value_length_bytes = size_bytes_name_value_len(value_max as u64)?;
265
266        Ok(Self {
267            headers_total_bytes,
268            flag: HeaderFlag {
269                num_headers_length_bytes,
270                name_length_bytes,
271                value_length_bytes,
272            },
273        })
274    }
275}
276
277#[cfg(test)]
278mod test {
279    use std::num::NonZeroU8;
280
281    use bytes::{BufMut, Bytes, BytesMut};
282
283    use super::{
284        Encodable as _, EnvelopeRecord, Header, HeaderFlag, MeteredSize, RecordDecodeError,
285    };
286
287    fn roundtrip_parts(headers: Vec<Header>, body: Bytes) {
288        let encoded: Bytes = EnvelopeRecord::try_from_parts(headers.clone(), body.clone())
289            .unwrap()
290            .to_bytes();
291        let decoded = EnvelopeRecord::try_from(encoded).unwrap();
292        assert_eq!(decoded.headers(), headers);
293        assert_eq!(decoded.body(), &body);
294    }
295
296    #[test]
297    fn framed_with_headers() {
298        roundtrip_parts(
299            vec![
300                Header {
301                    name: Bytes::from("key_1"),
302                    value: Bytes::from("val_1"),
303                },
304                Header {
305                    name: Bytes::from("key_2"),
306                    value: Bytes::from("val_2"),
307                },
308                Header {
309                    name: Bytes::from("key_3"),
310                    value: Bytes::from("val_3"),
311                },
312                Header {
313                    name: Bytes::from("key_4"),
314                    value: Bytes::from("val_4"),
315                },
316            ],
317            Bytes::from("hello"),
318        );
319    }
320
321    #[test]
322    fn framed_no_headers() {
323        roundtrip_parts(vec![], Bytes::from("hello"));
324    }
325
326    #[test]
327    fn decode_rejects_empty_header_name() {
328        let mut encoded = BytesMut::new();
329        encoded.put_u8(
330            HeaderFlag {
331                num_headers_length_bytes: 1,
332                name_length_bytes: NonZeroU8::new(1).unwrap(),
333                value_length_bytes: NonZeroU8::new(1).unwrap(),
334            }
335            .into(),
336        );
337        encoded.put_u8(1); // number of headers
338        encoded.put_u8(0); // header name length
339        encoded.put_u8(5); // header value length
340        encoded.put_slice(b"value");
341        encoded.put_slice(b"body");
342
343        assert_eq!(
344            EnvelopeRecord::try_from(encoded.freeze()),
345            Err(RecordDecodeError::InvalidValue("HeaderName", "empty"))
346        );
347    }
348
349    #[test]
350    fn framed_duplicate_keys() {
351        // Duplicate keys preserved in original order.
352        roundtrip_parts(
353            vec![
354                Header {
355                    name: Bytes::from("b"),
356                    value: Bytes::from("val_1"),
357                },
358                Header {
359                    name: Bytes::from("b"),
360                    value: Bytes::from("val_2"),
361                },
362                Header {
363                    name: Bytes::from("a"),
364                    value: Bytes::from("val_3"),
365                },
366            ],
367            Bytes::from("hello"),
368        );
369    }
370
371    #[test]
372    fn metered_size_uses_cached_header_bytes() {
373        let record = EnvelopeRecord::try_from_parts(
374            vec![
375                Header {
376                    name: Bytes::from("alpha"),
377                    value: Bytes::from("1"),
378                },
379                Header {
380                    name: Bytes::from("beta"),
381                    value: Bytes::from("two"),
382                },
383            ],
384            Bytes::from("body"),
385        )
386        .unwrap();
387
388        assert_eq!(
389            record.metered_size(),
390            8 + (2 * record.headers().len())
391                + ("alpha".len() + "1".len() + "beta".len() + "two".len())
392                + "body".len()
393        );
394    }
395
396    #[test]
397    fn flag_ex1() {
398        assert_eq!(
399            Ok(HeaderFlag {
400                num_headers_length_bytes: 2,
401                name_length_bytes: NonZeroU8::new(1).unwrap(),
402                value_length_bytes: NonZeroU8::new(1).unwrap(),
403            }),
404            0b00100000.try_into()
405        );
406
407        let u8_repr: u8 = HeaderFlag {
408            num_headers_length_bytes: 2,
409            name_length_bytes: NonZeroU8::new(1).unwrap(),
410            value_length_bytes: NonZeroU8::new(1).unwrap(),
411        }
412        .into();
413        assert_eq!(u8_repr, 0b00100000);
414    }
415
416    #[test]
417    fn flag_ex2() {
418        assert_eq!(
419            Ok(HeaderFlag {
420                num_headers_length_bytes: 1,
421                name_length_bytes: NonZeroU8::new(1).unwrap(),
422                value_length_bytes: NonZeroU8::new(1).unwrap(),
423            }),
424            0b00010000.try_into()
425        );
426
427        let u8_repr: u8 = HeaderFlag {
428            num_headers_length_bytes: 1,
429            name_length_bytes: NonZeroU8::new(1).unwrap(),
430            value_length_bytes: NonZeroU8::new(1).unwrap(),
431        }
432        .into();
433        assert_eq!(u8_repr, 0b00010000);
434    }
435
436    #[test]
437    fn empty_envelope_size() {
438        assert_eq!(
439            1,
440            EnvelopeRecord::try_from_parts(vec![], Bytes::new())
441                .unwrap()
442                .to_bytes()
443                .len()
444        );
445    }
446
447    #[test]
448    fn truncated_returns_error() {
449        let record = EnvelopeRecord::try_from_parts(
450            vec![Header {
451                name: Bytes::from("key"),
452                value: Bytes::from("value"),
453            }],
454            Bytes::new(),
455        )
456        .unwrap();
457        let encoded = record.to_bytes();
458
459        // Truncation anywhere before the end should error
460        // (with empty body, there's no trailing data that can be truncated safely).
461        for len in 1..encoded.len() {
462            let truncated = encoded.slice(..len);
463            assert!(
464                matches!(
465                    EnvelopeRecord::try_from(truncated),
466                    Err(RecordDecodeError::Truncated(_))
467                ),
468                "expected Truncated error for len {len}"
469            );
470        }
471    }
472}