1use alloc::{borrow::Cow, boxed::Box, collections::BTreeSet, string::String, sync::Arc, vec::Vec};
2use core::cmp::Ordering;
3
4mod encode;
5mod validate;
6mod wire;
7
8use crate::{
9 CalendarInterval, Decimal, Duration, Envelope, EnvelopeMode, Error, ErrorKind, Field, Message,
10 Result, Schema, SchemaId, SchemaRegistry, TimestampPrecision, TpackValue, TypeDescriptor,
11 ValueMapEntry, Variant, empty_registry,
12};
13
14pub const MAGIC: [u8; 4] = *b"TPAK";
15pub const VERSION: u8 = 0x01;
16
17const NANOS_PER_DAY: u64 = 86_400_000_000_000;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum CanonicalMode {
21 Off,
22 Strict,
23}
24
25impl CanonicalMode {
26 pub fn is_strict(self) -> bool {
27 matches!(self, CanonicalMode::Strict)
28 }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub struct Limits {
34 pub max_schema_len: usize,
38 pub max_schema_id_len: usize,
39 pub max_depth: usize,
40 pub max_fields: usize,
41 pub max_variants: usize,
42 pub max_collection_len: usize,
43 pub max_string_len: usize,
44 pub max_bytes_len: usize,
45 pub max_extension_len: usize,
46 pub max_varint_bytes: usize,
47}
48
49impl Default for Limits {
50 fn default() -> Self {
51 Self {
52 max_schema_len: 1024 * 1024,
53 max_schema_id_len: 1024,
54 max_depth: 128,
55 max_fields: 16_384,
56 max_variants: 16_384,
57 max_collection_len: 1_000_000,
58 max_string_len: 16 * 1024 * 1024,
59 max_bytes_len: 16 * 1024 * 1024,
60 max_extension_len: 16 * 1024 * 1024,
61 max_varint_bytes: 10,
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub struct DecodeOptions {
69 pub canonical: CanonicalMode,
70 pub allow_schema_ref: bool,
71 pub validate_embedded_schema_on_cache_hit: bool,
78 pub limits: Limits,
79}
80
81impl Default for DecodeOptions {
82 fn default() -> Self {
83 Self {
84 canonical: CanonicalMode::Off,
85 allow_schema_ref: true,
86 validate_embedded_schema_on_cache_hit: true,
87 limits: Limits::default(),
88 }
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub struct EncodeOptions {
95 pub canonical: CanonicalMode,
96 pub limits: Limits,
97}
98
99impl Default for EncodeOptions {
100 fn default() -> Self {
101 Self {
102 canonical: CanonicalMode::Off,
103 limits: Limits::default(),
104 }
105 }
106}
107
108pub struct Decoder<'de> {
109 input: &'de [u8],
110 pos: usize,
111 options: DecodeOptions,
112}
113
114impl<'de> Decoder<'de> {
115 pub fn new(input: &'de [u8]) -> Self {
116 Self::with_options(input, DecodeOptions::default())
117 }
118
119 pub fn with_options(input: &'de [u8], options: DecodeOptions) -> Self {
120 Self {
121 input,
122 pos: 0,
123 options,
124 }
125 }
126
127 pub fn position(&self) -> usize {
128 self.pos
129 }
130
131 pub fn is_eof(&self) -> bool {
132 self.pos == self.input.len()
133 }
134
135 pub fn decode_message(&mut self) -> Result<Message<'de>> {
136 self.decode_message_with_registry(&empty_registry())
137 }
138
139 pub fn decode_message_with_registry<R: SchemaRegistry + ?Sized>(
140 &mut self,
141 registry: &R,
142 ) -> Result<Message<'de>> {
143 self.read_header()?;
144 let mode = match self.read_u8()? {
145 0x00 => EnvelopeMode::FullSchema,
146 0x01 => EnvelopeMode::FullSchemaWithId,
147 0x02 => EnvelopeMode::SchemaRef,
148 other => return Err(Error::new(ErrorKind::UnknownEnvelopeMode(other))),
149 };
150
151 let (schema_id, schema, used_cached_schema) = match mode {
152 EnvelopeMode::FullSchema => {
153 let schema = self.decode_schema_block()?;
154 (None, Arc::new(schema), false)
155 }
156 EnvelopeMode::FullSchemaWithId => {
157 let schema_id = self.read_schema_id(false)?;
158 let schema_len = self.read_len("schema length")?;
159 if schema_len > self.options.limits.max_schema_len {
160 return Err(Error::new(ErrorKind::SchemaLengthExceeded));
161 }
162
163 let schema_start = self.pos;
164 let schema_end = schema_start
165 .checked_add(schema_len)
166 .ok_or(Error::new(ErrorKind::SchemaLengthExceeded))?;
167 if schema_end > self.input.len() {
168 return Err(Error::new(ErrorKind::UnexpectedEof));
169 }
170
171 if let Some(schema) = registry.get(schema_id.as_bytes()) {
172 if self.options.validate_embedded_schema_on_cache_hit {
173 self.validate_cached_schema_bytes(schema_len, schema.as_ref())?;
174 } else {
175 self.pos = schema_end;
179 }
180 (Some(schema_id), schema, true)
181 } else {
182 let schema = self.decode_schema_at_exact_len(schema_len)?;
183 (Some(schema_id), Arc::new(schema), false)
184 }
185 }
186 EnvelopeMode::SchemaRef => {
187 if !self.options.allow_schema_ref {
188 return Err(Error::new(ErrorKind::SchemaRefNotAllowed));
189 }
190 let schema_id = self.read_schema_id(true)?;
191 let schema = registry
192 .get(schema_id.as_bytes())
193 .ok_or(Error::new(ErrorKind::UnknownSchemaId))?;
194 (Some(schema_id), schema, true)
195 }
196 };
197
198 let value = self.decode_value_for(&schema.root, 0)?;
199 if !self.is_eof() {
200 return Err(Error::new(ErrorKind::TrailingBytes));
201 }
202
203 Ok(Message {
204 envelope: Envelope {
205 mode,
206 schema_id,
207 used_cached_schema,
208 },
209 schema,
210 value,
211 })
212 }
213
214 pub fn decode_schema(&mut self) -> Result<Schema> {
215 let schema = Schema::new(self.decode_type_descriptor(0)?);
216 validate::validate_schema(&schema, &self.options.limits)?;
217 Ok(schema)
218 }
219
220 pub fn decode_value(&mut self, schema: &Schema) -> Result<TpackValue<'de>> {
221 let value = self.decode_value_for(&schema.root, 0)?;
222 if !self.is_eof() {
223 return Err(Error::new(ErrorKind::TrailingBytes));
224 }
225 Ok(value)
226 }
227
228 fn read_header(&mut self) -> Result<()> {
229 if self.read_bytes(4)? != MAGIC {
230 return Err(Error::new(ErrorKind::InvalidMagic));
231 }
232 let version = self.read_u8()?;
233 if version != VERSION {
234 return Err(Error::new(ErrorKind::UnsupportedVersion(version)));
235 }
236 Ok(())
237 }
238
239 fn decode_schema_block(&mut self) -> Result<Schema> {
240 let schema_len = self.read_len("schema length")?;
241 if schema_len > self.options.limits.max_schema_len {
242 return Err(Error::new(ErrorKind::SchemaLengthExceeded));
243 }
244 self.decode_schema_at_exact_len(schema_len)
245 }
246
247 fn decode_schema_at_exact_len(&mut self, schema_len: usize) -> Result<Schema> {
248 let start = self.pos;
249 let schema = self.decode_schema()?;
250 let consumed = self.pos - start;
251 if consumed != schema_len {
252 return Err(Error::new(ErrorKind::SchemaLengthMismatch));
253 }
254 Ok(schema)
255 }
256
257 fn validate_cached_schema_bytes(
258 &mut self,
259 schema_len: usize,
260 cached_schema: &Schema,
261 ) -> Result<()> {
262 let embedded_schema = self.decode_schema_at_exact_len(schema_len)?;
263 if &embedded_schema != cached_schema {
264 return Err(Error::new(ErrorKind::EmbeddedSchemaMismatch));
265 }
266 Ok(())
267 }
268
269 fn read_schema_id(&mut self, require_non_empty: bool) -> Result<SchemaId<'de>> {
270 let len = self.read_len("schema id length")?;
271 if len > self.options.limits.max_schema_id_len {
272 return Err(Error::new(ErrorKind::InvalidSchemaId));
273 }
274 if require_non_empty && len == 0 {
275 return Err(Error::new(ErrorKind::InvalidSchemaId));
276 }
277 Ok(SchemaId::borrowed(self.read_bytes(len)?))
278 }
279
280 fn decode_type_descriptor(&mut self, depth: usize) -> Result<TypeDescriptor> {
281 if depth > self.options.limits.max_depth {
282 return Err(Error::limit("schema depth"));
283 }
284 let tag = self.read_u8()?;
285 let ty = match tag {
286 0x00 => TypeDescriptor::Null,
287 0x01 => TypeDescriptor::Bool,
288 0x02 => TypeDescriptor::I8,
289 0x03 => TypeDescriptor::I16,
290 0x04 => TypeDescriptor::I32,
291 0x05 => TypeDescriptor::I64,
292 0x06 => TypeDescriptor::U8,
293 0x07 => TypeDescriptor::U16,
294 0x08 => TypeDescriptor::U32,
295 0x09 => TypeDescriptor::U64,
296 0x0A => TypeDescriptor::F32,
297 0x0B => TypeDescriptor::F64,
298 0x0C => TypeDescriptor::Decimal,
299 0x0D => {
300 let precision = self.read_uvarint()?;
301 let scale = self.read_uvarint()?;
302 if precision == 0 || scale > precision {
303 return Err(Error::new(ErrorKind::InvalidDecimalParameters));
304 }
305 TypeDescriptor::DecimalFixed { precision, scale }
306 }
307 0x0E => TypeDescriptor::String {
308 max_len: Some(self.read_uvarint()?),
309 },
310 0x0F => TypeDescriptor::String { max_len: None },
311 0x10 => TypeDescriptor::Bytes {
312 max_len: Some(self.read_uvarint()?),
313 },
314 0x11 => TypeDescriptor::Bytes { max_len: None },
315 0x12 => TypeDescriptor::Date,
316 0x13 => TypeDescriptor::Time,
317 0x14 => TypeDescriptor::DateTime,
318 0x15 => TypeDescriptor::DateTimeTz,
319 0x16 => {
320 let precision = match self.read_u8()? {
321 0 => TimestampPrecision::Seconds,
322 1 => TimestampPrecision::Milliseconds,
323 2 => TimestampPrecision::Microseconds,
324 3 => TimestampPrecision::Nanoseconds,
325 other => return Err(Error::new(ErrorKind::InvalidTimestampPrecision(other))),
326 };
327 TypeDescriptor::Timestamp(precision)
328 }
329 0x17 => TypeDescriptor::Duration,
330 0x18 => TypeDescriptor::BigInt,
331 0x19 => TypeDescriptor::BigUInt,
332 0x1A => TypeDescriptor::CalendarInterval,
333 0x20 => {
334 let count = self.read_count("struct field count")?;
335 if count > self.options.limits.max_fields {
336 return Err(Error::limit("struct field count"));
337 }
338 let mut fields = Vec::with_capacity(count);
339 let mut seen_ids = BTreeSet::new();
340 let mut seen_names = BTreeSet::new();
341 for _ in 0..count {
342 let id = self.read_uvarint()?;
343 if id == 0 {
344 return Err(Error::new(ErrorKind::StructFieldIdZero));
345 }
346 let name = self.read_text_owned()?;
347 if name.is_empty() {
348 return Err(Error::new(ErrorKind::StructFieldNameEmpty));
349 }
350 let flags = self.read_uvarint()?;
351 if flags != 0 {
352 return Err(Error::new(ErrorKind::StructFieldFlagsNonZero(flags)));
353 }
354 let ty = self.decode_type_descriptor(depth + 1)?;
355 if !seen_ids.insert(id) || !seen_names.insert(name.clone()) {
356 return Err(Error::new(ErrorKind::DuplicateStructFieldDefinition));
357 }
358 fields.push(Field { id, name, ty });
359 }
360 TypeDescriptor::Struct(fields)
361 }
362 0x21 => {
363 let max_count = wire::max_count_from_wire(self.read_uvarint()?);
364 let element = Box::new(self.decode_type_descriptor(depth + 1)?);
365 TypeDescriptor::List { max_count, element }
366 }
367 0x22 => {
368 let max_count = wire::max_count_from_wire(self.read_uvarint()?);
369 let key = Box::new(self.decode_type_descriptor(depth + 1)?);
370 if !validate::is_valid_map_key_type(&key) {
371 return Err(Error::new(ErrorKind::InvalidMapKeyType));
372 }
373 let value = Box::new(self.decode_type_descriptor(depth + 1)?);
374 TypeDescriptor::Map {
375 max_count,
376 key,
377 value,
378 }
379 }
380 0x23 => {
381 let count = self.read_count("union variant count")?;
382 if count > self.options.limits.max_variants {
383 return Err(Error::limit("union variant count"));
384 }
385 let mut variants = Vec::with_capacity(count);
386 let mut seen_names = BTreeSet::new();
387 for _ in 0..count {
388 let name = self.read_text_owned()?;
389 if name.is_empty() {
390 return Err(Error::new(ErrorKind::UnionVariantNameEmpty));
391 }
392 if !seen_names.insert(name.clone()) {
393 return Err(Error::new(ErrorKind::DuplicateUnionVariantName));
394 }
395 let ty = self.decode_type_descriptor(depth + 1)?;
396 variants.push(Variant { name, ty });
397 }
398 TypeDescriptor::Union(variants)
399 }
400 0x24 => {
401 let count = self.read_count("enum symbol count")?;
402 if count > self.options.limits.max_variants {
403 return Err(Error::limit("enum symbol count"));
404 }
405 let mut symbols = Vec::with_capacity(count);
406 let mut seen_symbols = BTreeSet::new();
407 for _ in 0..count {
408 let symbol = self.read_text_owned()?;
409 if symbol.is_empty() {
410 return Err(Error::new(ErrorKind::EnumSymbolEmpty));
411 }
412 if !seen_symbols.insert(symbol.clone()) {
413 return Err(Error::new(ErrorKind::DuplicateEnumSymbol));
414 }
415 symbols.push(symbol);
416 }
417 TypeDescriptor::Enum(symbols)
418 }
419 0x25 => {
420 let inner = Box::new(self.decode_type_descriptor(depth + 1)?);
421 TypeDescriptor::Optional(inner)
422 }
423 0x26 => {
424 let authority = self.read_text_owned()?;
425 let type_label = self.read_text_owned()?;
426 let schema_params = self.read_bytes_owned(self.options.limits.max_extension_len)?;
427 TypeDescriptor::Extension {
428 authority,
429 type_name: type_label,
430 schema_params,
431 }
432 }
433 other => return Err(Error::new(ErrorKind::UnknownTypeTag(other))),
434 };
435 Ok(ty)
436 }
437
438 fn decode_value_for(&mut self, ty: &TypeDescriptor, depth: usize) -> Result<TpackValue<'de>> {
439 if depth > self.options.limits.max_depth {
440 return Err(Error::limit("value depth"));
441 }
442 let value = match ty {
443 TypeDescriptor::Null => TpackValue::Null,
444 TypeDescriptor::Bool => match self.read_u8()? {
445 0 => TpackValue::Bool(false),
446 1 => TpackValue::Bool(true),
447 _ => return Err(Error::invalid("invalid bool value")),
448 },
449 TypeDescriptor::I8 => TpackValue::I8(self.read_i8()?),
450 TypeDescriptor::I16 => TpackValue::I16(i16::from_be_bytes(self.read_array()?)),
451 TypeDescriptor::I32 => TpackValue::I32(i32::from_be_bytes(self.read_array()?)),
452 TypeDescriptor::I64 => TpackValue::I64(i64::from_be_bytes(self.read_array()?)),
453 TypeDescriptor::U8 => TpackValue::U8(self.read_u8()?),
454 TypeDescriptor::U16 => TpackValue::U16(u16::from_be_bytes(self.read_array()?)),
455 TypeDescriptor::U32 => TpackValue::U32(u32::from_be_bytes(self.read_array()?)),
456 TypeDescriptor::U64 => TpackValue::U64(u64::from_be_bytes(self.read_array()?)),
457 TypeDescriptor::F32 => {
458 let bits = u32::from_be_bytes(self.read_array()?);
459 if self.options.canonical.is_strict()
460 && f32::from_bits(bits).is_nan()
461 && bits != 0x7FC0_0000
462 {
463 return Err(Error::invalid("non-canonical f32 NaN"));
464 }
465 TpackValue::F32(f32::from_bits(bits))
466 }
467 TypeDescriptor::F64 => {
468 let bits = u64::from_be_bytes(self.read_array()?);
469 if self.options.canonical.is_strict()
470 && f64::from_bits(bits).is_nan()
471 && bits != 0x7FF8_0000_0000_0000
472 {
473 return Err(Error::invalid("non-canonical f64 NaN"));
474 }
475 TpackValue::F64(f64::from_bits(bits))
476 }
477 TypeDescriptor::Decimal => {
478 let scale = self.read_svarint()?;
479 let coefficient = self.read_svarint()?;
480 TpackValue::Decimal(Decimal { scale, coefficient })
481 }
482 TypeDescriptor::DecimalFixed { precision, .. } => {
483 let coefficient = self.read_svarint()?;
484 if validate::decimal_digits_abs(coefficient) > *precision {
485 return Err(Error::invalid("Decimal(P,S) coefficient exceeds precision"));
486 }
487 TpackValue::DecimalFixed(coefficient)
488 }
489 TypeDescriptor::String { max_len } => {
490 let value = self.read_text_borrowed(*max_len)?;
491 TpackValue::String(Cow::Borrowed(value))
492 }
493 TypeDescriptor::Bytes { max_len } => {
494 let value = self.read_byte_component(*max_len)?;
495 TpackValue::Bytes(Cow::Borrowed(value))
496 }
497 TypeDescriptor::Date => TpackValue::Date(self.read_svarint()?),
498 TypeDescriptor::Time => {
499 let nanos = self.read_uvarint()?;
500 if nanos >= NANOS_PER_DAY {
501 return Err(Error::invalid("time value exceeds nanos-per-day"));
502 }
503 TpackValue::Time(nanos)
504 }
505 TypeDescriptor::DateTime => {
506 let days = self.read_svarint()?;
507 let nanos = self.read_uvarint()?;
508 if nanos >= NANOS_PER_DAY {
509 return Err(Error::invalid("datetime time value exceeds nanos-per-day"));
510 }
511 TpackValue::DateTime { days, nanos }
512 }
513 TypeDescriptor::DateTimeTz => {
514 let days = self.read_svarint()?;
515 let nanos = self.read_uvarint()?;
516 if nanos >= NANOS_PER_DAY {
517 return Err(Error::invalid(
518 "datetime-tz time value exceeds nanos-per-day",
519 ));
520 }
521 let timezone = self.read_text_borrowed(None)?;
522 TpackValue::DateTimeTz {
523 days,
524 nanos,
525 timezone: Cow::Borrowed(timezone),
526 }
527 }
528 TypeDescriptor::Timestamp(_) => TpackValue::Timestamp(self.read_svarint()?),
529 TypeDescriptor::Duration => {
530 let seconds = self.read_svarint()?;
531 let nanos = self.read_svarint()?;
532 validate::validate_duration(seconds, nanos)?;
533 TpackValue::Duration(Duration { seconds, nanos })
534 }
535 TypeDescriptor::BigInt => TpackValue::BigInt(self.read_svarint()?),
536 TypeDescriptor::BigUInt => TpackValue::BigUInt(self.read_uvarint()?),
537 TypeDescriptor::CalendarInterval => {
538 let months = self.read_svarint()?;
539 let days = self.read_svarint()?;
540 let nanos = self.read_svarint()?;
541 TpackValue::CalendarInterval(CalendarInterval {
542 months,
543 days,
544 nanos,
545 })
546 }
547 TypeDescriptor::Struct(fields) => {
548 let mut values = Vec::with_capacity(fields.len());
549 for field in fields {
550 let value = self
551 .decode_value_for(&field.ty, depth + 1)
552 .map_err(|err| err.at_field(field.name.clone()))?;
553 values.push((field.id, value));
554 }
555 TpackValue::Struct(values)
556 }
557 TypeDescriptor::List { max_count, element } => {
558 let count = self.read_count("list count")?;
559 validate::validate_count("list count", count, *max_count, &self.options.limits)?;
560 let mut values = Vec::with_capacity(count);
561 for index in 0..count {
562 let value = self
563 .decode_value_for(element, depth + 1)
564 .map_err(|err| err.at_index(index))?;
565 values.push(value);
566 }
567 TpackValue::List(values)
568 }
569 TypeDescriptor::Map {
570 max_count,
571 key,
572 value,
573 } => {
574 let count = self.read_count("map count")?;
575 validate::validate_count("map count", count, *max_count, &self.options.limits)?;
576 let mut entries = Vec::with_capacity(count);
577 let mut seen_key_bytes = if self.options.canonical.is_strict() {
578 None
579 } else {
580 Some(BTreeSet::new())
581 };
582 let mut last_key_bytes: Option<&'de [u8]> = None;
583 for _ in 0..count {
584 let key_start = self.pos;
585 let key_value = self.decode_value_for(key, depth + 1)?;
586 let raw_key_bytes = &self.input[key_start..self.pos];
587 validate::reject_nan_map_key(&key_value)?;
588 if self.options.canonical.is_strict() {
589 if let Some(previous) = last_key_bytes {
593 match previous.cmp(raw_key_bytes) {
594 Ordering::Less => {}
595 Ordering::Equal => {
596 return Err(Error::invalid("duplicate map key"));
597 }
598 Ordering::Greater => {
599 return Err(Error::invalid("non-canonical map key order"));
600 }
601 }
602 }
603 last_key_bytes = Some(raw_key_bytes);
604 }
605 if !self.options.canonical.is_strict() {
606 let canonical_key = encode::value(
607 key,
608 &key_value,
609 EncodeOptions {
610 canonical: CanonicalMode::Strict,
611 limits: self.options.limits,
612 },
613 )?;
614 if !seen_key_bytes
615 .as_mut()
616 .expect("non-strict mode allocates a map-key set")
617 .insert(canonical_key)
618 {
619 return Err(Error::invalid("duplicate map key"));
620 }
621 }
622 let value = self.decode_value_for(value, depth + 1)?;
623 entries.push(ValueMapEntry {
624 key: key_value,
625 value,
626 });
627 }
628 TpackValue::Map(entries)
629 }
630 TypeDescriptor::Union(variants) => {
631 let index = self.read_uvarint()?;
632 let variant = variants
633 .get(usize::try_from(index).map_err(|_| Error::limit("variant index"))?)
634 .ok_or(Error::invalid("union variant index out of range"))?;
635 let value = self.decode_value_for(&variant.ty, depth + 1)?;
636 TpackValue::Union {
637 index,
638 value: Box::new(value),
639 }
640 }
641 TypeDescriptor::Enum(symbols) => {
642 let index = self.read_uvarint()?;
643 symbols
644 .get(usize::try_from(index).map_err(|_| Error::limit("enum index"))?)
645 .ok_or(Error::invalid("enum symbol index out of range"))?;
646 TpackValue::Enum(index)
647 }
648 TypeDescriptor::Optional(inner) => match self.read_u8()? {
649 0 => TpackValue::Optional(None),
650 1 => TpackValue::Optional(Some(Box::new(self.decode_value_for(inner, depth + 1)?))),
651 _ => return Err(Error::invalid("invalid optional presence marker")),
652 },
653 TypeDescriptor::Extension { .. } => {
654 let bytes = self.read_extension_component()?;
655 TpackValue::Extension(Cow::Borrowed(bytes))
656 }
657 };
658 Ok(value)
659 }
660
661 fn read_u8(&mut self) -> Result<u8> {
662 let byte = *self
663 .input
664 .get(self.pos)
665 .ok_or(Error::new(ErrorKind::UnexpectedEof))?;
666 self.pos += 1;
667 Ok(byte)
668 }
669
670 fn read_i8(&mut self) -> Result<i8> {
671 Ok(i8::from_be_bytes([self.read_u8()?]))
672 }
673
674 fn read_array<const N: usize>(&mut self) -> Result<[u8; N]> {
675 let bytes = self.read_bytes(N)?;
676 let mut out = [0u8; N];
677 out.copy_from_slice(bytes);
678 Ok(out)
679 }
680
681 fn read_bytes(&mut self, len: usize) -> Result<&'de [u8]> {
682 let end = self
683 .pos
684 .checked_add(len)
685 .ok_or(Error::new(ErrorKind::UnexpectedEof))?;
686 let bytes = self
687 .input
688 .get(self.pos..end)
689 .ok_or(Error::new(ErrorKind::UnexpectedEof))?;
690 self.pos = end;
691 Ok(bytes)
692 }
693
694 fn read_uvarint(&mut self) -> Result<u64> {
695 if let Some(&byte) = self.input.get(self.pos) {
698 if byte < 0x80 {
699 self.pos += 1;
700 return Ok(u64::from(byte));
701 }
702 }
703 self.read_uvarint_slow()
704 }
705
706 #[cold]
707 fn read_uvarint_slow(&mut self) -> Result<u64> {
708 let start = self.pos;
709 let mut value = 0u64;
710 for i in 0..self.options.limits.max_varint_bytes {
711 let byte = self.read_u8()?;
712 let payload = u64::from(byte & 0x7F);
713 if i == 9 && payload > 1 {
714 return Err(Error::new(ErrorKind::VarintOverflow));
715 }
716 value |= payload << (7 * i);
717 if byte & 0x80 == 0 {
718 let encoded_len = self.pos - start;
719 if self.options.canonical.is_strict() && encoded_len != wire::uvarint_len(value) {
720 return Err(Error::new(ErrorKind::OverlongVarint));
721 }
722 return Ok(value);
723 }
724 }
725 Err(Error::new(ErrorKind::VarintOverflow))
726 }
727
728 fn read_svarint(&mut self) -> Result<i64> {
729 let raw = self.read_uvarint()?;
730 Ok(((raw >> 1) as i64) ^ (-((raw & 1) as i64)))
731 }
732
733 fn read_len(&mut self, name: &'static str) -> Result<usize> {
734 usize::try_from(self.read_uvarint()?).map_err(|_| Error::limit(name))
735 }
736
737 fn read_count(&mut self, name: &'static str) -> Result<usize> {
738 usize::try_from(self.read_uvarint()?).map_err(|_| Error::limit(name))
739 }
740
741 fn read_text_owned(&mut self) -> Result<String> {
742 Ok(String::from(self.read_text_borrowed(None)?))
743 }
744
745 fn read_text_borrowed(&mut self, schema_max: Option<u64>) -> Result<&'de str> {
746 let bytes = self.read_limited_component(
747 "string length",
748 schema_max,
749 self.options.limits.max_string_len,
750 )?;
751 Ok(core::str::from_utf8(bytes)?)
752 }
753
754 fn read_bytes_owned(&mut self, limit: usize) -> Result<Vec<u8>> {
755 Ok(self
756 .read_limited_component("byte string length", None, limit)?
757 .to_vec())
758 }
759
760 fn read_byte_component(&mut self, schema_max: Option<u64>) -> Result<&'de [u8]> {
761 self.read_limited_component(
762 "byte string length",
763 schema_max,
764 self.options.limits.max_bytes_len,
765 )
766 }
767
768 fn read_extension_component(&mut self) -> Result<&'de [u8]> {
769 self.read_limited_component(
770 "extension payload size",
771 None,
772 self.options.limits.max_extension_len,
773 )
774 }
775
776 fn read_limited_component(
777 &mut self,
778 limit_name: &'static str,
779 schema_max: Option<u64>,
780 max_len: usize,
781 ) -> Result<&'de [u8]> {
782 let len = self.read_len(limit_name)?;
783 let limit = schema_max
784 .and_then(|max| usize::try_from(max).ok())
785 .unwrap_or(max_len)
786 .min(max_len);
787 if len > limit {
788 return Err(Error::limit(limit_name));
789 }
790 self.read_bytes(len)
791 }
792}
793
794pub struct Encoder {
795 out: Vec<u8>,
796 options: EncodeOptions,
797}
798
799impl Encoder {
800 pub fn new() -> Self {
801 Self::with_options(EncodeOptions::default())
802 }
803
804 pub fn with_options(options: EncodeOptions) -> Self {
805 Self {
806 out: Vec::new(),
807 options,
808 }
809 }
810
811 pub fn into_vec(self) -> Vec<u8> {
812 self.out
813 }
814
815 pub fn encode_message(
816 &mut self,
817 schema: &Schema,
818 value: &TpackValue<'_>,
819 mode: EnvelopeMode,
820 schema_id: Option<&[u8]>,
821 ) -> Result<()> {
822 let schema_bytes = encode::schema(schema, self.options)?;
823 self.out.extend_from_slice(&MAGIC);
824 self.out.push(VERSION);
825 self.out.push(mode.tag());
826 match mode {
827 EnvelopeMode::FullSchema => {
828 wire::write_uvarint(&mut self.out, schema_bytes.len() as u64);
829 self.out.extend_from_slice(&schema_bytes);
830 }
831 EnvelopeMode::FullSchemaWithId => {
832 let schema_id = schema_id.unwrap_or(&[]);
833 if schema_id.len() > self.options.limits.max_schema_id_len {
834 return Err(Error::new(ErrorKind::InvalidSchemaId));
835 }
836 wire::write_uvarint(&mut self.out, schema_id.len() as u64);
837 self.out.extend_from_slice(schema_id);
838 wire::write_uvarint(&mut self.out, schema_bytes.len() as u64);
839 self.out.extend_from_slice(&schema_bytes);
840 }
841 EnvelopeMode::SchemaRef => {
842 let schema_id = schema_id.ok_or(Error::new(ErrorKind::InvalidSchemaId))?;
843 if schema_id.is_empty() || schema_id.len() > self.options.limits.max_schema_id_len {
844 return Err(Error::new(ErrorKind::InvalidSchemaId));
845 }
846 wire::write_uvarint(&mut self.out, schema_id.len() as u64);
847 self.out.extend_from_slice(schema_id);
848 }
849 }
850 encode::ValueEncoder::new(&mut self.out, self.options).write_value(&schema.root, value)?;
851 Ok(())
852 }
853
854 pub fn encode_schema(&mut self, schema: &Schema) -> Result<()> {
855 let schema_bytes = encode::schema(schema, self.options)?;
856 self.out.extend_from_slice(&schema_bytes);
857 Ok(())
858 }
859
860 pub fn encode_value(&mut self, schema: &Schema, value: &TpackValue<'_>) -> Result<()> {
861 encode::ValueEncoder::new(&mut self.out, self.options).write_value(&schema.root, value)
862 }
863}
864
865impl Default for Encoder {
866 fn default() -> Self {
867 Self::new()
868 }
869}
870
871pub fn decode_message(input: &[u8]) -> Result<Message<'_>> {
872 Decoder::new(input).decode_message()
873}
874
875pub fn encode_message(
876 schema: &Schema,
877 value: &TpackValue<'_>,
878 mode: EnvelopeMode,
879 schema_id: Option<&[u8]>,
880) -> Result<Vec<u8>> {
881 let mut encoder = Encoder::new();
882 encoder.encode_message(schema, value, mode, schema_id)?;
883 Ok(encoder.into_vec())
884}
885
886pub fn encode_schema(schema: &Schema) -> Result<Vec<u8>> {
887 encode::schema(schema, EncodeOptions::default())
888}
889
890pub fn encode_value(schema: &Schema, value: &TpackValue<'_>) -> Result<Vec<u8>> {
891 encode::value(&schema.root, value, EncodeOptions::default())
892}
893
894#[cfg(test)]
895mod tests {
896 use super::*;
897 use alloc::{borrow::Cow, vec};
898
899 fn flat_schema() -> Schema {
900 Schema::new(TypeDescriptor::Struct(vec![
901 Field::new(1, "id", TypeDescriptor::String { max_len: Some(64) }),
902 Field::new(
903 2,
904 "price",
905 TypeDescriptor::DecimalFixed {
906 precision: 18,
907 scale: 4,
908 },
909 ),
910 Field::new(3, "tax", TypeDescriptor::Decimal),
911 Field::new(4, "qty", TypeDescriptor::I32),
912 Field::new(5, "ts", TypeDescriptor::I64),
913 ]))
914 }
915
916 fn flat_value<'a>() -> TpackValue<'a> {
917 TpackValue::Struct(vec![
918 (1, TpackValue::String(Cow::Borrowed("prod_001"))),
919 (2, TpackValue::DecimalFixed(2_999_900)),
920 (
921 3,
922 TpackValue::Decimal(Decimal {
923 scale: 3,
924 coefficient: 13_725,
925 }),
926 ),
927 (4, TpackValue::I32(10)),
928 (5, TpackValue::I64(1_715_000_000)),
929 ])
930 }
931
932 fn flat_example_bytes() -> Vec<u8> {
933 vec![
934 0x54, 0x50, 0x41, 0x4B, 0x01, 0x00, 0x28, 0x20, 0x05, 0x01, 0x02, 0x69, 0x64, 0x00,
935 0x0E, 0x40, 0x02, 0x05, 0x70, 0x72, 0x69, 0x63, 0x65, 0x00, 0x0D, 0x12, 0x04, 0x03,
936 0x03, 0x74, 0x61, 0x78, 0x00, 0x0C, 0x04, 0x03, 0x71, 0x74, 0x79, 0x00, 0x04, 0x05,
937 0x02, 0x74, 0x73, 0x00, 0x05, 0x08, 0x70, 0x72, 0x6F, 0x64, 0x5F, 0x30, 0x30, 0x31,
938 0xB8, 0x99, 0xEE, 0x02, 0x06, 0xBA, 0xD6, 0x01, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00,
939 0x00, 0x00, 0x66, 0x38, 0xD2, 0xC0,
940 ]
941 }
942
943 #[test]
944 fn draft_flat_record_roundtrips_exactly() {
945 let schema = flat_schema();
946 let value = flat_value();
947 let encoded =
948 encode_message(&schema, &value, EnvelopeMode::FullSchema, None).expect("encode");
949 assert_eq!(encoded, flat_example_bytes());
950
951 let decoded = decode_message(&encoded).expect("decode");
952 assert_eq!(decoded.schema.as_ref(), &schema);
953 assert_eq!(decoded.value, value);
954 }
955
956 #[test]
957 fn canonical_rejects_overlong_varint() {
958 let mut bytes = flat_example_bytes();
959 bytes[6] = 0xA8;
960 bytes.insert(7, 0x00);
961 let mut decoder = Decoder::with_options(
962 &bytes,
963 DecodeOptions {
964 canonical: CanonicalMode::Strict,
965 ..DecodeOptions::default()
966 },
967 );
968 assert!(matches!(
969 decoder.decode_message().unwrap_err().kind(),
970 ErrorKind::OverlongVarint
971 ));
972 }
973
974 #[test]
975 fn rejects_duplicate_map_keys() {
976 let schema = Schema::new(TypeDescriptor::Map {
977 max_count: None,
978 key: Box::new(TypeDescriptor::String { max_len: None }),
979 value: Box::new(TypeDescriptor::I32),
980 });
981 let value = TpackValue::Map(vec![
982 ValueMapEntry {
983 key: TpackValue::String(Cow::Borrowed("a")),
984 value: TpackValue::I32(1),
985 },
986 ValueMapEntry {
987 key: TpackValue::String(Cow::Borrowed("a")),
988 value: TpackValue::I32(2),
989 },
990 ]);
991 assert!(encode_message(&schema, &value, EnvelopeMode::FullSchema, None).is_err());
992 }
993
994 #[test]
995 fn encode_schema_helper_rejects_oversized_serialized_schema() {
996 let schema = Schema::new(TypeDescriptor::Struct(vec![Field::new(
997 1,
998 "schema_name",
999 TypeDescriptor::Null,
1000 )]));
1001 let schema_len = encode::schema(&schema, EncodeOptions::default())
1002 .expect("encode schema")
1003 .len();
1004 let options = EncodeOptions {
1005 limits: Limits {
1006 max_schema_len: schema_len - 1,
1007 ..Limits::default()
1008 },
1009 ..EncodeOptions::default()
1010 };
1011
1012 assert!(matches!(
1013 encode::schema(&schema, options).unwrap_err().kind(),
1014 ErrorKind::SchemaLengthExceeded
1015 ));
1016 }
1017}