1use bytes::Bytes;
2
3use super::{Header, MeteredSize, RecordPartsError};
4use crate::deep_size::DeepSize;
5
6const MAX_HEADER_COUNT: usize = 0xFF_FFFF;
7const MAX_HEADER_NAME_OR_VALUE_LEN: usize = u32::MAX as usize;
8
9#[derive(Debug, PartialEq, thiserror::Error)]
10pub enum HeaderValidationError {
11 #[error("too many")]
12 TooMany,
13 #[error("too long")]
14 TooLong,
15 #[error("empty name")]
16 NameEmpty,
17}
18
19#[derive(PartialEq, Eq, Clone)]
20pub struct EnvelopeRecord {
21 headers: Vec<Header>,
22 body: Bytes,
23 header_sizing: HeaderSizing,
24}
25
26#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
27struct HeaderSizing(u64);
28
29impl HeaderSizing {
30 const TOTAL_BYTES_MASK: u64 = (1 << 60) - 1;
31 const NAME_LENGTH_WIDTH_SHIFT: u32 = 62;
32 const VALUE_LENGTH_WIDTH_SHIFT: u32 = 60;
33
34 fn new(total_bytes: usize, name_length_width: u8, value_length_width: u8) -> Self {
35 debug_assert!(total_bytes as u64 <= Self::TOTAL_BYTES_MASK);
36 debug_assert!((1..=4).contains(&name_length_width));
37 debug_assert!((1..=4).contains(&value_length_width));
38
39 Self(
40 total_bytes as u64
41 | (u64::from(name_length_width - 1) << Self::NAME_LENGTH_WIDTH_SHIFT)
42 | (u64::from(value_length_width - 1) << Self::VALUE_LENGTH_WIDTH_SHIFT),
43 )
44 }
45
46 fn total_bytes(self) -> usize {
47 (self.0 & Self::TOTAL_BYTES_MASK) as usize
48 }
49
50 fn name_length_width_bytes(self) -> usize {
51 (((self.0 >> Self::NAME_LENGTH_WIDTH_SHIFT) & 0b11) + 1) as usize
52 }
53
54 fn value_length_width_bytes(self) -> usize {
55 (((self.0 >> Self::VALUE_LENGTH_WIDTH_SHIFT) & 0b11) + 1) as usize
56 }
57}
58
59impl std::fmt::Debug for EnvelopeRecord {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("EnvelopeRecord")
62 .field("headers.len", &self.headers.len())
63 .field("body.len", &self.body.len())
64 .finish()
65 }
66}
67
68impl DeepSize for EnvelopeRecord {
69 fn deep_size(&self) -> usize {
70 self.headers.deep_size() + self.body.deep_size()
71 }
72}
73
74impl MeteredSize for EnvelopeRecord {
75 fn metered_size(&self) -> usize {
76 8 + (2 * self.headers.len()) + self.header_sizing.total_bytes() + self.body.len()
77 }
78}
79
80impl EnvelopeRecord {
81 pub fn headers(&self) -> &[Header] {
82 &self.headers
83 }
84
85 pub fn body(&self) -> &Bytes {
86 &self.body
87 }
88
89 pub fn headers_total_bytes(&self) -> usize {
91 self.header_sizing.total_bytes()
92 }
93
94 #[doc(hidden)]
95 pub fn header_name_length_width_bytes(&self) -> usize {
96 self.header_sizing.name_length_width_bytes()
97 }
98
99 #[doc(hidden)]
100 pub fn header_value_length_width_bytes(&self) -> usize {
101 self.header_sizing.value_length_width_bytes()
102 }
103
104 pub fn into_parts(self) -> (Vec<Header>, Bytes) {
105 (self.headers, self.body)
106 }
107
108 pub fn try_from_parts(headers: Vec<Header>, body: Bytes) -> Result<Self, RecordPartsError> {
109 let header_sizing = validate_headers(&headers)?;
110 Ok(Self {
111 headers,
112 body,
113 header_sizing,
114 })
115 }
116}
117
118fn validate_headers(headers: &[Header]) -> Result<HeaderSizing, HeaderValidationError> {
119 if headers.len() > MAX_HEADER_COUNT {
120 return Err(HeaderValidationError::TooMany);
121 }
122
123 let mut total_bytes = 0usize;
124 let mut name_length_width_bytes = 1u8;
125 let mut value_length_width_bytes = 1u8;
126
127 for Header { name, value } in headers {
128 if name.is_empty() {
129 return Err(HeaderValidationError::NameEmpty);
130 }
131 if name.len() > MAX_HEADER_NAME_OR_VALUE_LEN || value.len() > MAX_HEADER_NAME_OR_VALUE_LEN {
132 return Err(HeaderValidationError::TooLong);
133 }
134
135 total_bytes = total_bytes
136 .checked_add(name.len())
137 .and_then(|total| total.checked_add(value.len()))
138 .ok_or(HeaderValidationError::TooLong)?;
139 if total_bytes as u64 > HeaderSizing::TOTAL_BYTES_MASK {
140 return Err(HeaderValidationError::TooLong);
141 }
142
143 name_length_width_bytes = name_length_width_bytes.max(length_width_bytes(name.len())?);
144 value_length_width_bytes = value_length_width_bytes.max(length_width_bytes(value.len())?);
145 }
146
147 Ok(HeaderSizing::new(
148 total_bytes,
149 name_length_width_bytes,
150 value_length_width_bytes,
151 ))
152}
153
154fn length_width_bytes(len: usize) -> Result<u8, HeaderValidationError> {
155 if len == 0 {
156 return Ok(1);
157 }
158
159 let width = 8 - len.leading_zeros() / 8;
160 if width <= 4 {
161 Ok(width as u8)
162 } else {
163 Err(HeaderValidationError::TooLong)
164 }
165}
166
167#[cfg(test)]
168mod test {
169 use bytes::Bytes;
170 use proptest::prelude::*;
171
172 use super::{
173 EnvelopeRecord, Header, HeaderSizing, HeaderValidationError, MeteredSize, RecordPartsError,
174 length_width_bytes,
175 };
176
177 fn assert_parts_preserved(headers: Vec<Header>, body: Bytes) {
178 let record = EnvelopeRecord::try_from_parts(headers.clone(), body.clone()).unwrap();
179 assert_eq!(record.headers(), headers);
180 assert_eq!(record.body(), &body);
181 }
182
183 #[test]
184 fn preserves_headers() {
185 assert_parts_preserved(
186 vec![
187 Header {
188 name: Bytes::from("key_1"),
189 value: Bytes::from("val_1"),
190 },
191 Header {
192 name: Bytes::from("key_2"),
193 value: Bytes::from("val_2"),
194 },
195 Header {
196 name: Bytes::from("key_3"),
197 value: Bytes::from("val_3"),
198 },
199 Header {
200 name: Bytes::from("key_4"),
201 value: Bytes::from("val_4"),
202 },
203 ],
204 Bytes::from("hello"),
205 );
206 }
207
208 #[test]
209 fn preserves_no_headers() {
210 assert_parts_preserved(vec![], Bytes::from("hello"));
211 }
212
213 #[test]
214 fn rejects_empty_header_name() {
215 assert_eq!(
216 EnvelopeRecord::try_from_parts(
217 vec![Header {
218 name: Bytes::new(),
219 value: Bytes::from_static(b"value"),
220 }],
221 Bytes::from_static(b"body"),
222 ),
223 Err(RecordPartsError::Header(HeaderValidationError::NameEmpty))
224 );
225 }
226
227 #[test]
228 fn preserves_duplicate_keys() {
229 assert_parts_preserved(
231 vec![
232 Header {
233 name: Bytes::from("b"),
234 value: Bytes::from("val_1"),
235 },
236 Header {
237 name: Bytes::from("b"),
238 value: Bytes::from("val_2"),
239 },
240 Header {
241 name: Bytes::from("a"),
242 value: Bytes::from("val_3"),
243 },
244 ],
245 Bytes::from("hello"),
246 );
247 }
248
249 #[test]
250 fn metered_size_uses_cached_header_bytes() {
251 let record = EnvelopeRecord::try_from_parts(
252 vec![
253 Header {
254 name: Bytes::from("alpha"),
255 value: Bytes::from("1"),
256 },
257 Header {
258 name: Bytes::from("beta"),
259 value: Bytes::from("two"),
260 },
261 ],
262 Bytes::from("body"),
263 )
264 .unwrap();
265
266 assert_eq!(
267 record.metered_size(),
268 8 + (2 * record.headers().len())
269 + ("alpha".len() + "1".len() + "beta".len() + "two".len())
270 + "body".len()
271 );
272 }
273
274 #[test]
275 fn header_sizing_is_cached_from_validated_headers() {
276 let long_name = Bytes::from(vec![b'n'; 256]);
277 let long_value = Bytes::from(vec![b'v'; 65_536]);
278 let record = EnvelopeRecord::try_from_parts(
279 vec![
280 Header {
281 name: Bytes::from_static(b"a"),
282 value: Bytes::from_static(b"value"),
283 },
284 Header {
285 name: long_name.clone(),
286 value: long_value.clone(),
287 },
288 ],
289 Bytes::from_static(b"body"),
290 )
291 .unwrap();
292
293 assert_eq!(
294 record.headers_total_bytes(),
295 "a".len() + "value".len() + long_name.len() + long_value.len()
296 );
297 assert_eq!(record.header_name_length_width_bytes(), 2);
298 assert_eq!(record.header_value_length_width_bytes(), 3);
299 }
300
301 proptest! {
302 #[test]
303 fn header_sizing_pack_roundtrips(
304 total_bytes in 0usize..=HeaderSizing::TOTAL_BYTES_MASK as usize,
305 name_length_width in 1u8..=4,
306 value_length_width in 1u8..=4,
307 ) {
308 let summary = HeaderSizing::new(
309 total_bytes,
310 name_length_width,
311 value_length_width,
312 );
313
314 prop_assert_eq!(summary.total_bytes(), total_bytes);
315 prop_assert_eq!(
316 summary.name_length_width_bytes(),
317 name_length_width as usize,
318 );
319 prop_assert_eq!(
320 summary.value_length_width_bytes(),
321 value_length_width as usize,
322 );
323 }
324 }
325
326 #[test]
327 fn length_width_bytes_covers_encoding_boundaries() {
328 assert_eq!(length_width_bytes(0), Ok(1));
329 assert_eq!(length_width_bytes(1), Ok(1));
330 assert_eq!(length_width_bytes(0xff), Ok(1));
331 assert_eq!(length_width_bytes(0x100), Ok(2));
332 assert_eq!(length_width_bytes(0xffff), Ok(2));
333 assert_eq!(length_width_bytes(0x1_0000), Ok(3));
334 assert_eq!(length_width_bytes(0xff_ffff), Ok(3));
335 assert_eq!(length_width_bytes(0x100_0000), Ok(4));
336 assert_eq!(length_width_bytes(u32::MAX as usize), Ok(4));
337
338 if let Some(too_long) = (u32::MAX as usize).checked_add(1) {
339 assert_eq!(
340 length_width_bytes(too_long),
341 Err(HeaderValidationError::TooLong)
342 );
343 }
344 }
345}