1use std::num::NonZeroU8;
2
3use bytes::{Buf, BufMut, Bytes};
4
5use super::{Encodable, Header, InternalRecordError, PublicRecordError};
6use crate::deep_size::DeepSize;
7
8#[derive(Debug, PartialEq, thiserror::Error)]
9pub enum HeaderValidationError {
10 #[error("too many")]
11 TooMany,
12 #[error("too long")]
13 TooLong,
14 #[error("empty name")]
15 NameEmpty,
16}
17
18#[derive(PartialEq, Eq, Clone)]
19pub struct EnvelopeRecord {
20 headers: Vec<Header>,
21 body: Bytes,
22 encoding_info: EncodingInfo,
23}
24
25impl std::fmt::Debug for EnvelopeRecord {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("EnvelopeRecord")
28 .field("headers.len", &self.headers.len())
29 .field("body.len", &self.body.len())
30 .finish()
31 }
32}
33
34impl DeepSize for EnvelopeRecord {
35 fn deep_size(&self) -> usize {
36 self.headers.deep_size() + self.body.deep_size()
37 }
38}
39
40impl EnvelopeRecord {
41 pub fn headers(&self) -> &[Header] {
42 &self.headers
43 }
44
45 pub fn body(&self) -> &Bytes {
46 &self.body
47 }
48
49 pub fn into_parts(self) -> (Vec<Header>, Bytes) {
50 (self.headers, self.body)
51 }
52
53 pub fn try_from_parts(headers: Vec<Header>, body: Bytes) -> Result<Self, PublicRecordError> {
54 let encoding_info = headers.as_slice().try_into()?;
55 Ok(Self {
56 headers,
57 body,
58 encoding_info,
59 })
60 }
61}
62
63impl Encodable for EnvelopeRecord {
64 fn encoded_size(&self) -> usize {
65 1 + self.encoding_info.flag.num_headers_length_bytes as usize
66 + self.headers.len()
67 * (self.encoding_info.flag.name_length_bytes.get() as usize
68 + self.encoding_info.flag.value_length_bytes.get() as usize)
69 + self.encoding_info.headers_total_bytes
70 + self.body.len()
71 }
72
73 fn encode_into(&self, buf: &mut impl BufMut) {
74 buf.put_u8(self.encoding_info.flag.into());
76 buf.put_uint(
77 self.headers.len() as u64,
78 self.encoding_info.flag.num_headers_length_bytes as usize,
79 );
80 for Header { name, value } in &self.headers {
82 buf.put_uint(
83 name.len() as u64,
84 self.encoding_info.flag.name_length_bytes.get() as usize,
85 );
86 buf.put_slice(name);
87 buf.put_uint(
88 value.len() as u64,
89 self.encoding_info.flag.value_length_bytes.get() as usize,
90 );
91 buf.put_slice(value);
92 }
93 buf.put_slice(&self.body);
94 }
95}
96
97impl TryFrom<Bytes> for EnvelopeRecord {
98 type Error = InternalRecordError;
99
100 fn try_from(mut buf: Bytes) -> Result<Self, Self::Error> {
101 if buf.is_empty() {
102 return Err(InternalRecordError::InvalidValue("HeaderFlag", "missing"));
103 }
104
105 let flag: HeaderFlag = buf
106 .get_u8()
107 .try_into()
108 .map_err(|info| InternalRecordError::InvalidValue("HeaderFlag", info))?;
109 if flag.num_headers_length_bytes == 0 {
110 return Ok(Self {
111 encoding_info: EMPTY_HEADERS_ENCODING_INFO,
112 headers: vec![],
113 body: buf,
114 });
115 }
116
117 let num_headers = buf
118 .try_get_uint(flag.num_headers_length_bytes as usize)
119 .map_err(|_| InternalRecordError::Truncated("NumHeaders"))?;
120
121 let mut headers_total_bytes = 0;
122 let mut headers: Vec<Header> = Vec::with_capacity(num_headers as usize);
123 for _ in 0..num_headers {
124 let name_len = buf
125 .try_get_uint(flag.name_length_bytes.get() as usize)
126 .map_err(|_| InternalRecordError::Truncated("HeaderNameLen"))?
127 as usize;
128 if buf.remaining() < name_len {
129 return Err(InternalRecordError::Truncated("HeaderName"));
130 }
131 let name = buf.split_to(name_len);
132
133 let value_len = buf
134 .try_get_uint(flag.value_length_bytes.get() as usize)
135 .map_err(|_| InternalRecordError::Truncated("HeaderValueLen"))?
136 as usize;
137 if buf.remaining() < value_len {
138 return Err(InternalRecordError::Truncated("HeaderValue"));
139 }
140 let value = buf.split_to(value_len);
141
142 headers_total_bytes += name.len() + value.len();
143 headers.push(Header { name, value })
144 }
145
146 Ok(Self {
147 encoding_info: EncodingInfo {
148 headers_total_bytes,
149 flag,
150 },
151 headers,
152 body: buf,
153 })
154 }
155}
156
157const EMPTY_HEADER_FLAG: HeaderFlag = HeaderFlag {
158 num_headers_length_bytes: 0,
159 name_length_bytes: NonZeroU8::new(1).unwrap(),
160 value_length_bytes: NonZeroU8::new(1).unwrap(),
161};
162
163#[derive(Debug, PartialEq, Eq, Clone, Copy)]
164struct HeaderFlag {
165 num_headers_length_bytes: u8,
166 name_length_bytes: NonZeroU8,
167 value_length_bytes: NonZeroU8,
168}
169
170impl From<HeaderFlag> for u8 {
171 fn from(value: HeaderFlag) -> Self {
172 (value.num_headers_length_bytes << 4)
173 | ((value.name_length_bytes.get() - 1) << 2)
174 | (value.value_length_bytes.get() - 1)
175 }
176}
177
178impl TryFrom<u8> for HeaderFlag {
179 type Error = &'static str;
180
181 fn try_from(value: u8) -> Result<Self, Self::Error> {
182 if (value & (0b11u8 << 6)) != 0u8 {
183 return Err("reserved bit set");
184 }
185 Ok(Self {
186 num_headers_length_bytes: (0b110000 & value) >> 4,
187 name_length_bytes: NonZeroU8::new(((0b1100 & value) >> 2) + 1).unwrap(),
188 value_length_bytes: NonZeroU8::new((0b11 & value) + 1).unwrap(),
189 })
190 }
191}
192
193const EMPTY_HEADERS_ENCODING_INFO: EncodingInfo = EncodingInfo {
194 headers_total_bytes: 0,
195 flag: EMPTY_HEADER_FLAG,
196};
197
198#[derive(Debug, PartialEq, Eq, Clone, Copy)]
199struct EncodingInfo {
200 headers_total_bytes: usize,
201 flag: HeaderFlag,
202}
203
204impl TryFrom<&[Header]> for EncodingInfo {
205 type Error = HeaderValidationError;
206
207 fn try_from(headers: &[Header]) -> Result<Self, Self::Error> {
208 fn size_bytes_headers_len(elems: u64) -> Result<u8, HeaderValidationError> {
211 let size = 8 - elems.leading_zeros() / 8;
212 if size > 3 {
213 Err(HeaderValidationError::TooMany)
214 } else {
215 Ok(size as u8)
216 }
217 }
218
219 fn size_bytes_name_value_len(elems: u64) -> Result<NonZeroU8, HeaderValidationError> {
222 if elems == 0 {
223 return Ok(NonZeroU8::new(1u8).unwrap());
224 }
225 let size = 8 - (elems.leading_zeros() / 8);
226 if size > 4 {
227 Err(HeaderValidationError::TooLong)
228 } else {
229 Ok(NonZeroU8::new(size as u8).unwrap())
230 }
231 }
232
233 if headers.is_empty() {
234 return Ok(EMPTY_HEADERS_ENCODING_INFO);
235 }
236
237 let (headers_total_bytes, name_max, value_max) = headers.iter().try_fold(
238 (0usize, 0usize, 0usize),
239 |(size_bytes_acc, name_max, value_max), Header { name, value }| {
240 if name.is_empty() {
241 return Err(HeaderValidationError::NameEmpty);
242 }
243 let name_len = name.len();
244 let value_len = value.len();
245 Ok((
246 size_bytes_acc + name_len + value_len,
247 name_max.max(name_len),
248 value_max.max(value_len),
249 ))
250 },
251 )?;
252
253 let num_headers_length_bytes = size_bytes_headers_len(headers.len() as u64)?;
254 let name_length_bytes = size_bytes_name_value_len(name_max as u64)?;
255 let value_length_bytes = size_bytes_name_value_len(value_max as u64)?;
256
257 Ok(Self {
258 headers_total_bytes,
259 flag: HeaderFlag {
260 num_headers_length_bytes,
261 name_length_bytes,
262 value_length_bytes,
263 },
264 })
265 }
266}
267
268#[cfg(test)]
269mod test {
270 use std::num::NonZeroU8;
271
272 use bytes::Bytes;
273
274 use super::{Encodable as _, EnvelopeRecord, Header, HeaderFlag, InternalRecordError};
275
276 fn roundtrip_parts(headers: Vec<Header>, body: Bytes) {
277 let encoded: Bytes = EnvelopeRecord::try_from_parts(headers.clone(), body.clone())
278 .unwrap()
279 .to_bytes();
280 let decoded = EnvelopeRecord::try_from(encoded).unwrap();
281 assert_eq!(decoded.headers(), headers);
282 assert_eq!(decoded.body(), &body);
283 }
284
285 #[test]
286 fn framed_with_headers() {
287 roundtrip_parts(
288 vec![
289 Header {
290 name: Bytes::from("key_1"),
291 value: Bytes::from("val_1"),
292 },
293 Header {
294 name: Bytes::from("key_2"),
295 value: Bytes::from("val_2"),
296 },
297 Header {
298 name: Bytes::from("key_3"),
299 value: Bytes::from("val_3"),
300 },
301 Header {
302 name: Bytes::from("key_4"),
303 value: Bytes::from("val_4"),
304 },
305 ],
306 Bytes::from("hello"),
307 );
308 }
309
310 #[test]
311 fn framed_no_headers() {
312 roundtrip_parts(vec![], Bytes::from("hello"));
313 }
314
315 #[test]
316 fn framed_duplicate_keys() {
317 roundtrip_parts(
319 vec![
320 Header {
321 name: Bytes::from("b"),
322 value: Bytes::from("val_1"),
323 },
324 Header {
325 name: Bytes::from("b"),
326 value: Bytes::from("val_2"),
327 },
328 Header {
329 name: Bytes::from("a"),
330 value: Bytes::from("val_3"),
331 },
332 ],
333 Bytes::from("hello"),
334 );
335 }
336
337 #[test]
338 fn flag_ex1() {
339 assert_eq!(
340 Ok(HeaderFlag {
341 num_headers_length_bytes: 2,
342 name_length_bytes: NonZeroU8::new(1).unwrap(),
343 value_length_bytes: NonZeroU8::new(1).unwrap(),
344 }),
345 0b00100000.try_into()
346 );
347
348 let u8_repr: u8 = HeaderFlag {
349 num_headers_length_bytes: 2,
350 name_length_bytes: NonZeroU8::new(1).unwrap(),
351 value_length_bytes: NonZeroU8::new(1).unwrap(),
352 }
353 .into();
354 assert_eq!(u8_repr, 0b00100000);
355 }
356
357 #[test]
358 fn flag_ex2() {
359 assert_eq!(
360 Ok(HeaderFlag {
361 num_headers_length_bytes: 1,
362 name_length_bytes: NonZeroU8::new(1).unwrap(),
363 value_length_bytes: NonZeroU8::new(1).unwrap(),
364 }),
365 0b00010000.try_into()
366 );
367
368 let u8_repr: u8 = HeaderFlag {
369 num_headers_length_bytes: 1,
370 name_length_bytes: NonZeroU8::new(1).unwrap(),
371 value_length_bytes: NonZeroU8::new(1).unwrap(),
372 }
373 .into();
374 assert_eq!(u8_repr, 0b00010000);
375 }
376
377 #[test]
378 fn empty_envelope_size() {
379 assert_eq!(
380 1,
381 EnvelopeRecord::try_from_parts(vec![], Bytes::new())
382 .unwrap()
383 .to_bytes()
384 .len()
385 );
386 }
387
388 #[test]
389 fn truncated_returns_error() {
390 let record = EnvelopeRecord::try_from_parts(
391 vec![Header {
392 name: Bytes::from("key"),
393 value: Bytes::from("value"),
394 }],
395 Bytes::new(),
396 )
397 .unwrap();
398 let encoded = record.to_bytes();
399
400 for len in 1..encoded.len() {
403 let truncated = encoded.slice(..len);
404 assert!(
405 matches!(
406 EnvelopeRecord::try_from(truncated),
407 Err(InternalRecordError::Truncated(_))
408 ),
409 "expected Truncated error for len {len}"
410 );
411 }
412 }
413}