1use std::num::NonZeroU8;
2
3use bytes::{Buf, BufMut, Bytes};
4
5use super::{Encodable, Header, MeteredSize, RecordDecodeError, RecordPartsError};
6use crate::deep_size::DeepSize;
7
8#[derive(Debug, PartialEq, thiserror::Error)]
9pub enum HeaderValidationError {
10 #[error("too many")]
11 TooMany,
12 #[error("too long")]
13 TooLong,
14 #[error("empty name")]
15 NameEmpty,
16}
17
18#[derive(PartialEq, Eq, Clone)]
19pub struct EnvelopeRecord {
20 headers: Vec<Header>,
21 body: Bytes,
22 encoding_info: EncodingInfo,
23}
24
25impl std::fmt::Debug for EnvelopeRecord {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("EnvelopeRecord")
28 .field("headers.len", &self.headers.len())
29 .field("body.len", &self.body.len())
30 .finish()
31 }
32}
33
34impl DeepSize for EnvelopeRecord {
35 fn deep_size(&self) -> usize {
36 self.headers.deep_size() + self.body.deep_size()
37 }
38}
39
40impl MeteredSize for EnvelopeRecord {
41 fn metered_size(&self) -> usize {
42 8 + (2 * self.headers.len()) + self.encoding_info.headers_total_bytes + self.body.len()
43 }
44}
45
46impl EnvelopeRecord {
47 pub fn headers(&self) -> &[Header] {
48 &self.headers
49 }
50
51 pub fn body(&self) -> &Bytes {
52 &self.body
53 }
54
55 pub fn into_parts(self) -> (Vec<Header>, Bytes) {
56 (self.headers, self.body)
57 }
58
59 pub fn try_from_parts(headers: Vec<Header>, body: Bytes) -> Result<Self, RecordPartsError> {
60 let encoding_info = headers.as_slice().try_into()?;
61 Ok(Self {
62 headers,
63 body,
64 encoding_info,
65 })
66 }
67}
68
69impl Encodable for EnvelopeRecord {
70 fn encoded_size(&self) -> usize {
71 1 + self.encoding_info.flag.num_headers_length_bytes as usize
72 + self.headers.len()
73 * (self.encoding_info.flag.name_length_bytes.get() as usize
74 + self.encoding_info.flag.value_length_bytes.get() as usize)
75 + self.encoding_info.headers_total_bytes
76 + self.body.len()
77 }
78
79 fn encode_into(&self, buf: &mut impl BufMut) {
80 buf.put_u8(self.encoding_info.flag.into());
82 buf.put_uint(
83 self.headers.len() as u64,
84 self.encoding_info.flag.num_headers_length_bytes as usize,
85 );
86 for Header { name, value } in &self.headers {
88 buf.put_uint(
89 name.len() as u64,
90 self.encoding_info.flag.name_length_bytes.get() as usize,
91 );
92 buf.put_slice(name);
93 buf.put_uint(
94 value.len() as u64,
95 self.encoding_info.flag.value_length_bytes.get() as usize,
96 );
97 buf.put_slice(value);
98 }
99 buf.put_slice(&self.body);
100 }
101}
102
103impl TryFrom<Bytes> for EnvelopeRecord {
104 type Error = RecordDecodeError;
105
106 fn try_from(mut buf: Bytes) -> Result<Self, Self::Error> {
107 if buf.is_empty() {
108 return Err(RecordDecodeError::InvalidValue("HeaderFlag", "missing"));
109 }
110
111 let flag: HeaderFlag = buf
112 .get_u8()
113 .try_into()
114 .map_err(|info| RecordDecodeError::InvalidValue("HeaderFlag", info))?;
115 if flag.num_headers_length_bytes == 0 {
116 return Ok(Self {
117 encoding_info: EMPTY_HEADERS_ENCODING_INFO,
118 headers: vec![],
119 body: buf,
120 });
121 }
122
123 let num_headers = buf
124 .try_get_uint(flag.num_headers_length_bytes as usize)
125 .map_err(|_| RecordDecodeError::Truncated("NumHeaders"))?;
126
127 let mut headers_total_bytes = 0;
128 let mut headers: Vec<Header> = Vec::with_capacity(num_headers as usize);
129 for _ in 0..num_headers {
130 let name_len = buf
131 .try_get_uint(flag.name_length_bytes.get() as usize)
132 .map_err(|_| RecordDecodeError::Truncated("HeaderNameLen"))?
133 as usize;
134 if name_len == 0 {
135 return Err(RecordDecodeError::InvalidValue("HeaderName", "empty"));
136 }
137 if buf.remaining() < name_len {
138 return Err(RecordDecodeError::Truncated("HeaderName"));
139 }
140 let name = buf.split_to(name_len);
141
142 let value_len = buf
143 .try_get_uint(flag.value_length_bytes.get() as usize)
144 .map_err(|_| RecordDecodeError::Truncated("HeaderValueLen"))?
145 as usize;
146 if buf.remaining() < value_len {
147 return Err(RecordDecodeError::Truncated("HeaderValue"));
148 }
149 let value = buf.split_to(value_len);
150
151 headers_total_bytes += name.len() + value.len();
152 headers.push(Header { name, value })
153 }
154
155 Ok(Self {
156 encoding_info: EncodingInfo {
157 headers_total_bytes,
158 flag,
159 },
160 headers,
161 body: buf,
162 })
163 }
164}
165
166const EMPTY_HEADER_FLAG: HeaderFlag = HeaderFlag {
167 num_headers_length_bytes: 0,
168 name_length_bytes: NonZeroU8::new(1).unwrap(),
169 value_length_bytes: NonZeroU8::new(1).unwrap(),
170};
171
172#[derive(Debug, PartialEq, Eq, Clone, Copy)]
173struct HeaderFlag {
174 num_headers_length_bytes: u8,
175 name_length_bytes: NonZeroU8,
176 value_length_bytes: NonZeroU8,
177}
178
179impl From<HeaderFlag> for u8 {
180 fn from(value: HeaderFlag) -> Self {
181 (value.num_headers_length_bytes << 4)
182 | ((value.name_length_bytes.get() - 1) << 2)
183 | (value.value_length_bytes.get() - 1)
184 }
185}
186
187impl TryFrom<u8> for HeaderFlag {
188 type Error = &'static str;
189
190 fn try_from(value: u8) -> Result<Self, Self::Error> {
191 if (value & (0b11u8 << 6)) != 0u8 {
192 return Err("reserved bit set");
193 }
194 Ok(Self {
195 num_headers_length_bytes: (0b110000 & value) >> 4,
196 name_length_bytes: NonZeroU8::new(((0b1100 & value) >> 2) + 1).unwrap(),
197 value_length_bytes: NonZeroU8::new((0b11 & value) + 1).unwrap(),
198 })
199 }
200}
201
202const EMPTY_HEADERS_ENCODING_INFO: EncodingInfo = EncodingInfo {
203 headers_total_bytes: 0,
204 flag: EMPTY_HEADER_FLAG,
205};
206
207#[derive(Debug, PartialEq, Eq, Clone, Copy)]
208struct EncodingInfo {
209 headers_total_bytes: usize,
210 flag: HeaderFlag,
211}
212
213impl TryFrom<&[Header]> for EncodingInfo {
214 type Error = HeaderValidationError;
215
216 fn try_from(headers: &[Header]) -> Result<Self, Self::Error> {
217 fn size_bytes_headers_len(elems: u64) -> Result<u8, HeaderValidationError> {
220 let size = 8 - elems.leading_zeros() / 8;
221 if size > 3 {
222 Err(HeaderValidationError::TooMany)
223 } else {
224 Ok(size as u8)
225 }
226 }
227
228 fn size_bytes_name_value_len(elems: u64) -> Result<NonZeroU8, HeaderValidationError> {
231 if elems == 0 {
232 return Ok(NonZeroU8::new(1u8).unwrap());
233 }
234 let size = 8 - (elems.leading_zeros() / 8);
235 if size > 4 {
236 Err(HeaderValidationError::TooLong)
237 } else {
238 Ok(NonZeroU8::new(size as u8).unwrap())
239 }
240 }
241
242 if headers.is_empty() {
243 return Ok(EMPTY_HEADERS_ENCODING_INFO);
244 }
245
246 let (headers_total_bytes, name_max, value_max) = headers.iter().try_fold(
247 (0usize, 0usize, 0usize),
248 |(size_bytes_acc, name_max, value_max), Header { name, value }| {
249 if name.is_empty() {
250 return Err(HeaderValidationError::NameEmpty);
251 }
252 let name_len = name.len();
253 let value_len = value.len();
254 Ok((
255 size_bytes_acc + name_len + value_len,
256 name_max.max(name_len),
257 value_max.max(value_len),
258 ))
259 },
260 )?;
261
262 let num_headers_length_bytes = size_bytes_headers_len(headers.len() as u64)?;
263 let name_length_bytes = size_bytes_name_value_len(name_max as u64)?;
264 let value_length_bytes = size_bytes_name_value_len(value_max as u64)?;
265
266 Ok(Self {
267 headers_total_bytes,
268 flag: HeaderFlag {
269 num_headers_length_bytes,
270 name_length_bytes,
271 value_length_bytes,
272 },
273 })
274 }
275}
276
277#[cfg(test)]
278mod test {
279 use std::num::NonZeroU8;
280
281 use bytes::{BufMut, Bytes, BytesMut};
282
283 use super::{
284 Encodable as _, EnvelopeRecord, Header, HeaderFlag, MeteredSize, RecordDecodeError,
285 };
286
287 fn roundtrip_parts(headers: Vec<Header>, body: Bytes) {
288 let encoded: Bytes = EnvelopeRecord::try_from_parts(headers.clone(), body.clone())
289 .unwrap()
290 .to_bytes();
291 let decoded = EnvelopeRecord::try_from(encoded).unwrap();
292 assert_eq!(decoded.headers(), headers);
293 assert_eq!(decoded.body(), &body);
294 }
295
296 #[test]
297 fn framed_with_headers() {
298 roundtrip_parts(
299 vec![
300 Header {
301 name: Bytes::from("key_1"),
302 value: Bytes::from("val_1"),
303 },
304 Header {
305 name: Bytes::from("key_2"),
306 value: Bytes::from("val_2"),
307 },
308 Header {
309 name: Bytes::from("key_3"),
310 value: Bytes::from("val_3"),
311 },
312 Header {
313 name: Bytes::from("key_4"),
314 value: Bytes::from("val_4"),
315 },
316 ],
317 Bytes::from("hello"),
318 );
319 }
320
321 #[test]
322 fn framed_no_headers() {
323 roundtrip_parts(vec![], Bytes::from("hello"));
324 }
325
326 #[test]
327 fn decode_rejects_empty_header_name() {
328 let mut encoded = BytesMut::new();
329 encoded.put_u8(
330 HeaderFlag {
331 num_headers_length_bytes: 1,
332 name_length_bytes: NonZeroU8::new(1).unwrap(),
333 value_length_bytes: NonZeroU8::new(1).unwrap(),
334 }
335 .into(),
336 );
337 encoded.put_u8(1); encoded.put_u8(0); encoded.put_u8(5); encoded.put_slice(b"value");
341 encoded.put_slice(b"body");
342
343 assert_eq!(
344 EnvelopeRecord::try_from(encoded.freeze()),
345 Err(RecordDecodeError::InvalidValue("HeaderName", "empty"))
346 );
347 }
348
349 #[test]
350 fn framed_duplicate_keys() {
351 roundtrip_parts(
353 vec![
354 Header {
355 name: Bytes::from("b"),
356 value: Bytes::from("val_1"),
357 },
358 Header {
359 name: Bytes::from("b"),
360 value: Bytes::from("val_2"),
361 },
362 Header {
363 name: Bytes::from("a"),
364 value: Bytes::from("val_3"),
365 },
366 ],
367 Bytes::from("hello"),
368 );
369 }
370
371 #[test]
372 fn metered_size_uses_cached_header_bytes() {
373 let record = EnvelopeRecord::try_from_parts(
374 vec![
375 Header {
376 name: Bytes::from("alpha"),
377 value: Bytes::from("1"),
378 },
379 Header {
380 name: Bytes::from("beta"),
381 value: Bytes::from("two"),
382 },
383 ],
384 Bytes::from("body"),
385 )
386 .unwrap();
387
388 assert_eq!(
389 record.metered_size(),
390 8 + (2 * record.headers().len())
391 + ("alpha".len() + "1".len() + "beta".len() + "two".len())
392 + "body".len()
393 );
394 }
395
396 #[test]
397 fn flag_ex1() {
398 assert_eq!(
399 Ok(HeaderFlag {
400 num_headers_length_bytes: 2,
401 name_length_bytes: NonZeroU8::new(1).unwrap(),
402 value_length_bytes: NonZeroU8::new(1).unwrap(),
403 }),
404 0b00100000.try_into()
405 );
406
407 let u8_repr: u8 = HeaderFlag {
408 num_headers_length_bytes: 2,
409 name_length_bytes: NonZeroU8::new(1).unwrap(),
410 value_length_bytes: NonZeroU8::new(1).unwrap(),
411 }
412 .into();
413 assert_eq!(u8_repr, 0b00100000);
414 }
415
416 #[test]
417 fn flag_ex2() {
418 assert_eq!(
419 Ok(HeaderFlag {
420 num_headers_length_bytes: 1,
421 name_length_bytes: NonZeroU8::new(1).unwrap(),
422 value_length_bytes: NonZeroU8::new(1).unwrap(),
423 }),
424 0b00010000.try_into()
425 );
426
427 let u8_repr: u8 = HeaderFlag {
428 num_headers_length_bytes: 1,
429 name_length_bytes: NonZeroU8::new(1).unwrap(),
430 value_length_bytes: NonZeroU8::new(1).unwrap(),
431 }
432 .into();
433 assert_eq!(u8_repr, 0b00010000);
434 }
435
436 #[test]
437 fn empty_envelope_size() {
438 assert_eq!(
439 1,
440 EnvelopeRecord::try_from_parts(vec![], Bytes::new())
441 .unwrap()
442 .to_bytes()
443 .len()
444 );
445 }
446
447 #[test]
448 fn truncated_returns_error() {
449 let record = EnvelopeRecord::try_from_parts(
450 vec![Header {
451 name: Bytes::from("key"),
452 value: Bytes::from("value"),
453 }],
454 Bytes::new(),
455 )
456 .unwrap();
457 let encoded = record.to_bytes();
458
459 for len in 1..encoded.len() {
462 let truncated = encoded.slice(..len);
463 assert!(
464 matches!(
465 EnvelopeRecord::try_from(truncated),
466 Err(RecordDecodeError::Truncated(_))
467 ),
468 "expected Truncated error for len {len}"
469 );
470 }
471 }
472}