Skip to main content

s2_common/record/
envelope.rs

1use std::num::NonZeroU8;
2
3use bytes::{Buf, BufMut, Bytes};
4
5use super::{Encodable, Header, MeteredSize, RecordDecodeError, RecordPartsError};
6use crate::deep_size::DeepSize;
7
8#[derive(Debug, PartialEq, thiserror::Error)]
9pub enum HeaderValidationError {
10    #[error("too many")]
11    TooMany,
12    #[error("too long")]
13    TooLong,
14    #[error("empty name")]
15    NameEmpty,
16}
17
18#[derive(PartialEq, Eq, Clone)]
19pub struct EnvelopeRecord {
20    headers: Vec<Header>,
21    body: Bytes,
22    encoding_info: EncodingInfo,
23}
24
25impl std::fmt::Debug for EnvelopeRecord {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("EnvelopeRecord")
28            .field("headers.len", &self.headers.len())
29            .field("body.len", &self.body.len())
30            .finish()
31    }
32}
33
34impl DeepSize for EnvelopeRecord {
35    fn deep_size(&self) -> usize {
36        self.headers.deep_size() + self.body.deep_size()
37    }
38}
39
40impl MeteredSize for EnvelopeRecord {
41    fn metered_size(&self) -> usize {
42        8 + (2 * self.headers.len()) + self.encoding_info.headers_total_bytes + self.body.len()
43    }
44}
45
46impl EnvelopeRecord {
47    pub fn headers(&self) -> &[Header] {
48        &self.headers
49    }
50
51    pub fn body(&self) -> &Bytes {
52        &self.body
53    }
54
55    pub fn into_parts(self) -> (Vec<Header>, Bytes) {
56        (self.headers, self.body)
57    }
58
59    pub fn try_from_parts(headers: Vec<Header>, body: Bytes) -> Result<Self, RecordPartsError> {
60        let encoding_info = headers.as_slice().try_into()?;
61        Ok(Self {
62            headers,
63            body,
64            encoding_info,
65        })
66    }
67}
68
69impl Encodable for EnvelopeRecord {
70    fn encoded_size(&self) -> usize {
71        1 + self.encoding_info.flag.num_headers_length_bytes as usize
72            + self.headers.len()
73                * (self.encoding_info.flag.name_length_bytes.get() as usize
74                    + self.encoding_info.flag.value_length_bytes.get() as usize)
75            + self.encoding_info.headers_total_bytes
76            + self.body.len()
77    }
78
79    fn encode_into(&self, buf: &mut impl BufMut) {
80        // Write prefix: flag and number of headers.
81        buf.put_u8(self.encoding_info.flag.into());
82        buf.put_uint(
83            self.headers.len() as u64,
84            self.encoding_info.flag.num_headers_length_bytes as usize,
85        );
86        // Write headers.
87        for Header { name, value } in &self.headers {
88            buf.put_uint(
89                name.len() as u64,
90                self.encoding_info.flag.name_length_bytes.get() as usize,
91            );
92            buf.put_slice(name);
93            buf.put_uint(
94                value.len() as u64,
95                self.encoding_info.flag.value_length_bytes.get() as usize,
96            );
97            buf.put_slice(value);
98        }
99        buf.put_slice(&self.body);
100    }
101}
102
103impl TryFrom<Bytes> for EnvelopeRecord {
104    type Error = RecordDecodeError;
105
106    fn try_from(mut buf: Bytes) -> Result<Self, Self::Error> {
107        if buf.is_empty() {
108            return Err(RecordDecodeError::InvalidValue("HeaderFlag", "missing"));
109        }
110
111        let flag: HeaderFlag = buf
112            .get_u8()
113            .try_into()
114            .map_err(|info| RecordDecodeError::InvalidValue("HeaderFlag", info))?;
115        if flag.num_headers_length_bytes == 0 {
116            return Ok(Self {
117                encoding_info: EMPTY_HEADERS_ENCODING_INFO,
118                headers: vec![],
119                body: buf,
120            });
121        }
122
123        let num_headers = buf
124            .try_get_uint(flag.num_headers_length_bytes as usize)
125            .map_err(|_| RecordDecodeError::Truncated("NumHeaders"))?;
126
127        let mut headers_total_bytes = 0;
128        let mut headers: Vec<Header> = Vec::with_capacity(num_headers as usize);
129        for _ in 0..num_headers {
130            let name_len = buf
131                .try_get_uint(flag.name_length_bytes.get() as usize)
132                .map_err(|_| RecordDecodeError::Truncated("HeaderNameLen"))?
133                as usize;
134            if buf.remaining() < name_len {
135                return Err(RecordDecodeError::Truncated("HeaderName"));
136            }
137            let name = buf.split_to(name_len);
138
139            let value_len = buf
140                .try_get_uint(flag.value_length_bytes.get() as usize)
141                .map_err(|_| RecordDecodeError::Truncated("HeaderValueLen"))?
142                as usize;
143            if buf.remaining() < value_len {
144                return Err(RecordDecodeError::Truncated("HeaderValue"));
145            }
146            let value = buf.split_to(value_len);
147
148            headers_total_bytes += name.len() + value.len();
149            headers.push(Header { name, value })
150        }
151
152        Ok(Self {
153            encoding_info: EncodingInfo {
154                headers_total_bytes,
155                flag,
156            },
157            headers,
158            body: buf,
159        })
160    }
161}
162
163const EMPTY_HEADER_FLAG: HeaderFlag = HeaderFlag {
164    num_headers_length_bytes: 0,
165    name_length_bytes: NonZeroU8::new(1).unwrap(),
166    value_length_bytes: NonZeroU8::new(1).unwrap(),
167};
168
169#[derive(Debug, PartialEq, Eq, Clone, Copy)]
170struct HeaderFlag {
171    num_headers_length_bytes: u8,
172    name_length_bytes: NonZeroU8,
173    value_length_bytes: NonZeroU8,
174}
175
176impl From<HeaderFlag> for u8 {
177    fn from(value: HeaderFlag) -> Self {
178        (value.num_headers_length_bytes << 4)
179            | ((value.name_length_bytes.get() - 1) << 2)
180            | (value.value_length_bytes.get() - 1)
181    }
182}
183
184impl TryFrom<u8> for HeaderFlag {
185    type Error = &'static str;
186
187    fn try_from(value: u8) -> Result<Self, Self::Error> {
188        if (value & (0b11u8 << 6)) != 0u8 {
189            return Err("reserved bit set");
190        }
191        Ok(Self {
192            num_headers_length_bytes: (0b110000 & value) >> 4,
193            name_length_bytes: NonZeroU8::new(((0b1100 & value) >> 2) + 1).unwrap(),
194            value_length_bytes: NonZeroU8::new((0b11 & value) + 1).unwrap(),
195        })
196    }
197}
198
199const EMPTY_HEADERS_ENCODING_INFO: EncodingInfo = EncodingInfo {
200    headers_total_bytes: 0,
201    flag: EMPTY_HEADER_FLAG,
202};
203
204#[derive(Debug, PartialEq, Eq, Clone, Copy)]
205struct EncodingInfo {
206    headers_total_bytes: usize,
207    flag: HeaderFlag,
208}
209
210impl TryFrom<&[Header]> for EncodingInfo {
211    type Error = HeaderValidationError;
212
213    fn try_from(headers: &[Header]) -> Result<Self, Self::Error> {
214        // Given number of KV pairs, determine how many bytes are required for storing
215        // the length number.
216        fn size_bytes_headers_len(elems: u64) -> Result<u8, HeaderValidationError> {
217            let size = 8 - elems.leading_zeros() / 8;
218            if size > 3 {
219                Err(HeaderValidationError::TooMany)
220            } else {
221                Ok(size as u8)
222            }
223        }
224
225        // Given max length of a name (key) or value, determine how many bytes are required for
226        // storing this number.
227        fn size_bytes_name_value_len(elems: u64) -> Result<NonZeroU8, HeaderValidationError> {
228            if elems == 0 {
229                return Ok(NonZeroU8::new(1u8).unwrap());
230            }
231            let size = 8 - (elems.leading_zeros() / 8);
232            if size > 4 {
233                Err(HeaderValidationError::TooLong)
234            } else {
235                Ok(NonZeroU8::new(size as u8).unwrap())
236            }
237        }
238
239        if headers.is_empty() {
240            return Ok(EMPTY_HEADERS_ENCODING_INFO);
241        }
242
243        let (headers_total_bytes, name_max, value_max) = headers.iter().try_fold(
244            (0usize, 0usize, 0usize),
245            |(size_bytes_acc, name_max, value_max), Header { name, value }| {
246                if name.is_empty() {
247                    return Err(HeaderValidationError::NameEmpty);
248                }
249                let name_len = name.len();
250                let value_len = value.len();
251                Ok((
252                    size_bytes_acc + name_len + value_len,
253                    name_max.max(name_len),
254                    value_max.max(value_len),
255                ))
256            },
257        )?;
258
259        let num_headers_length_bytes = size_bytes_headers_len(headers.len() as u64)?;
260        let name_length_bytes = size_bytes_name_value_len(name_max as u64)?;
261        let value_length_bytes = size_bytes_name_value_len(value_max as u64)?;
262
263        Ok(Self {
264            headers_total_bytes,
265            flag: HeaderFlag {
266                num_headers_length_bytes,
267                name_length_bytes,
268                value_length_bytes,
269            },
270        })
271    }
272}
273
274#[cfg(test)]
275mod test {
276    use std::num::NonZeroU8;
277
278    use bytes::Bytes;
279
280    use super::{
281        Encodable as _, EnvelopeRecord, Header, HeaderFlag, MeteredSize, RecordDecodeError,
282    };
283
284    fn roundtrip_parts(headers: Vec<Header>, body: Bytes) {
285        let encoded: Bytes = EnvelopeRecord::try_from_parts(headers.clone(), body.clone())
286            .unwrap()
287            .to_bytes();
288        let decoded = EnvelopeRecord::try_from(encoded).unwrap();
289        assert_eq!(decoded.headers(), headers);
290        assert_eq!(decoded.body(), &body);
291    }
292
293    #[test]
294    fn framed_with_headers() {
295        roundtrip_parts(
296            vec![
297                Header {
298                    name: Bytes::from("key_1"),
299                    value: Bytes::from("val_1"),
300                },
301                Header {
302                    name: Bytes::from("key_2"),
303                    value: Bytes::from("val_2"),
304                },
305                Header {
306                    name: Bytes::from("key_3"),
307                    value: Bytes::from("val_3"),
308                },
309                Header {
310                    name: Bytes::from("key_4"),
311                    value: Bytes::from("val_4"),
312                },
313            ],
314            Bytes::from("hello"),
315        );
316    }
317
318    #[test]
319    fn framed_no_headers() {
320        roundtrip_parts(vec![], Bytes::from("hello"));
321    }
322
323    #[test]
324    fn framed_duplicate_keys() {
325        // Duplicate keys preserved in original order.
326        roundtrip_parts(
327            vec![
328                Header {
329                    name: Bytes::from("b"),
330                    value: Bytes::from("val_1"),
331                },
332                Header {
333                    name: Bytes::from("b"),
334                    value: Bytes::from("val_2"),
335                },
336                Header {
337                    name: Bytes::from("a"),
338                    value: Bytes::from("val_3"),
339                },
340            ],
341            Bytes::from("hello"),
342        );
343    }
344
345    #[test]
346    fn metered_size_uses_cached_header_bytes() {
347        let record = EnvelopeRecord::try_from_parts(
348            vec![
349                Header {
350                    name: Bytes::from("alpha"),
351                    value: Bytes::from("1"),
352                },
353                Header {
354                    name: Bytes::from("beta"),
355                    value: Bytes::from("two"),
356                },
357            ],
358            Bytes::from("body"),
359        )
360        .unwrap();
361
362        assert_eq!(
363            record.metered_size(),
364            8 + (2 * record.headers().len())
365                + ("alpha".len() + "1".len() + "beta".len() + "two".len())
366                + "body".len()
367        );
368    }
369
370    #[test]
371    fn flag_ex1() {
372        assert_eq!(
373            Ok(HeaderFlag {
374                num_headers_length_bytes: 2,
375                name_length_bytes: NonZeroU8::new(1).unwrap(),
376                value_length_bytes: NonZeroU8::new(1).unwrap(),
377            }),
378            0b00100000.try_into()
379        );
380
381        let u8_repr: u8 = HeaderFlag {
382            num_headers_length_bytes: 2,
383            name_length_bytes: NonZeroU8::new(1).unwrap(),
384            value_length_bytes: NonZeroU8::new(1).unwrap(),
385        }
386        .into();
387        assert_eq!(u8_repr, 0b00100000);
388    }
389
390    #[test]
391    fn flag_ex2() {
392        assert_eq!(
393            Ok(HeaderFlag {
394                num_headers_length_bytes: 1,
395                name_length_bytes: NonZeroU8::new(1).unwrap(),
396                value_length_bytes: NonZeroU8::new(1).unwrap(),
397            }),
398            0b00010000.try_into()
399        );
400
401        let u8_repr: u8 = HeaderFlag {
402            num_headers_length_bytes: 1,
403            name_length_bytes: NonZeroU8::new(1).unwrap(),
404            value_length_bytes: NonZeroU8::new(1).unwrap(),
405        }
406        .into();
407        assert_eq!(u8_repr, 0b00010000);
408    }
409
410    #[test]
411    fn empty_envelope_size() {
412        assert_eq!(
413            1,
414            EnvelopeRecord::try_from_parts(vec![], Bytes::new())
415                .unwrap()
416                .to_bytes()
417                .len()
418        );
419    }
420
421    #[test]
422    fn truncated_returns_error() {
423        let record = EnvelopeRecord::try_from_parts(
424            vec![Header {
425                name: Bytes::from("key"),
426                value: Bytes::from("value"),
427            }],
428            Bytes::new(),
429        )
430        .unwrap();
431        let encoded = record.to_bytes();
432
433        // Truncation anywhere before the end should error
434        // (with empty body, there's no trailing data that can be truncated safely).
435        for len in 1..encoded.len() {
436            let truncated = encoded.slice(..len);
437            assert!(
438                matches!(
439                    EnvelopeRecord::try_from(truncated),
440                    Err(RecordDecodeError::Truncated(_))
441                ),
442                "expected Truncated error for len {len}"
443            );
444        }
445    }
446}