1use std::num::NonZeroU8;
41
42use bytes::{Buf, BufMut, Bytes, BytesMut};
43use s2_common::record::{
44 CommandOp, CommandPayloadError, CommandRecord, EnvelopeRecord, Header, HeaderValidationError,
45 RecordPartsError,
46};
47
48#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
49pub enum StoredRecordDecodeError {
50 #[error("truncated: {0}")]
51 Truncated(&'static str),
52 #[error("invalid value [{0}]: {1}")]
53 InvalidValue(&'static str, &'static str),
54}
55
56pub(crate) trait WireEncode {
57 fn to_bytes(&self) -> Bytes {
58 let expected_size = self.encoded_size();
59 let mut buf = BytesMut::with_capacity(expected_size);
60 self.encode_into(&mut buf);
61 assert_eq!(buf.len(), expected_size, "no reallocation");
62 buf.freeze()
63 }
64
65 fn encoded_size(&self) -> usize;
66
67 fn encode_into(&self, buf: &mut impl BufMut);
68}
69
70const COMMAND_ORDINAL_FENCE: u8 = 0;
71const COMMAND_ORDINAL_TRIM: u8 = 1;
72
73fn command_op_ordinal(op: CommandOp) -> u8 {
74 match op {
75 CommandOp::Fence => COMMAND_ORDINAL_FENCE,
76 CommandOp::Trim => COMMAND_ORDINAL_TRIM,
77 }
78}
79
80fn command_op_from_ordinal(ordinal: u8) -> Option<CommandOp> {
81 match ordinal {
82 COMMAND_ORDINAL_FENCE => Some(CommandOp::Fence),
83 COMMAND_ORDINAL_TRIM => Some(CommandOp::Trim),
84 _ => None,
85 }
86}
87
88impl From<CommandPayloadError> for StoredRecordDecodeError {
89 fn from(e: CommandPayloadError) -> Self {
90 match e {
91 CommandPayloadError::InvalidUtf8(_) => StoredRecordDecodeError::InvalidValue(
92 "CommandPayload",
93 "fencing token not valid utf8",
94 ),
95 CommandPayloadError::FencingTokenTooLong(_) => {
96 StoredRecordDecodeError::InvalidValue("CommandPayload", "fencing token too long")
97 }
98 CommandPayloadError::TrimPointSize(_) => {
99 StoredRecordDecodeError::InvalidValue("CommandPayload", "trim point size")
100 }
101 }
102 }
103}
104
105impl WireEncode for CommandRecord {
106 fn encoded_size(&self) -> usize {
107 1 + match self {
108 CommandRecord::Fence(token) => token.len(),
109 CommandRecord::Trim(trim_point) => size_of_val(trim_point),
110 }
111 }
112
113 fn encode_into(&self, buf: &mut impl BufMut) {
114 buf.put_u8(command_op_ordinal(self.op()));
115 match self {
116 CommandRecord::Fence(token) => {
117 buf.put_slice(token.as_bytes());
118 }
119 CommandRecord::Trim(trim_point) => {
120 buf.put_u64(*trim_point);
121 }
122 }
123 }
124}
125
126pub(super) fn decode_command_record(
127 record: &[u8],
128) -> Result<CommandRecord, StoredRecordDecodeError> {
129 if record.is_empty() {
130 return Err(StoredRecordDecodeError::Truncated("CommandOrdinal"));
131 }
132 let op = command_op_from_ordinal(record[0]).ok_or(StoredRecordDecodeError::InvalidValue(
133 "CommandOrdinal",
134 "unknown",
135 ))?;
136 CommandRecord::try_from_parts(op, &record[1..]).map_err(Into::into)
137}
138
139const EMPTY_HEADER_FLAG: HeaderFlag = HeaderFlag {
140 num_headers_length_bytes: 0,
141 name_length_bytes: NonZeroU8::new(1).unwrap(),
142 value_length_bytes: NonZeroU8::new(1).unwrap(),
143};
144
145#[derive(Debug, PartialEq, Eq, Clone, Copy)]
151struct HeaderFlag {
152 num_headers_length_bytes: u8,
153 name_length_bytes: NonZeroU8,
154 value_length_bytes: NonZeroU8,
155}
156
157impl HeaderFlag {
158 const RESERVED_MASK: u8 = 0b1100_0000;
159 const NUM_HEADERS_LENGTH_MASK: u8 = 0b0011_0000;
160 const NUM_HEADERS_LENGTH_SHIFT: u8 = 4;
161 const NAME_LENGTH_MASK: u8 = 0b0000_1100;
162 const NAME_LENGTH_SHIFT: u8 = 2;
163 const VALUE_LENGTH_MASK: u8 = 0b0000_0011;
164}
165
166impl From<HeaderFlag> for u8 {
167 fn from(value: HeaderFlag) -> Self {
168 (value.num_headers_length_bytes << HeaderFlag::NUM_HEADERS_LENGTH_SHIFT)
169 | ((value.name_length_bytes.get() - 1) << HeaderFlag::NAME_LENGTH_SHIFT)
170 | (value.value_length_bytes.get() - 1)
171 }
172}
173
174impl TryFrom<u8> for HeaderFlag {
175 type Error = &'static str;
176
177 fn try_from(value: u8) -> Result<Self, Self::Error> {
178 if (value & HeaderFlag::RESERVED_MASK) != 0 {
179 return Err("reserved bit set");
180 }
181 Ok(Self {
182 num_headers_length_bytes: (value & HeaderFlag::NUM_HEADERS_LENGTH_MASK)
183 >> HeaderFlag::NUM_HEADERS_LENGTH_SHIFT,
184 name_length_bytes: NonZeroU8::new(
185 ((value & HeaderFlag::NAME_LENGTH_MASK) >> HeaderFlag::NAME_LENGTH_SHIFT) + 1,
186 )
187 .unwrap(),
188 value_length_bytes: NonZeroU8::new((value & HeaderFlag::VALUE_LENGTH_MASK) + 1)
189 .unwrap(),
190 })
191 }
192}
193
194const EMPTY_HEADERS_ENCODING_INFO: EncodingInfo = EncodingInfo {
195 headers_total_bytes: 0,
196 flag: EMPTY_HEADER_FLAG,
197};
198
199#[derive(Debug, PartialEq, Eq, Clone, Copy)]
200struct EncodingInfo {
201 headers_total_bytes: usize,
202 flag: HeaderFlag,
203}
204
205impl EncodingInfo {
206 fn for_record(record: &EnvelopeRecord) -> Self {
207 Self::from_header_sizing(
208 record.headers().len(),
209 record.headers_total_bytes(),
210 record.header_name_length_width_bytes(),
211 record.header_value_length_width_bytes(),
212 )
213 .expect("envelope record headers should be validated")
214 }
215
216 fn from_header_sizing(
217 header_count: usize,
218 headers_total_bytes: usize,
219 name_length_width_bytes: usize,
220 value_length_width_bytes: usize,
221 ) -> Result<Self, HeaderValidationError> {
222 fn size_bytes_header_count(count: u64) -> Result<u8, HeaderValidationError> {
223 let size = 8 - count.leading_zeros() / 8;
224 if size <= 3 {
225 Ok(size as u8)
226 } else {
227 Err(HeaderValidationError::TooMany)
228 }
229 }
230
231 fn header_part_width(width: usize) -> Result<NonZeroU8, HeaderValidationError> {
232 let width = u8::try_from(width).map_err(|_| HeaderValidationError::TooLong)?;
233 if (1..=4).contains(&width) {
234 Ok(NonZeroU8::new(width).expect("header part width should be non-zero"))
235 } else {
236 Err(HeaderValidationError::TooLong)
237 }
238 }
239
240 if header_count == 0 {
241 return Ok(EMPTY_HEADERS_ENCODING_INFO);
242 }
243
244 let num_headers_length_bytes = size_bytes_header_count(header_count as u64)?;
245 let name_length_bytes = header_part_width(name_length_width_bytes)?;
246 let value_length_bytes = header_part_width(value_length_width_bytes)?;
247
248 Ok(Self {
249 headers_total_bytes,
250 flag: HeaderFlag {
251 num_headers_length_bytes,
252 name_length_bytes,
253 value_length_bytes,
254 },
255 })
256 }
257}
258
259impl WireEncode for EnvelopeRecord {
260 fn encoded_size(&self) -> usize {
261 let encoding_info = EncodingInfo::for_record(self);
262 1 + encoding_info.flag.num_headers_length_bytes as usize
263 + self.headers().len()
264 * (encoding_info.flag.name_length_bytes.get() as usize
265 + encoding_info.flag.value_length_bytes.get() as usize)
266 + encoding_info.headers_total_bytes
267 + self.body().len()
268 }
269
270 fn encode_into(&self, buf: &mut impl BufMut) {
271 let encoding_info = EncodingInfo::for_record(self);
272 buf.put_u8(encoding_info.flag.into());
273 buf.put_uint(
274 self.headers().len() as u64,
275 encoding_info.flag.num_headers_length_bytes as usize,
276 );
277 for Header { name, value } in self.headers() {
278 buf.put_uint(
279 name.len() as u64,
280 encoding_info.flag.name_length_bytes.get() as usize,
281 );
282 buf.put_slice(name);
283 buf.put_uint(
284 value.len() as u64,
285 encoding_info.flag.value_length_bytes.get() as usize,
286 );
287 buf.put_slice(value);
288 }
289 buf.put_slice(self.body());
290 }
291}
292
293pub(super) fn decode_envelope_record(
294 mut buf: Bytes,
295) -> Result<EnvelopeRecord, StoredRecordDecodeError> {
296 if buf.is_empty() {
297 return Err(StoredRecordDecodeError::InvalidValue(
298 "HeaderFlag",
299 "missing",
300 ));
301 }
302
303 let flag: HeaderFlag = buf
304 .get_u8()
305 .try_into()
306 .map_err(|info| StoredRecordDecodeError::InvalidValue("HeaderFlag", info))?;
307 if flag.num_headers_length_bytes == 0 {
308 return EnvelopeRecord::try_from_parts(vec![], buf).map_err(record_parts_decode_error);
309 }
310
311 let num_headers = buf
312 .try_get_uint(flag.num_headers_length_bytes as usize)
313 .map_err(|_| StoredRecordDecodeError::Truncated("NumHeaders"))?;
314 let num_headers = usize::try_from(num_headers)
315 .map_err(|_| StoredRecordDecodeError::InvalidValue("NumHeaders", "too many"))?;
316
317 let mut headers: Vec<Header> = Vec::with_capacity(num_headers);
318 for _ in 0..num_headers {
319 let name_len = buf
320 .try_get_uint(flag.name_length_bytes.get() as usize)
321 .map_err(|_| StoredRecordDecodeError::Truncated("HeaderNameLen"))?
322 as usize;
323 if name_len == 0 {
324 return Err(StoredRecordDecodeError::InvalidValue("HeaderName", "empty"));
325 }
326 if buf.remaining() < name_len {
327 return Err(StoredRecordDecodeError::Truncated("HeaderName"));
328 }
329 let name = buf.split_to(name_len);
330
331 let value_len = buf
332 .try_get_uint(flag.value_length_bytes.get() as usize)
333 .map_err(|_| StoredRecordDecodeError::Truncated("HeaderValueLen"))?
334 as usize;
335 if buf.remaining() < value_len {
336 return Err(StoredRecordDecodeError::Truncated("HeaderValue"));
337 }
338 let value = buf.split_to(value_len);
339
340 headers.push(Header { name, value })
341 }
342
343 EnvelopeRecord::try_from_parts(headers, buf).map_err(record_parts_decode_error)
344}
345
346fn record_parts_decode_error(error: RecordPartsError) -> StoredRecordDecodeError {
347 match error {
348 RecordPartsError::Header(HeaderValidationError::NameEmpty) => {
349 StoredRecordDecodeError::InvalidValue("HeaderName", "empty")
350 }
351 RecordPartsError::Header(HeaderValidationError::TooMany) => {
352 StoredRecordDecodeError::InvalidValue("NumHeaders", "too many")
353 }
354 RecordPartsError::Header(HeaderValidationError::TooLong) => {
355 StoredRecordDecodeError::InvalidValue("Header", "too long")
356 }
357 RecordPartsError::UnknownCommand | RecordPartsError::CommandPayload(_, _) => {
358 StoredRecordDecodeError::InvalidValue("EnvelopeRecord", "unexpected command record")
359 }
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use bytes::{BufMut, Bytes, BytesMut};
366 use rstest::rstest;
367 use s2_common::record::{FencingToken, FencingTokenTooLongError, SeqNum};
368
369 use super::*;
370
371 fn roundtrip_command(cmd: CommandRecord, expected_len: usize) {
372 assert_eq!(cmd.encoded_size(), expected_len);
373 let encoded = cmd.to_bytes();
374 assert_eq!(encoded.len(), expected_len);
375 assert_eq!(decode_command_record(encoded.as_ref()), Ok(cmd));
376 }
377
378 #[rstest]
379 #[case::empty("")]
380 #[case::arbit("arbitrary")]
381 #[case::full("0123456789012345")]
382 fn command_fence_roundtrip(#[case] token: &str) {
383 let cmd = CommandRecord::Fence(token.parse::<FencingToken>().unwrap());
384 roundtrip_command(cmd, 1 + token.len());
385 }
386
387 #[rstest]
388 #[case::zero(0)]
389 #[case::large(SeqNum::MAX)]
390 fn command_trim_roundtrip(#[case] trim_point: SeqNum) {
391 roundtrip_command(CommandRecord::Trim(trim_point), 1 + size_of::<SeqNum>());
392 }
393
394 #[test]
395 fn decode_invalid_command() {
396 let try_convert = |raw: &[u8]| decode_command_record(raw);
397 assert_eq!(
398 try_convert(&[]),
399 Err(StoredRecordDecodeError::Truncated("CommandOrdinal"))
400 );
401 assert_eq!(
402 try_convert(&[0xff]),
403 Err(StoredRecordDecodeError::InvalidValue(
404 "CommandOrdinal",
405 "unknown"
406 ))
407 );
408 assert_eq!(
409 try_convert(&[command_op_ordinal(CommandOp::Fence), 0xff, 0xff]),
410 Err(StoredRecordDecodeError::InvalidValue(
411 "CommandPayload",
412 "fencing token not valid utf8"
413 ))
414 );
415 assert_eq!(
416 try_convert(&[
417 command_op_ordinal(CommandOp::Fence),
418 b'0',
419 b'1',
420 b'2',
421 b'3',
422 b'4',
423 b'5',
424 b'6',
425 b'7',
426 b'8',
427 b'9',
428 b'0',
429 b'1',
430 b'2',
431 b'3',
432 b'4',
433 b'5',
434 b'6',
435 b'7',
436 b'8',
437 b'9',
438 b'0',
439 b'1',
440 b'2',
441 b'3',
442 b'4',
443 b'5',
444 b'6',
445 b'7',
446 b'8',
447 b'9',
448 b'0',
449 b'1',
450 b'2',
451 b'3',
452 b'4',
453 b'5',
454 b'6',
455 b'7',
456 b'8',
457 b'9',
458 ]),
459 Err(CommandPayloadError::FencingTokenTooLong(FencingTokenTooLongError(40)).into())
460 );
461 assert_eq!(
462 try_convert(&[command_op_ordinal(CommandOp::Trim), 0xff]),
463 Err(CommandPayloadError::TrimPointSize(1).into())
464 );
465 }
466
467 fn roundtrip_envelope_parts(headers: Vec<Header>, body: Bytes) {
468 let encoded: Bytes = EnvelopeRecord::try_from_parts(headers.clone(), body.clone())
469 .unwrap()
470 .to_bytes();
471 let decoded = decode_envelope_record(encoded).unwrap();
472 assert_eq!(decoded.headers(), headers);
473 assert_eq!(decoded.body(), &body);
474 }
475
476 #[test]
477 fn envelope_framed_with_headers() {
478 roundtrip_envelope_parts(
479 vec![
480 Header {
481 name: Bytes::from("key_1"),
482 value: Bytes::from("val_1"),
483 },
484 Header {
485 name: Bytes::from("key_2"),
486 value: Bytes::from("val_2"),
487 },
488 Header {
489 name: Bytes::from("key_3"),
490 value: Bytes::from("val_3"),
491 },
492 Header {
493 name: Bytes::from("key_4"),
494 value: Bytes::from("val_4"),
495 },
496 ],
497 Bytes::from("hello"),
498 );
499 }
500
501 #[test]
502 fn envelope_framed_no_headers() {
503 roundtrip_envelope_parts(vec![], Bytes::from("hello"));
504 }
505
506 #[test]
507 fn envelope_decode_rejects_empty_header_name() {
508 let mut encoded = BytesMut::new();
509 encoded.put_u8(
510 HeaderFlag {
511 num_headers_length_bytes: 1,
512 name_length_bytes: NonZeroU8::new(1).unwrap(),
513 value_length_bytes: NonZeroU8::new(1).unwrap(),
514 }
515 .into(),
516 );
517 encoded.put_u8(1);
518 encoded.put_u8(0);
519 encoded.put_u8(5);
520 encoded.put_slice(b"value");
521 encoded.put_slice(b"body");
522
523 assert_eq!(
524 decode_envelope_record(encoded.freeze()),
525 Err(StoredRecordDecodeError::InvalidValue("HeaderName", "empty"))
526 );
527 }
528
529 #[test]
530 fn envelope_framed_duplicate_keys() {
531 roundtrip_envelope_parts(
532 vec![
533 Header {
534 name: Bytes::from("b"),
535 value: Bytes::from("val_1"),
536 },
537 Header {
538 name: Bytes::from("b"),
539 value: Bytes::from("val_2"),
540 },
541 Header {
542 name: Bytes::from("a"),
543 value: Bytes::from("val_3"),
544 },
545 ],
546 Bytes::from("hello"),
547 );
548 }
549
550 #[test]
551 fn flag_ex1() {
552 assert_eq!(
553 Ok(HeaderFlag {
554 num_headers_length_bytes: 2,
555 name_length_bytes: NonZeroU8::new(1).unwrap(),
556 value_length_bytes: NonZeroU8::new(1).unwrap(),
557 }),
558 0b00100000.try_into()
559 );
560
561 let u8_repr: u8 = HeaderFlag {
562 num_headers_length_bytes: 2,
563 name_length_bytes: NonZeroU8::new(1).unwrap(),
564 value_length_bytes: NonZeroU8::new(1).unwrap(),
565 }
566 .into();
567 assert_eq!(u8_repr, 0b00100000);
568 }
569
570 #[test]
571 fn flag_ex2() {
572 assert_eq!(
573 Ok(HeaderFlag {
574 num_headers_length_bytes: 1,
575 name_length_bytes: NonZeroU8::new(1).unwrap(),
576 value_length_bytes: NonZeroU8::new(1).unwrap(),
577 }),
578 0b00010000.try_into()
579 );
580
581 let u8_repr: u8 = HeaderFlag {
582 num_headers_length_bytes: 1,
583 name_length_bytes: NonZeroU8::new(1).unwrap(),
584 value_length_bytes: NonZeroU8::new(1).unwrap(),
585 }
586 .into();
587 assert_eq!(u8_repr, 0b00010000);
588 }
589
590 #[rstest]
591 #[case::one_byte_widths(1, 1)]
592 #[case::two_byte_widths(2, 2)]
593 #[case::three_byte_widths(3, 3)]
594 #[case::four_byte_widths(4, 4)]
595 #[case::mixed_widths(2, 4)]
596 fn encoding_info_uses_cached_header_length_widths(
597 #[case] name_length_width_bytes: usize,
598 #[case] value_length_width_bytes: usize,
599 ) {
600 let encoding_info = EncodingInfo::from_header_sizing(
601 1,
602 42,
603 name_length_width_bytes,
604 value_length_width_bytes,
605 )
606 .unwrap();
607
608 assert_eq!(encoding_info.headers_total_bytes, 42);
609 assert_eq!(
610 encoding_info.flag,
611 HeaderFlag {
612 num_headers_length_bytes: 1,
613 name_length_bytes: NonZeroU8::new(name_length_width_bytes as u8).unwrap(),
614 value_length_bytes: NonZeroU8::new(value_length_width_bytes as u8).unwrap(),
615 }
616 );
617 }
618
619 #[rstest]
620 #[case::zero_name_width(0, 1)]
621 #[case::too_large_name_width(5, 1)]
622 #[case::zero_value_width(1, 0)]
623 #[case::too_large_value_width(1, 5)]
624 fn encoding_info_rejects_invalid_cached_header_length_widths(
625 #[case] name_length_width_bytes: usize,
626 #[case] value_length_width_bytes: usize,
627 ) {
628 assert_eq!(
629 EncodingInfo::from_header_sizing(
630 1,
631 42,
632 name_length_width_bytes,
633 value_length_width_bytes,
634 ),
635 Err(HeaderValidationError::TooLong)
636 );
637 }
638
639 #[test]
640 fn empty_envelope_size() {
641 assert_eq!(
642 1,
643 EnvelopeRecord::try_from_parts(vec![], Bytes::new())
644 .unwrap()
645 .to_bytes()
646 .len()
647 );
648 }
649
650 #[test]
651 fn truncated_envelope_returns_error() {
652 let record = EnvelopeRecord::try_from_parts(
653 vec![Header {
654 name: Bytes::from("key"),
655 value: Bytes::from("value"),
656 }],
657 Bytes::new(),
658 )
659 .unwrap();
660 let encoded = record.to_bytes();
661
662 for len in 1..encoded.len() {
663 let truncated = encoded.slice(..len);
664 assert!(
665 matches!(
666 decode_envelope_record(truncated),
667 Err(StoredRecordDecodeError::Truncated(_))
668 ),
669 "expected Truncated error for len {len}"
670 );
671 }
672 }
673}