s2_common/record/
envelope.rs

1use std::num::NonZeroU8;
2
3use bytes::{Buf, BufMut, Bytes};
4
5use super::{Encodable, Header, InternalRecordError, PublicRecordError};
6use crate::deep_size::DeepSize;
7
8#[derive(Debug, PartialEq, thiserror::Error)]
9pub enum HeaderValidationError {
10    #[error("too many")]
11    TooMany,
12    #[error("too long")]
13    TooLong,
14    #[error("empty name")]
15    NameEmpty,
16}
17
18#[derive(PartialEq, Eq, Clone)]
19pub struct EnvelopeRecord {
20    headers: Vec<Header>,
21    body: Bytes,
22    encoding_info: EncodingInfo,
23}
24
25impl std::fmt::Debug for EnvelopeRecord {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("EnvelopeRecord")
28            .field("headers.len", &self.headers.len())
29            .field("body.len", &self.body.len())
30            .finish()
31    }
32}
33
34impl DeepSize for EnvelopeRecord {
35    fn deep_size(&self) -> usize {
36        self.headers.deep_size() + self.body.deep_size()
37    }
38}
39
40impl EnvelopeRecord {
41    pub fn headers(&self) -> &[Header] {
42        &self.headers
43    }
44
45    pub fn body(&self) -> &Bytes {
46        &self.body
47    }
48
49    pub fn into_parts(self) -> (Vec<Header>, Bytes) {
50        (self.headers, self.body)
51    }
52
53    pub fn try_from_parts(headers: Vec<Header>, body: Bytes) -> Result<Self, PublicRecordError> {
54        let encoding_info = headers.as_slice().try_into()?;
55        Ok(Self {
56            headers,
57            body,
58            encoding_info,
59        })
60    }
61}
62
63impl Encodable for EnvelopeRecord {
64    fn encoded_size(&self) -> usize {
65        1 + self.encoding_info.flag.num_headers_length_bytes as usize
66            + self.headers.len()
67                * (self.encoding_info.flag.name_length_bytes.get() as usize
68                    + self.encoding_info.flag.value_length_bytes.get() as usize)
69            + self.encoding_info.headers_total_bytes
70            + self.body.len()
71    }
72
73    fn encode_into(&self, buf: &mut impl BufMut) {
74        // Write prefix: flag and number of headers.
75        buf.put_u8(self.encoding_info.flag.into());
76        buf.put_uint(
77            self.headers.len() as u64,
78            self.encoding_info.flag.num_headers_length_bytes as usize,
79        );
80        // Write headers.
81        for Header { name, value } in &self.headers {
82            buf.put_uint(
83                name.len() as u64,
84                self.encoding_info.flag.name_length_bytes.get() as usize,
85            );
86            buf.put_slice(name);
87            buf.put_uint(
88                value.len() as u64,
89                self.encoding_info.flag.value_length_bytes.get() as usize,
90            );
91            buf.put_slice(value);
92        }
93        buf.put_slice(&self.body);
94    }
95}
96
97impl TryFrom<Bytes> for EnvelopeRecord {
98    type Error = InternalRecordError;
99
100    fn try_from(mut buf: Bytes) -> Result<Self, Self::Error> {
101        if buf.is_empty() {
102            return Err(InternalRecordError::InvalidValue("HeaderFlag", "missing"));
103        }
104
105        let flag: HeaderFlag = buf
106            .get_u8()
107            .try_into()
108            .map_err(|info| InternalRecordError::InvalidValue("HeaderFlag", info))?;
109        if flag.num_headers_length_bytes == 0 {
110            return Ok(Self {
111                encoding_info: EMPTY_HEADERS_ENCODING_INFO,
112                headers: vec![],
113                body: buf,
114            });
115        }
116
117        let num_headers = buf
118            .try_get_uint(flag.num_headers_length_bytes as usize)
119            .map_err(|_| InternalRecordError::Truncated("NumHeaders"))?;
120
121        let mut headers_total_bytes = 0;
122        let mut headers: Vec<Header> = Vec::with_capacity(num_headers as usize);
123        for _ in 0..num_headers {
124            let name_len = buf
125                .try_get_uint(flag.name_length_bytes.get() as usize)
126                .map_err(|_| InternalRecordError::Truncated("HeaderNameLen"))?
127                as usize;
128            if buf.remaining() < name_len {
129                return Err(InternalRecordError::Truncated("HeaderName"));
130            }
131            let name = buf.split_to(name_len);
132
133            let value_len = buf
134                .try_get_uint(flag.value_length_bytes.get() as usize)
135                .map_err(|_| InternalRecordError::Truncated("HeaderValueLen"))?
136                as usize;
137            if buf.remaining() < value_len {
138                return Err(InternalRecordError::Truncated("HeaderValue"));
139            }
140            let value = buf.split_to(value_len);
141
142            headers_total_bytes += name.len() + value.len();
143            headers.push(Header { name, value })
144        }
145
146        Ok(Self {
147            encoding_info: EncodingInfo {
148                headers_total_bytes,
149                flag,
150            },
151            headers,
152            body: buf,
153        })
154    }
155}
156
157const EMPTY_HEADER_FLAG: HeaderFlag = HeaderFlag {
158    num_headers_length_bytes: 0,
159    name_length_bytes: NonZeroU8::new(1).unwrap(),
160    value_length_bytes: NonZeroU8::new(1).unwrap(),
161};
162
163#[derive(Debug, PartialEq, Eq, Clone, Copy)]
164struct HeaderFlag {
165    num_headers_length_bytes: u8,
166    name_length_bytes: NonZeroU8,
167    value_length_bytes: NonZeroU8,
168}
169
170impl From<HeaderFlag> for u8 {
171    fn from(value: HeaderFlag) -> Self {
172        (value.num_headers_length_bytes << 4)
173            | ((value.name_length_bytes.get() - 1) << 2)
174            | (value.value_length_bytes.get() - 1)
175    }
176}
177
178impl TryFrom<u8> for HeaderFlag {
179    type Error = &'static str;
180
181    fn try_from(value: u8) -> Result<Self, Self::Error> {
182        if (value & (0b11u8 << 6)) != 0u8 {
183            return Err("reserved bit set");
184        }
185        Ok(Self {
186            num_headers_length_bytes: (0b110000 & value) >> 4,
187            name_length_bytes: NonZeroU8::new(((0b1100 & value) >> 2) + 1).unwrap(),
188            value_length_bytes: NonZeroU8::new((0b11 & value) + 1).unwrap(),
189        })
190    }
191}
192
193const EMPTY_HEADERS_ENCODING_INFO: EncodingInfo = EncodingInfo {
194    headers_total_bytes: 0,
195    flag: EMPTY_HEADER_FLAG,
196};
197
198#[derive(Debug, PartialEq, Eq, Clone, Copy)]
199struct EncodingInfo {
200    headers_total_bytes: usize,
201    flag: HeaderFlag,
202}
203
204impl TryFrom<&[Header]> for EncodingInfo {
205    type Error = HeaderValidationError;
206
207    fn try_from(headers: &[Header]) -> Result<Self, Self::Error> {
208        // Given number of KV pairs, determine how many bytes are required for storing
209        // the length number.
210        fn size_bytes_headers_len(elems: u64) -> Result<u8, HeaderValidationError> {
211            let size = 8 - elems.leading_zeros() / 8;
212            if size > 3 {
213                Err(HeaderValidationError::TooMany)
214            } else {
215                Ok(size as u8)
216            }
217        }
218
219        // Given max length of a name (key) or value, determine how many bytes are required for
220        // storing this number.
221        fn size_bytes_name_value_len(elems: u64) -> Result<NonZeroU8, HeaderValidationError> {
222            if elems == 0 {
223                return Ok(NonZeroU8::new(1u8).unwrap());
224            }
225            let size = 8 - (elems.leading_zeros() / 8);
226            if size > 4 {
227                Err(HeaderValidationError::TooLong)
228            } else {
229                Ok(NonZeroU8::new(size as u8).unwrap())
230            }
231        }
232
233        if headers.is_empty() {
234            return Ok(EMPTY_HEADERS_ENCODING_INFO);
235        }
236
237        let (headers_total_bytes, name_max, value_max) = headers.iter().try_fold(
238            (0usize, 0usize, 0usize),
239            |(size_bytes_acc, name_max, value_max), Header { name, value }| {
240                if name.is_empty() {
241                    return Err(HeaderValidationError::NameEmpty);
242                }
243                let name_len = name.len();
244                let value_len = value.len();
245                Ok((
246                    size_bytes_acc + name_len + value_len,
247                    name_max.max(name_len),
248                    value_max.max(value_len),
249                ))
250            },
251        )?;
252
253        let num_headers_length_bytes = size_bytes_headers_len(headers.len() as u64)?;
254        let name_length_bytes = size_bytes_name_value_len(name_max as u64)?;
255        let value_length_bytes = size_bytes_name_value_len(value_max as u64)?;
256
257        Ok(Self {
258            headers_total_bytes,
259            flag: HeaderFlag {
260                num_headers_length_bytes,
261                name_length_bytes,
262                value_length_bytes,
263            },
264        })
265    }
266}
267
268#[cfg(test)]
269mod test {
270    use std::num::NonZeroU8;
271
272    use bytes::Bytes;
273
274    use super::{Encodable as _, EnvelopeRecord, Header, HeaderFlag, InternalRecordError};
275
276    fn roundtrip_parts(headers: Vec<Header>, body: Bytes) {
277        let encoded: Bytes = EnvelopeRecord::try_from_parts(headers.clone(), body.clone())
278            .unwrap()
279            .to_bytes();
280        let decoded = EnvelopeRecord::try_from(encoded).unwrap();
281        assert_eq!(decoded.headers(), headers);
282        assert_eq!(decoded.body(), &body);
283    }
284
285    #[test]
286    fn framed_with_headers() {
287        roundtrip_parts(
288            vec![
289                Header {
290                    name: Bytes::from("key_1"),
291                    value: Bytes::from("val_1"),
292                },
293                Header {
294                    name: Bytes::from("key_2"),
295                    value: Bytes::from("val_2"),
296                },
297                Header {
298                    name: Bytes::from("key_3"),
299                    value: Bytes::from("val_3"),
300                },
301                Header {
302                    name: Bytes::from("key_4"),
303                    value: Bytes::from("val_4"),
304                },
305            ],
306            Bytes::from("hello"),
307        );
308    }
309
310    #[test]
311    fn framed_no_headers() {
312        roundtrip_parts(vec![], Bytes::from("hello"));
313    }
314
315    #[test]
316    fn framed_duplicate_keys() {
317        // Duplicate keys preserved in original order.
318        roundtrip_parts(
319            vec![
320                Header {
321                    name: Bytes::from("b"),
322                    value: Bytes::from("val_1"),
323                },
324                Header {
325                    name: Bytes::from("b"),
326                    value: Bytes::from("val_2"),
327                },
328                Header {
329                    name: Bytes::from("a"),
330                    value: Bytes::from("val_3"),
331                },
332            ],
333            Bytes::from("hello"),
334        );
335    }
336
337    #[test]
338    fn flag_ex1() {
339        assert_eq!(
340            Ok(HeaderFlag {
341                num_headers_length_bytes: 2,
342                name_length_bytes: NonZeroU8::new(1).unwrap(),
343                value_length_bytes: NonZeroU8::new(1).unwrap(),
344            }),
345            0b00100000.try_into()
346        );
347
348        let u8_repr: u8 = HeaderFlag {
349            num_headers_length_bytes: 2,
350            name_length_bytes: NonZeroU8::new(1).unwrap(),
351            value_length_bytes: NonZeroU8::new(1).unwrap(),
352        }
353        .into();
354        assert_eq!(u8_repr, 0b00100000);
355    }
356
357    #[test]
358    fn flag_ex2() {
359        assert_eq!(
360            Ok(HeaderFlag {
361                num_headers_length_bytes: 1,
362                name_length_bytes: NonZeroU8::new(1).unwrap(),
363                value_length_bytes: NonZeroU8::new(1).unwrap(),
364            }),
365            0b00010000.try_into()
366        );
367
368        let u8_repr: u8 = HeaderFlag {
369            num_headers_length_bytes: 1,
370            name_length_bytes: NonZeroU8::new(1).unwrap(),
371            value_length_bytes: NonZeroU8::new(1).unwrap(),
372        }
373        .into();
374        assert_eq!(u8_repr, 0b00010000);
375    }
376
377    #[test]
378    fn empty_envelope_size() {
379        assert_eq!(
380            1,
381            EnvelopeRecord::try_from_parts(vec![], Bytes::new())
382                .unwrap()
383                .to_bytes()
384                .len()
385        );
386    }
387
388    #[test]
389    fn truncated_returns_error() {
390        let record = EnvelopeRecord::try_from_parts(
391            vec![Header {
392                name: Bytes::from("key"),
393                value: Bytes::from("value"),
394            }],
395            Bytes::new(),
396        )
397        .unwrap();
398        let encoded = record.to_bytes();
399
400        // Truncation anywhere before the end should error
401        // (with empty body, there's no trailing data that can be truncated safely).
402        for len in 1..encoded.len() {
403            let truncated = encoded.slice(..len);
404            assert!(
405                matches!(
406                    EnvelopeRecord::try_from(truncated),
407                    Err(InternalRecordError::Truncated(_))
408                ),
409                "expected Truncated error for len {len}"
410            );
411        }
412    }
413}