1mod batcher;
2mod command;
3mod envelope;
4mod fencing;
5mod metering;
6
7pub use batcher::{RecordBatch, RecordBatcher};
8use bytes::{Buf, BufMut, Bytes, BytesMut};
9pub use command::CommandRecord;
10use command::{CommandOp, CommandPayloadError};
11use enum_ordinalize::Ordinalize;
12pub use envelope::EnvelopeRecord;
13use envelope::HeaderValidationError;
14pub use fencing::{FencingToken, FencingTokenTooLongError, MAX_FENCING_TOKEN_LENGTH};
15pub use metering::{Metered, MeteredSize};
16
17use crate::deep_size::DeepSize;
18
19pub type SeqNum = u64;
20pub type NonZeroSeqNum = std::num::NonZeroU64;
21pub type Timestamp = u64;
22
23#[derive(Debug, PartialEq, Eq, Clone, Copy)]
24pub struct StreamPosition {
25 pub seq_num: SeqNum,
26 pub timestamp: Timestamp,
27}
28
29impl StreamPosition {
30 pub const MIN: StreamPosition = StreamPosition {
31 seq_num: SeqNum::MIN,
32 timestamp: Timestamp::MIN,
33 };
34}
35
36impl std::fmt::Display for StreamPosition {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 write!(f, "{} @ {}", self.seq_num, self.timestamp)
39 }
40}
41
42impl DeepSize for StreamPosition {
43 fn deep_size(&self) -> usize {
44 self.seq_num.deep_size() + self.timestamp.deep_size()
45 }
46}
47
48#[derive(Debug, Clone, PartialEq, thiserror::Error)]
49pub enum InternalRecordError {
50 #[error("truncated: {0}")]
51 Truncated(&'static str),
52 #[error("invalid value [{0}]: {1}")]
53 InvalidValue(&'static str, &'static str),
54}
55
56#[derive(Debug, PartialEq, thiserror::Error)]
58pub enum PublicRecordError {
59 #[error("unknown command")]
60 UnknownCommand,
61 #[error("invalid `{0}` command: {1}")]
62 CommandPayload(CommandOp, CommandPayloadError),
63 #[error("invalid header: {0}")]
64 Header(#[from] HeaderValidationError),
65}
66
67#[derive(Debug, Clone, PartialEq, Eq)]
68pub struct Header {
69 pub name: Bytes,
70 pub value: Bytes,
71}
72
73impl DeepSize for Header {
74 fn deep_size(&self) -> usize {
75 self.name.len() + self.value.len()
76 }
77}
78
79#[derive(Clone, Copy, Debug, PartialEq, Ordinalize)]
80#[repr(u8)]
81pub enum RecordType {
82 Command = 1,
83 Envelope = 2,
84}
85
86#[derive(Copy, Clone, Debug, PartialEq)]
87pub struct MagicByte {
88 pub record_type: RecordType,
89 pub metered_size_varlen: u8,
90}
91
92fn read_vint_u32_be(bytes: &[u8]) -> u32 {
94 if bytes.len() > size_of::<u32>() || bytes.is_empty() {
95 panic!("invalid variable int bytes = {} len", bytes.len())
96 }
97 let mut acc: u32 = 0;
98 for &byte in bytes {
99 acc = (acc << 8) | byte as u32;
100 }
101 acc
102}
103
104pub fn try_metered_size(record_bytes: &[u8]) -> Result<u32, &'static str> {
105 let magic_byte_u8 = *record_bytes.first().ok_or("byte range is empty")?;
106 let magic_byte = MagicByte::try_from(magic_byte_u8)?;
107 Ok(read_vint_u32_be(
108 record_bytes
109 .get(1..1 + magic_byte.metered_size_varlen as usize)
110 .ok_or("byte range doesn't include bytes for metered size")?,
111 ))
112}
113
114impl MeteredSize for Record {
115 fn metered_size(&self) -> usize {
116 8 + (match self {
117 Record::Command(command) => 2 + command.op().to_id().len() + command.payload().len(),
118 Record::Envelope(envelope) => {
119 (2 * envelope.headers().len())
120 + envelope.headers().deep_size()
121 + envelope.body().len()
122 }
123 })
124 }
125}
126
127impl TryFrom<u8> for MagicByte {
128 type Error = &'static str;
129
130 fn try_from(value: u8) -> Result<Self, Self::Error> {
131 let record_type =
132 RecordType::from_ordinal(value & 0b111).ok_or("invalid record type ordinal")?;
133 Ok(Self {
134 record_type,
135 metered_size_varlen: match (value >> 3) & 0b11 {
136 0 => 1u8,
137 1 => 2u8,
138 2 => 3u8,
139 _ => Err("invalid metered_size_varlen")?,
140 },
141 })
142 }
143}
144
145impl From<MagicByte> for u8 {
146 fn from(value: MagicByte) -> Self {
147 ((value.metered_size_varlen - 1) << 3) | value.record_type as u8
148 }
149}
150
151#[derive(Debug, PartialEq, Eq, Clone)]
152pub enum Record {
153 Command(CommandRecord),
154 Envelope(EnvelopeRecord),
155}
156
157impl DeepSize for Record {
158 fn deep_size(&self) -> usize {
159 match self {
160 Self::Command(c) => c.deep_size(),
161 Self::Envelope(e) => e.deep_size(),
162 }
163 }
164}
165
166impl Record {
167 pub fn try_from_parts(headers: Vec<Header>, body: Bytes) -> Result<Self, PublicRecordError> {
168 if headers.len() == 1 {
169 let header = &headers[0];
170 if header.name.is_empty() {
171 let op = CommandOp::from_id(header.value.as_ref())
172 .ok_or(PublicRecordError::UnknownCommand)?;
173 let command_record = CommandRecord::try_from_parts(op, body.as_ref())
174 .map_err(|e| PublicRecordError::CommandPayload(op, e))?;
175 return Ok(Self::Command(command_record));
176 }
177 }
178 let envelope = EnvelopeRecord::try_from_parts(headers, body)?;
179 Ok(Self::Envelope(envelope))
180 }
181
182 pub fn sequenced(self, position: StreamPosition) -> SequencedRecord {
183 SequencedRecord {
184 position,
185 record: self,
186 }
187 }
188
189 pub fn into_parts(self) -> (Vec<Header>, Bytes) {
190 match self {
191 Record::Envelope(e) => e.into_parts(),
192 Record::Command(c) => {
193 let op = c.op();
194 let header = Header {
195 name: Bytes::new(),
196 value: Bytes::from_static(op.to_id()),
197 };
198 (vec![header], c.payload())
199 }
200 }
201 }
202}
203
204pub fn decode_if_command_record(
205 record: &[u8],
206) -> Result<Option<CommandRecord>, InternalRecordError> {
207 if record.is_empty() {
208 return Err(InternalRecordError::Truncated("MagicByte"));
209 }
210 let magic_byte = MagicByte::try_from(record[0])
211 .map_err(|msg| InternalRecordError::InvalidValue("MagicByte", msg))?;
212 match magic_byte.record_type {
213 RecordType::Command => {
214 let offset = 1 + magic_byte.metered_size_varlen as usize;
215 if record.len() < offset {
216 return Err(InternalRecordError::Truncated("MeteredSize"));
217 }
218 Ok(Some(CommandRecord::try_from(&record[offset..])?))
219 }
220 RecordType::Envelope => Ok(None),
221 }
222}
223
224pub trait Encodable {
225 fn to_bytes(&self) -> Bytes {
226 let expected_size = self.encoded_size();
227 let mut buf = BytesMut::with_capacity(expected_size);
228 self.encode_into(&mut buf);
229 assert_eq!(buf.len(), expected_size, "no reallocation");
230 buf.freeze()
231 }
232
233 fn encoded_size(&self) -> usize;
234
235 fn encode_into(&self, buf: &mut impl BufMut);
236}
237
238impl Encodable for Metered<&Record> {
239 fn encoded_size(&self) -> usize {
240 1 + self.magic_byte().metered_size_varlen as usize
241 + match &**self {
242 Record::Command(r) => r.encoded_size(),
243 Record::Envelope(r) => r.encoded_size(),
244 }
245 }
246
247 fn encode_into(&self, buf: &mut impl BufMut) {
248 let magic_byte = self.magic_byte();
249 buf.put_u8(magic_byte.into());
250 buf.put_uint(
251 self.metered_size() as u64,
252 magic_byte.metered_size_varlen as usize,
253 );
254 match &**self {
255 Record::Command(r) => r.encode_into(buf),
256 Record::Envelope(r) => r.encode_into(buf),
257 }
258 }
259}
260
261#[derive(Debug, Clone, PartialEq, Eq)]
262pub struct SequencedRecord {
263 pub position: StreamPosition,
264 pub record: Record,
265}
266
267impl MeteredSize for SequencedRecord {
268 fn metered_size(&self) -> usize {
269 self.record.metered_size()
270 }
271}
272
273impl DeepSize for SequencedRecord {
274 fn deep_size(&self) -> usize {
275 self.position.deep_size() + self.record.deep_size()
276 }
277}
278
279impl Metered<Record> {
280 pub fn sequenced(self, position: StreamPosition) -> Metered<SequencedRecord> {
281 Metered {
282 size: self.metered_size(),
283 inner: self.inner.sequenced(position),
284 }
285 }
286}
287
288impl Metered<&Record> {
289 fn magic_byte(&self) -> MagicByte {
290 let metered_size = self.metered_size();
291 let metered_size_varlen = 8 - (metered_size.leading_zeros() / 8) as u8;
292 if metered_size_varlen > 3 {
293 panic!("illegal metered size varlen {metered_size} for record")
294 }
295 let record_type = match self.inner {
296 Record::Command(_) => RecordType::Command,
297 Record::Envelope(_) => RecordType::Envelope,
298 };
299 MagicByte {
300 record_type,
301 metered_size_varlen,
302 }
303 }
304}
305
306impl TryFrom<Bytes> for Metered<Record> {
307 type Error = InternalRecordError;
308
309 fn try_from(mut buf: Bytes) -> Result<Self, Self::Error> {
310 if buf.is_empty() {
311 return Err(InternalRecordError::Truncated("MagicByte"));
312 }
313 let magic_byte = MagicByte::try_from(buf.get_u8())
314 .map_err(|msg| InternalRecordError::InvalidValue("MagicByte", msg))?;
315
316 let metered_size = buf.get_uint(magic_byte.metered_size_varlen as usize) as usize;
317
318 Ok(Self {
319 size: metered_size,
320 inner: match magic_byte.record_type {
321 RecordType::Command => Record::Command(CommandRecord::try_from(buf.as_ref())?),
322 RecordType::Envelope => Record::Envelope(EnvelopeRecord::try_from(buf)?),
323 },
324 })
325 }
326}
327
328impl Metered<SequencedRecord> {
329 pub fn parts(&self) -> (StreamPosition, Metered<&Record>) {
330 (
331 self.position,
332 Metered {
333 size: self.size,
334 inner: &self.inner.record,
335 },
336 )
337 }
338
339 pub fn into_parts(self) -> (StreamPosition, Metered<Record>) {
340 (
341 self.position,
342 Metered {
343 size: self.size,
344 inner: self.inner.record,
345 },
346 )
347 }
348}
349
350#[cfg(test)]
351mod test {
352 use proptest::prelude::*;
353 use rstest::rstest;
354
355 use super::*;
356
357 fn bytes_strategy(allow_empty: bool) -> impl Strategy<Value = Bytes> {
358 prop_oneof![
359 prop::collection::vec(any::<u8>(), (if allow_empty { 0 } else { 1 })..10)
360 .prop_map(Bytes::from),
361 prop::collection::vec(any::<u8>(), 100..1000).prop_map(Bytes::from),
362 ]
363 }
364
365 fn header_strategy() -> impl Strategy<Value = Header> {
366 (bytes_strategy(false), bytes_strategy(true))
367 .prop_map(|(name, value)| Header { name, value })
368 }
369
370 fn headers_strategy() -> impl Strategy<Value = Vec<Header>> {
371 prop_oneof![
372 prop::collection::vec(header_strategy(), 0..10),
373 prop::collection::vec(header_strategy(), 200..300),
374 ]
375 }
376
377 proptest!(
378 #![proptest_config(ProptestConfig::with_cases(10))]
379 #[test]
380 fn roundtrip_envelope(
381 seq_num in any::<SeqNum>(),
382 timestamp in any::<Timestamp>(),
383 headers in headers_strategy(),
384 body in bytes_strategy(true),
385 ) {
386 let record = Record::try_from_parts(headers, body).unwrap();
387 let metered_record: Metered<Record> = record.clone().into();
388 let encoded_record = metered_record.as_ref().to_bytes();
389 let decoded_record = Metered::try_from(encoded_record).unwrap();
390 prop_assert_eq!(&decoded_record, &metered_record);
391 let sequenced = decoded_record.sequenced(StreamPosition { seq_num, timestamp });
392 assert_eq!(sequenced.position, StreamPosition {seq_num, timestamp});
393 assert_eq!(sequenced.record, record);
394 }
395 );
396
397 proptest!(
398 #![proptest_config(ProptestConfig::with_cases(10))]
399 #[test]
400 fn roundtrip_metered(
401 headers in headers_strategy(),
402 body in bytes_strategy(true),
403 ) {
404 let record = Record::try_from_parts(headers.clone(), body.clone()).unwrap();
405 let encoded_record = Metered::from(&record).to_bytes();
406 assert_eq!(record.metered_size(), try_metered_size(encoded_record.as_ref()).unwrap() as usize);
407 }
408 );
409
410 #[test]
411 fn empty_header_name_solo() {
412 let headers = vec![Header {
413 name: Bytes::new(),
414 value: Bytes::from("hi"),
415 }];
416 let body = Bytes::from("hello");
417 assert_eq!(
418 Record::try_from_parts(headers, body),
419 Err(PublicRecordError::UnknownCommand)
420 );
421 }
422
423 #[test]
424 fn empty_header_name_among_others() {
425 let headers = vec![
426 Header {
427 name: Bytes::from("boku"),
428 value: Bytes::from("hi"),
429 },
430 Header {
431 name: Bytes::new(),
432 value: Bytes::from("hi"),
433 },
434 ];
435 let body = Bytes::from("hello");
436 assert_eq!(
437 Record::try_from_parts(headers, body),
438 Err(PublicRecordError::Header(HeaderValidationError::NameEmpty))
439 );
440 }
441
442 #[rstest]
443 #[case::fence_empty(b"fence", b"")]
444 #[case::fence_uuid(b"fence", b"my-special-uuid")]
445 #[should_panic(expected = "FencingTokenTooLongError(49)")]
446 #[case::fence_too_long(b"fence", b"toolongtoolongtoolongtoolongtoolongtoolongtoolong")]
447 #[case::trim_0(b"trim", b"\x00\x00\x00\x00\x00\x00\x00\x00")]
448 #[should_panic(expected = "TrimPointSize(0)")]
449 #[case::trim_empty(b"trim", b"")]
450 #[should_panic(expected = "TrimPointSize(9)")]
451 #[case::trim_overflow(b"trim", b"\x00\x00\x00\x00\x00\x00\x00\x00\x00")]
452 fn command_records(#[case] op: &'static [u8], #[case] payload: &'static [u8]) {
453 let headers = vec![Header {
454 name: Bytes::new(),
455 value: Bytes::from_static(op),
456 }];
457 let body = Bytes::from_static(payload);
458 let record = Record::try_from_parts(headers.clone(), body.clone()).unwrap();
459 let record_metered = record.metered_size();
460 match &record {
461 Record::Command(cmd) => {
462 assert_eq!(cmd.op().to_id(), op);
463 assert_eq!(cmd.payload().as_ref(), payload);
464 }
465 Record::Envelope(e) => panic!("Command expected, got Envelope: {e:?}"),
466 }
467 let sequenced_record = record.sequenced(StreamPosition {
468 seq_num: 42,
469 timestamp: 100_000,
470 });
471 let sequenced_metered = sequenced_record.metered_size();
472 assert_eq!(record_metered, sequenced_metered);
473 assert_eq!(
474 sequenced_record.position,
475 StreamPosition {
476 seq_num: 42,
477 timestamp: 100_000,
478 }
479 );
480 assert_eq!(
481 sequenced_record.record,
482 Record::try_from_parts(headers, body).unwrap()
483 );
484 }
485
486 #[rstest]
487 #[case(0b0000_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 1})]
488 #[case(0b0001_0010, MagicByte { record_type: RecordType::Envelope, metered_size_varlen: 3})]
489 #[case(0b0000_1001, MagicByte { record_type: RecordType::Command, metered_size_varlen: 2})]
490 #[should_panic(expected = "invalid record type ordinal")]
491 #[case(0b0000_1101, MagicByte { record_type: RecordType::Command, metered_size_varlen: 2})]
492 fn magic_byte_parsing(#[case] as_u8: u8, #[case] magic_byte: MagicByte) {
493 assert_eq!(MagicByte::try_from(as_u8).unwrap(), magic_byte);
494 assert_eq!(u8::from(magic_byte), as_u8);
495 }
496
497 #[test]
498 fn test_read_varint() {
499 let data = [0u8, 0, 0, 1, 0, 0, 0];
500
501 assert_eq!(read_vint_u32_be(&data[..4]), 1u32);
502 assert_eq!(read_vint_u32_be(&data[2..5]), 2u32.pow(8));
503 assert_eq!(read_vint_u32_be(&data[2..6]), 2u32.pow(16));
504 assert_eq!(read_vint_u32_be(&data[3..]), 2u32.pow(24));
505 }
506}