1use std::num::NonZeroU8;
2
3use bytes::{Buf, BufMut, Bytes};
4
5use super::{Encodable, Header, MeteredSize, RecordDecodeError, RecordPartsError};
6use crate::deep_size::DeepSize;
7
8#[derive(Debug, PartialEq, thiserror::Error)]
9pub enum HeaderValidationError {
10 #[error("too many")]
11 TooMany,
12 #[error("too long")]
13 TooLong,
14 #[error("empty name")]
15 NameEmpty,
16}
17
18#[derive(PartialEq, Eq, Clone)]
19pub struct EnvelopeRecord {
20 headers: Vec<Header>,
21 body: Bytes,
22 encoding_info: EncodingInfo,
23}
24
25impl std::fmt::Debug for EnvelopeRecord {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("EnvelopeRecord")
28 .field("headers.len", &self.headers.len())
29 .field("body.len", &self.body.len())
30 .finish()
31 }
32}
33
34impl DeepSize for EnvelopeRecord {
35 fn deep_size(&self) -> usize {
36 self.headers.deep_size() + self.body.deep_size()
37 }
38}
39
40impl MeteredSize for EnvelopeRecord {
41 fn metered_size(&self) -> usize {
42 8 + (2 * self.headers.len()) + self.encoding_info.headers_total_bytes + self.body.len()
43 }
44}
45
46impl EnvelopeRecord {
47 pub fn headers(&self) -> &[Header] {
48 &self.headers
49 }
50
51 pub fn body(&self) -> &Bytes {
52 &self.body
53 }
54
55 pub fn into_parts(self) -> (Vec<Header>, Bytes) {
56 (self.headers, self.body)
57 }
58
59 pub fn try_from_parts(headers: Vec<Header>, body: Bytes) -> Result<Self, RecordPartsError> {
60 let encoding_info = headers.as_slice().try_into()?;
61 Ok(Self {
62 headers,
63 body,
64 encoding_info,
65 })
66 }
67}
68
69impl Encodable for EnvelopeRecord {
70 fn encoded_size(&self) -> usize {
71 1 + self.encoding_info.flag.num_headers_length_bytes as usize
72 + self.headers.len()
73 * (self.encoding_info.flag.name_length_bytes.get() as usize
74 + self.encoding_info.flag.value_length_bytes.get() as usize)
75 + self.encoding_info.headers_total_bytes
76 + self.body.len()
77 }
78
79 fn encode_into(&self, buf: &mut impl BufMut) {
80 buf.put_u8(self.encoding_info.flag.into());
82 buf.put_uint(
83 self.headers.len() as u64,
84 self.encoding_info.flag.num_headers_length_bytes as usize,
85 );
86 for Header { name, value } in &self.headers {
88 buf.put_uint(
89 name.len() as u64,
90 self.encoding_info.flag.name_length_bytes.get() as usize,
91 );
92 buf.put_slice(name);
93 buf.put_uint(
94 value.len() as u64,
95 self.encoding_info.flag.value_length_bytes.get() as usize,
96 );
97 buf.put_slice(value);
98 }
99 buf.put_slice(&self.body);
100 }
101}
102
103impl TryFrom<Bytes> for EnvelopeRecord {
104 type Error = RecordDecodeError;
105
106 fn try_from(mut buf: Bytes) -> Result<Self, Self::Error> {
107 if buf.is_empty() {
108 return Err(RecordDecodeError::InvalidValue("HeaderFlag", "missing"));
109 }
110
111 let flag: HeaderFlag = buf
112 .get_u8()
113 .try_into()
114 .map_err(|info| RecordDecodeError::InvalidValue("HeaderFlag", info))?;
115 if flag.num_headers_length_bytes == 0 {
116 return Ok(Self {
117 encoding_info: EMPTY_HEADERS_ENCODING_INFO,
118 headers: vec![],
119 body: buf,
120 });
121 }
122
123 let num_headers = buf
124 .try_get_uint(flag.num_headers_length_bytes as usize)
125 .map_err(|_| RecordDecodeError::Truncated("NumHeaders"))?;
126
127 let mut headers_total_bytes = 0;
128 let mut headers: Vec<Header> = Vec::with_capacity(num_headers as usize);
129 for _ in 0..num_headers {
130 let name_len = buf
131 .try_get_uint(flag.name_length_bytes.get() as usize)
132 .map_err(|_| RecordDecodeError::Truncated("HeaderNameLen"))?
133 as usize;
134 if buf.remaining() < name_len {
135 return Err(RecordDecodeError::Truncated("HeaderName"));
136 }
137 let name = buf.split_to(name_len);
138
139 let value_len = buf
140 .try_get_uint(flag.value_length_bytes.get() as usize)
141 .map_err(|_| RecordDecodeError::Truncated("HeaderValueLen"))?
142 as usize;
143 if buf.remaining() < value_len {
144 return Err(RecordDecodeError::Truncated("HeaderValue"));
145 }
146 let value = buf.split_to(value_len);
147
148 headers_total_bytes += name.len() + value.len();
149 headers.push(Header { name, value })
150 }
151
152 Ok(Self {
153 encoding_info: EncodingInfo {
154 headers_total_bytes,
155 flag,
156 },
157 headers,
158 body: buf,
159 })
160 }
161}
162
163const EMPTY_HEADER_FLAG: HeaderFlag = HeaderFlag {
164 num_headers_length_bytes: 0,
165 name_length_bytes: NonZeroU8::new(1).unwrap(),
166 value_length_bytes: NonZeroU8::new(1).unwrap(),
167};
168
169#[derive(Debug, PartialEq, Eq, Clone, Copy)]
170struct HeaderFlag {
171 num_headers_length_bytes: u8,
172 name_length_bytes: NonZeroU8,
173 value_length_bytes: NonZeroU8,
174}
175
176impl From<HeaderFlag> for u8 {
177 fn from(value: HeaderFlag) -> Self {
178 (value.num_headers_length_bytes << 4)
179 | ((value.name_length_bytes.get() - 1) << 2)
180 | (value.value_length_bytes.get() - 1)
181 }
182}
183
184impl TryFrom<u8> for HeaderFlag {
185 type Error = &'static str;
186
187 fn try_from(value: u8) -> Result<Self, Self::Error> {
188 if (value & (0b11u8 << 6)) != 0u8 {
189 return Err("reserved bit set");
190 }
191 Ok(Self {
192 num_headers_length_bytes: (0b110000 & value) >> 4,
193 name_length_bytes: NonZeroU8::new(((0b1100 & value) >> 2) + 1).unwrap(),
194 value_length_bytes: NonZeroU8::new((0b11 & value) + 1).unwrap(),
195 })
196 }
197}
198
199const EMPTY_HEADERS_ENCODING_INFO: EncodingInfo = EncodingInfo {
200 headers_total_bytes: 0,
201 flag: EMPTY_HEADER_FLAG,
202};
203
204#[derive(Debug, PartialEq, Eq, Clone, Copy)]
205struct EncodingInfo {
206 headers_total_bytes: usize,
207 flag: HeaderFlag,
208}
209
210impl TryFrom<&[Header]> for EncodingInfo {
211 type Error = HeaderValidationError;
212
213 fn try_from(headers: &[Header]) -> Result<Self, Self::Error> {
214 fn size_bytes_headers_len(elems: u64) -> Result<u8, HeaderValidationError> {
217 let size = 8 - elems.leading_zeros() / 8;
218 if size > 3 {
219 Err(HeaderValidationError::TooMany)
220 } else {
221 Ok(size as u8)
222 }
223 }
224
225 fn size_bytes_name_value_len(elems: u64) -> Result<NonZeroU8, HeaderValidationError> {
228 if elems == 0 {
229 return Ok(NonZeroU8::new(1u8).unwrap());
230 }
231 let size = 8 - (elems.leading_zeros() / 8);
232 if size > 4 {
233 Err(HeaderValidationError::TooLong)
234 } else {
235 Ok(NonZeroU8::new(size as u8).unwrap())
236 }
237 }
238
239 if headers.is_empty() {
240 return Ok(EMPTY_HEADERS_ENCODING_INFO);
241 }
242
243 let (headers_total_bytes, name_max, value_max) = headers.iter().try_fold(
244 (0usize, 0usize, 0usize),
245 |(size_bytes_acc, name_max, value_max), Header { name, value }| {
246 if name.is_empty() {
247 return Err(HeaderValidationError::NameEmpty);
248 }
249 let name_len = name.len();
250 let value_len = value.len();
251 Ok((
252 size_bytes_acc + name_len + value_len,
253 name_max.max(name_len),
254 value_max.max(value_len),
255 ))
256 },
257 )?;
258
259 let num_headers_length_bytes = size_bytes_headers_len(headers.len() as u64)?;
260 let name_length_bytes = size_bytes_name_value_len(name_max as u64)?;
261 let value_length_bytes = size_bytes_name_value_len(value_max as u64)?;
262
263 Ok(Self {
264 headers_total_bytes,
265 flag: HeaderFlag {
266 num_headers_length_bytes,
267 name_length_bytes,
268 value_length_bytes,
269 },
270 })
271 }
272}
273
274#[cfg(test)]
275mod test {
276 use std::num::NonZeroU8;
277
278 use bytes::Bytes;
279
280 use super::{
281 Encodable as _, EnvelopeRecord, Header, HeaderFlag, MeteredSize, RecordDecodeError,
282 };
283
284 fn roundtrip_parts(headers: Vec<Header>, body: Bytes) {
285 let encoded: Bytes = EnvelopeRecord::try_from_parts(headers.clone(), body.clone())
286 .unwrap()
287 .to_bytes();
288 let decoded = EnvelopeRecord::try_from(encoded).unwrap();
289 assert_eq!(decoded.headers(), headers);
290 assert_eq!(decoded.body(), &body);
291 }
292
293 #[test]
294 fn framed_with_headers() {
295 roundtrip_parts(
296 vec![
297 Header {
298 name: Bytes::from("key_1"),
299 value: Bytes::from("val_1"),
300 },
301 Header {
302 name: Bytes::from("key_2"),
303 value: Bytes::from("val_2"),
304 },
305 Header {
306 name: Bytes::from("key_3"),
307 value: Bytes::from("val_3"),
308 },
309 Header {
310 name: Bytes::from("key_4"),
311 value: Bytes::from("val_4"),
312 },
313 ],
314 Bytes::from("hello"),
315 );
316 }
317
318 #[test]
319 fn framed_no_headers() {
320 roundtrip_parts(vec![], Bytes::from("hello"));
321 }
322
323 #[test]
324 fn framed_duplicate_keys() {
325 roundtrip_parts(
327 vec![
328 Header {
329 name: Bytes::from("b"),
330 value: Bytes::from("val_1"),
331 },
332 Header {
333 name: Bytes::from("b"),
334 value: Bytes::from("val_2"),
335 },
336 Header {
337 name: Bytes::from("a"),
338 value: Bytes::from("val_3"),
339 },
340 ],
341 Bytes::from("hello"),
342 );
343 }
344
345 #[test]
346 fn metered_size_uses_cached_header_bytes() {
347 let record = EnvelopeRecord::try_from_parts(
348 vec![
349 Header {
350 name: Bytes::from("alpha"),
351 value: Bytes::from("1"),
352 },
353 Header {
354 name: Bytes::from("beta"),
355 value: Bytes::from("two"),
356 },
357 ],
358 Bytes::from("body"),
359 )
360 .unwrap();
361
362 assert_eq!(
363 record.metered_size(),
364 8 + (2 * record.headers().len())
365 + ("alpha".len() + "1".len() + "beta".len() + "two".len())
366 + "body".len()
367 );
368 }
369
370 #[test]
371 fn flag_ex1() {
372 assert_eq!(
373 Ok(HeaderFlag {
374 num_headers_length_bytes: 2,
375 name_length_bytes: NonZeroU8::new(1).unwrap(),
376 value_length_bytes: NonZeroU8::new(1).unwrap(),
377 }),
378 0b00100000.try_into()
379 );
380
381 let u8_repr: u8 = HeaderFlag {
382 num_headers_length_bytes: 2,
383 name_length_bytes: NonZeroU8::new(1).unwrap(),
384 value_length_bytes: NonZeroU8::new(1).unwrap(),
385 }
386 .into();
387 assert_eq!(u8_repr, 0b00100000);
388 }
389
390 #[test]
391 fn flag_ex2() {
392 assert_eq!(
393 Ok(HeaderFlag {
394 num_headers_length_bytes: 1,
395 name_length_bytes: NonZeroU8::new(1).unwrap(),
396 value_length_bytes: NonZeroU8::new(1).unwrap(),
397 }),
398 0b00010000.try_into()
399 );
400
401 let u8_repr: u8 = HeaderFlag {
402 num_headers_length_bytes: 1,
403 name_length_bytes: NonZeroU8::new(1).unwrap(),
404 value_length_bytes: NonZeroU8::new(1).unwrap(),
405 }
406 .into();
407 assert_eq!(u8_repr, 0b00010000);
408 }
409
410 #[test]
411 fn empty_envelope_size() {
412 assert_eq!(
413 1,
414 EnvelopeRecord::try_from_parts(vec![], Bytes::new())
415 .unwrap()
416 .to_bytes()
417 .len()
418 );
419 }
420
421 #[test]
422 fn truncated_returns_error() {
423 let record = EnvelopeRecord::try_from_parts(
424 vec![Header {
425 name: Bytes::from("key"),
426 value: Bytes::from("value"),
427 }],
428 Bytes::new(),
429 )
430 .unwrap();
431 let encoded = record.to_bytes();
432
433 for len in 1..encoded.len() {
436 let truncated = encoded.slice(..len);
437 assert!(
438 matches!(
439 EnvelopeRecord::try_from(truncated),
440 Err(RecordDecodeError::Truncated(_))
441 ),
442 "expected Truncated error for len {len}"
443 );
444 }
445 }
446}