Skip to main content

s2_common/record/
envelope.rs

1use bytes::Bytes;
2
3use super::{Header, MeteredSize, RecordPartsError};
4use crate::deep_size::DeepSize;
5
6const MAX_HEADER_COUNT: usize = 0xFF_FFFF;
7const MAX_HEADER_NAME_OR_VALUE_LEN: usize = u32::MAX as usize;
8
9#[derive(Debug, PartialEq, thiserror::Error)]
10pub enum HeaderValidationError {
11    #[error("too many")]
12    TooMany,
13    #[error("too long")]
14    TooLong,
15    #[error("empty name")]
16    NameEmpty,
17}
18
19#[derive(PartialEq, Eq, Clone)]
20pub struct EnvelopeRecord {
21    headers: Vec<Header>,
22    body: Bytes,
23    header_sizing: HeaderSizing,
24}
25
26#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
27struct HeaderSizing(u64);
28
29impl HeaderSizing {
30    const TOTAL_BYTES_MASK: u64 = (1 << 60) - 1;
31    const NAME_LENGTH_WIDTH_SHIFT: u32 = 62;
32    const VALUE_LENGTH_WIDTH_SHIFT: u32 = 60;
33
34    fn new(total_bytes: usize, name_length_width: u8, value_length_width: u8) -> Self {
35        debug_assert!(total_bytes as u64 <= Self::TOTAL_BYTES_MASK);
36        debug_assert!((1..=4).contains(&name_length_width));
37        debug_assert!((1..=4).contains(&value_length_width));
38
39        Self(
40            total_bytes as u64
41                | (u64::from(name_length_width - 1) << Self::NAME_LENGTH_WIDTH_SHIFT)
42                | (u64::from(value_length_width - 1) << Self::VALUE_LENGTH_WIDTH_SHIFT),
43        )
44    }
45
46    fn total_bytes(self) -> usize {
47        (self.0 & Self::TOTAL_BYTES_MASK) as usize
48    }
49
50    fn name_length_width_bytes(self) -> usize {
51        (((self.0 >> Self::NAME_LENGTH_WIDTH_SHIFT) & 0b11) + 1) as usize
52    }
53
54    fn value_length_width_bytes(self) -> usize {
55        (((self.0 >> Self::VALUE_LENGTH_WIDTH_SHIFT) & 0b11) + 1) as usize
56    }
57}
58
59impl std::fmt::Debug for EnvelopeRecord {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("EnvelopeRecord")
62            .field("headers.len", &self.headers.len())
63            .field("body.len", &self.body.len())
64            .finish()
65    }
66}
67
68impl DeepSize for EnvelopeRecord {
69    fn deep_size(&self) -> usize {
70        self.headers.deep_size() + self.body.deep_size()
71    }
72}
73
74impl MeteredSize for EnvelopeRecord {
75    fn metered_size(&self) -> usize {
76        8 + (2 * self.headers.len()) + self.header_sizing.total_bytes() + self.body.len()
77    }
78}
79
80impl EnvelopeRecord {
81    pub fn headers(&self) -> &[Header] {
82        &self.headers
83    }
84
85    pub fn body(&self) -> &Bytes {
86        &self.body
87    }
88
89    /// Total bytes across all header names and values.
90    pub fn headers_total_bytes(&self) -> usize {
91        self.header_sizing.total_bytes()
92    }
93
94    #[doc(hidden)]
95    pub fn header_name_length_width_bytes(&self) -> usize {
96        self.header_sizing.name_length_width_bytes()
97    }
98
99    #[doc(hidden)]
100    pub fn header_value_length_width_bytes(&self) -> usize {
101        self.header_sizing.value_length_width_bytes()
102    }
103
104    pub fn into_parts(self) -> (Vec<Header>, Bytes) {
105        (self.headers, self.body)
106    }
107
108    pub fn try_from_parts(headers: Vec<Header>, body: Bytes) -> Result<Self, RecordPartsError> {
109        let header_sizing = validate_headers(&headers)?;
110        Ok(Self {
111            headers,
112            body,
113            header_sizing,
114        })
115    }
116}
117
118fn validate_headers(headers: &[Header]) -> Result<HeaderSizing, HeaderValidationError> {
119    if headers.len() > MAX_HEADER_COUNT {
120        return Err(HeaderValidationError::TooMany);
121    }
122
123    let mut total_bytes = 0usize;
124    let mut name_length_width_bytes = 1u8;
125    let mut value_length_width_bytes = 1u8;
126
127    for Header { name, value } in headers {
128        if name.is_empty() {
129            return Err(HeaderValidationError::NameEmpty);
130        }
131        if name.len() > MAX_HEADER_NAME_OR_VALUE_LEN || value.len() > MAX_HEADER_NAME_OR_VALUE_LEN {
132            return Err(HeaderValidationError::TooLong);
133        }
134
135        total_bytes = total_bytes
136            .checked_add(name.len())
137            .and_then(|total| total.checked_add(value.len()))
138            .ok_or(HeaderValidationError::TooLong)?;
139        if total_bytes as u64 > HeaderSizing::TOTAL_BYTES_MASK {
140            return Err(HeaderValidationError::TooLong);
141        }
142
143        name_length_width_bytes = name_length_width_bytes.max(length_width_bytes(name.len())?);
144        value_length_width_bytes = value_length_width_bytes.max(length_width_bytes(value.len())?);
145    }
146
147    Ok(HeaderSizing::new(
148        total_bytes,
149        name_length_width_bytes,
150        value_length_width_bytes,
151    ))
152}
153
154fn length_width_bytes(len: usize) -> Result<u8, HeaderValidationError> {
155    if len == 0 {
156        return Ok(1);
157    }
158
159    let width = 8 - len.leading_zeros() / 8;
160    if width <= 4 {
161        Ok(width as u8)
162    } else {
163        Err(HeaderValidationError::TooLong)
164    }
165}
166
167#[cfg(test)]
168mod test {
169    use bytes::Bytes;
170    use proptest::prelude::*;
171
172    use super::{
173        EnvelopeRecord, Header, HeaderSizing, HeaderValidationError, MeteredSize, RecordPartsError,
174        length_width_bytes,
175    };
176
177    fn assert_parts_preserved(headers: Vec<Header>, body: Bytes) {
178        let record = EnvelopeRecord::try_from_parts(headers.clone(), body.clone()).unwrap();
179        assert_eq!(record.headers(), headers);
180        assert_eq!(record.body(), &body);
181    }
182
183    #[test]
184    fn preserves_headers() {
185        assert_parts_preserved(
186            vec![
187                Header {
188                    name: Bytes::from("key_1"),
189                    value: Bytes::from("val_1"),
190                },
191                Header {
192                    name: Bytes::from("key_2"),
193                    value: Bytes::from("val_2"),
194                },
195                Header {
196                    name: Bytes::from("key_3"),
197                    value: Bytes::from("val_3"),
198                },
199                Header {
200                    name: Bytes::from("key_4"),
201                    value: Bytes::from("val_4"),
202                },
203            ],
204            Bytes::from("hello"),
205        );
206    }
207
208    #[test]
209    fn preserves_no_headers() {
210        assert_parts_preserved(vec![], Bytes::from("hello"));
211    }
212
213    #[test]
214    fn rejects_empty_header_name() {
215        assert_eq!(
216            EnvelopeRecord::try_from_parts(
217                vec![Header {
218                    name: Bytes::new(),
219                    value: Bytes::from_static(b"value"),
220                }],
221                Bytes::from_static(b"body"),
222            ),
223            Err(RecordPartsError::Header(HeaderValidationError::NameEmpty))
224        );
225    }
226
227    #[test]
228    fn preserves_duplicate_keys() {
229        // Duplicate keys preserved in original order.
230        assert_parts_preserved(
231            vec![
232                Header {
233                    name: Bytes::from("b"),
234                    value: Bytes::from("val_1"),
235                },
236                Header {
237                    name: Bytes::from("b"),
238                    value: Bytes::from("val_2"),
239                },
240                Header {
241                    name: Bytes::from("a"),
242                    value: Bytes::from("val_3"),
243                },
244            ],
245            Bytes::from("hello"),
246        );
247    }
248
249    #[test]
250    fn metered_size_uses_cached_header_bytes() {
251        let record = EnvelopeRecord::try_from_parts(
252            vec![
253                Header {
254                    name: Bytes::from("alpha"),
255                    value: Bytes::from("1"),
256                },
257                Header {
258                    name: Bytes::from("beta"),
259                    value: Bytes::from("two"),
260                },
261            ],
262            Bytes::from("body"),
263        )
264        .unwrap();
265
266        assert_eq!(
267            record.metered_size(),
268            8 + (2 * record.headers().len())
269                + ("alpha".len() + "1".len() + "beta".len() + "two".len())
270                + "body".len()
271        );
272    }
273
274    #[test]
275    fn header_sizing_is_cached_from_validated_headers() {
276        let long_name = Bytes::from(vec![b'n'; 256]);
277        let long_value = Bytes::from(vec![b'v'; 65_536]);
278        let record = EnvelopeRecord::try_from_parts(
279            vec![
280                Header {
281                    name: Bytes::from_static(b"a"),
282                    value: Bytes::from_static(b"value"),
283                },
284                Header {
285                    name: long_name.clone(),
286                    value: long_value.clone(),
287                },
288            ],
289            Bytes::from_static(b"body"),
290        )
291        .unwrap();
292
293        assert_eq!(
294            record.headers_total_bytes(),
295            "a".len() + "value".len() + long_name.len() + long_value.len()
296        );
297        assert_eq!(record.header_name_length_width_bytes(), 2);
298        assert_eq!(record.header_value_length_width_bytes(), 3);
299    }
300
301    proptest! {
302        #[test]
303        fn header_sizing_pack_roundtrips(
304            total_bytes in 0usize..=HeaderSizing::TOTAL_BYTES_MASK as usize,
305            name_length_width in 1u8..=4,
306            value_length_width in 1u8..=4,
307        ) {
308            let summary = HeaderSizing::new(
309                total_bytes,
310                name_length_width,
311                value_length_width,
312            );
313
314            prop_assert_eq!(summary.total_bytes(), total_bytes);
315            prop_assert_eq!(
316                summary.name_length_width_bytes(),
317                name_length_width as usize,
318            );
319            prop_assert_eq!(
320                summary.value_length_width_bytes(),
321                value_length_width as usize,
322            );
323        }
324    }
325
326    #[test]
327    fn length_width_bytes_covers_encoding_boundaries() {
328        assert_eq!(length_width_bytes(0), Ok(1));
329        assert_eq!(length_width_bytes(1), Ok(1));
330        assert_eq!(length_width_bytes(0xff), Ok(1));
331        assert_eq!(length_width_bytes(0x100), Ok(2));
332        assert_eq!(length_width_bytes(0xffff), Ok(2));
333        assert_eq!(length_width_bytes(0x1_0000), Ok(3));
334        assert_eq!(length_width_bytes(0xff_ffff), Ok(3));
335        assert_eq!(length_width_bytes(0x100_0000), Ok(4));
336        assert_eq!(length_width_bytes(u32::MAX as usize), Ok(4));
337
338        if let Some(too_long) = (u32::MAX as usize).checked_add(1) {
339            assert_eq!(
340                length_width_bytes(too_long),
341                Err(HeaderValidationError::TooLong)
342            );
343        }
344    }
345}