Skip to main content

tpack_core/
codec.rs

1use alloc::{borrow::Cow, boxed::Box, collections::BTreeSet, string::String, sync::Arc, vec::Vec};
2use core::cmp::Ordering;
3
4mod encode;
5mod validate;
6mod wire;
7
8use crate::{
9    CalendarInterval, Decimal, Duration, Envelope, EnvelopeMode, Error, ErrorKind, Field, Message,
10    Result, Schema, SchemaId, SchemaRegistry, TimestampPrecision, TpackValue, TypeDescriptor,
11    ValueMapEntry, Variant, empty_registry,
12};
13
14pub const MAGIC: [u8; 4] = *b"TPAK";
15pub const VERSION: u8 = 0x01;
16
17const NANOS_PER_DAY: u64 = 86_400_000_000_000;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum CanonicalMode {
21    Off,
22    Strict,
23}
24
25impl CanonicalMode {
26    pub fn is_strict(self) -> bool {
27        matches!(self, CanonicalMode::Strict)
28    }
29}
30
31/// Resource limits applied during schema validation and message encode/decode.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub struct Limits {
34    /// Maximum encoded schema size in bytes.
35    ///
36    /// This limit is enforced symmetrically on decode and encode paths.
37    pub max_schema_len: usize,
38    pub max_schema_id_len: usize,
39    pub max_depth: usize,
40    pub max_fields: usize,
41    pub max_variants: usize,
42    pub max_collection_len: usize,
43    pub max_string_len: usize,
44    pub max_bytes_len: usize,
45    pub max_extension_len: usize,
46    pub max_varint_bytes: usize,
47}
48
49impl Default for Limits {
50    fn default() -> Self {
51        Self {
52            max_schema_len: 1024 * 1024,
53            max_schema_id_len: 1024,
54            max_depth: 128,
55            max_fields: 16_384,
56            max_variants: 16_384,
57            max_collection_len: 1_000_000,
58            max_string_len: 16 * 1024 * 1024,
59            max_bytes_len: 16 * 1024 * 1024,
60            max_extension_len: 16 * 1024 * 1024,
61            max_varint_bytes: 10,
62        }
63    }
64}
65
66/// Decoder behavior switches and resource limits.
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub struct DecodeOptions {
69    pub canonical: CanonicalMode,
70    pub allow_schema_ref: bool,
71    /// Validate embedded schema bytes on `FullSchemaWithId` registry hits.
72    ///
73    /// When enabled, the decoder reparses the embedded schema block and
74    /// requires it to match the cached schema before reusing the cached AST.
75    /// Disable this only when the registry entry is already trusted and the
76    /// embedded schema bytes do not need to be checked.
77    pub validate_embedded_schema_on_cache_hit: bool,
78    pub limits: Limits,
79}
80
81impl Default for DecodeOptions {
82    fn default() -> Self {
83        Self {
84            canonical: CanonicalMode::Off,
85            allow_schema_ref: true,
86            validate_embedded_schema_on_cache_hit: true,
87            limits: Limits::default(),
88        }
89    }
90}
91
92/// Encoder behavior switches and resource limits.
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub struct EncodeOptions {
95    pub canonical: CanonicalMode,
96    pub limits: Limits,
97}
98
99impl Default for EncodeOptions {
100    fn default() -> Self {
101        Self {
102            canonical: CanonicalMode::Off,
103            limits: Limits::default(),
104        }
105    }
106}
107
108pub struct Decoder<'de> {
109    input: &'de [u8],
110    pos: usize,
111    options: DecodeOptions,
112}
113
114impl<'de> Decoder<'de> {
115    pub fn new(input: &'de [u8]) -> Self {
116        Self::with_options(input, DecodeOptions::default())
117    }
118
119    pub fn with_options(input: &'de [u8], options: DecodeOptions) -> Self {
120        Self {
121            input,
122            pos: 0,
123            options,
124        }
125    }
126
127    pub fn position(&self) -> usize {
128        self.pos
129    }
130
131    pub fn is_eof(&self) -> bool {
132        self.pos == self.input.len()
133    }
134
135    pub fn decode_message(&mut self) -> Result<Message<'de>> {
136        self.decode_message_with_registry(&empty_registry())
137    }
138
139    pub fn decode_message_with_registry<R: SchemaRegistry + ?Sized>(
140        &mut self,
141        registry: &R,
142    ) -> Result<Message<'de>> {
143        self.read_header()?;
144        let mode = match self.read_u8()? {
145            0x00 => EnvelopeMode::FullSchema,
146            0x01 => EnvelopeMode::FullSchemaWithId,
147            0x02 => EnvelopeMode::SchemaRef,
148            other => return Err(Error::new(ErrorKind::UnknownEnvelopeMode(other))),
149        };
150
151        let (schema_id, schema, used_cached_schema) = match mode {
152            EnvelopeMode::FullSchema => {
153                let schema = self.decode_schema_block()?;
154                (None, Arc::new(schema), false)
155            }
156            EnvelopeMode::FullSchemaWithId => {
157                let schema_id = self.read_schema_id(false)?;
158                let schema_len = self.read_len("schema length")?;
159                if schema_len > self.options.limits.max_schema_len {
160                    return Err(Error::new(ErrorKind::SchemaLengthExceeded));
161                }
162
163                let schema_start = self.pos;
164                let schema_end = schema_start
165                    .checked_add(schema_len)
166                    .ok_or(Error::new(ErrorKind::SchemaLengthExceeded))?;
167                if schema_end > self.input.len() {
168                    return Err(Error::new(ErrorKind::UnexpectedEof));
169                }
170
171                if let Some(schema) = registry.get(schema_id.as_bytes()) {
172                    if self.options.validate_embedded_schema_on_cache_hit {
173                        self.validate_cached_schema_bytes(schema_len, schema.as_ref())?;
174                    } else {
175                        // Cache hits can skip the embedded schema by byte length and
176                        // reuse the shared AST when callers explicitly trust the
177                        // registry entry for this schema id.
178                        self.pos = schema_end;
179                    }
180                    (Some(schema_id), schema, true)
181                } else {
182                    let schema = self.decode_schema_at_exact_len(schema_len)?;
183                    (Some(schema_id), Arc::new(schema), false)
184                }
185            }
186            EnvelopeMode::SchemaRef => {
187                if !self.options.allow_schema_ref {
188                    return Err(Error::new(ErrorKind::SchemaRefNotAllowed));
189                }
190                let schema_id = self.read_schema_id(true)?;
191                let schema = registry
192                    .get(schema_id.as_bytes())
193                    .ok_or(Error::new(ErrorKind::UnknownSchemaId))?;
194                (Some(schema_id), schema, true)
195            }
196        };
197
198        let value = self.decode_value_for(&schema.root, 0)?;
199        if !self.is_eof() {
200            return Err(Error::new(ErrorKind::TrailingBytes));
201        }
202
203        Ok(Message {
204            envelope: Envelope {
205                mode,
206                schema_id,
207                used_cached_schema,
208            },
209            schema,
210            value,
211        })
212    }
213
214    pub fn decode_schema(&mut self) -> Result<Schema> {
215        let schema = Schema::new(self.decode_type_descriptor(0)?);
216        validate::validate_schema(&schema, &self.options.limits)?;
217        Ok(schema)
218    }
219
220    pub fn decode_value(&mut self, schema: &Schema) -> Result<TpackValue<'de>> {
221        let value = self.decode_value_for(&schema.root, 0)?;
222        if !self.is_eof() {
223            return Err(Error::new(ErrorKind::TrailingBytes));
224        }
225        Ok(value)
226    }
227
228    fn read_header(&mut self) -> Result<()> {
229        if self.read_bytes(4)? != MAGIC {
230            return Err(Error::new(ErrorKind::InvalidMagic));
231        }
232        let version = self.read_u8()?;
233        if version != VERSION {
234            return Err(Error::new(ErrorKind::UnsupportedVersion(version)));
235        }
236        Ok(())
237    }
238
239    fn decode_schema_block(&mut self) -> Result<Schema> {
240        let schema_len = self.read_len("schema length")?;
241        if schema_len > self.options.limits.max_schema_len {
242            return Err(Error::new(ErrorKind::SchemaLengthExceeded));
243        }
244        self.decode_schema_at_exact_len(schema_len)
245    }
246
247    fn decode_schema_at_exact_len(&mut self, schema_len: usize) -> Result<Schema> {
248        let start = self.pos;
249        let schema = self.decode_schema()?;
250        let consumed = self.pos - start;
251        if consumed != schema_len {
252            return Err(Error::new(ErrorKind::SchemaLengthMismatch));
253        }
254        Ok(schema)
255    }
256
257    fn validate_cached_schema_bytes(
258        &mut self,
259        schema_len: usize,
260        cached_schema: &Schema,
261    ) -> Result<()> {
262        let embedded_schema = self.decode_schema_at_exact_len(schema_len)?;
263        if &embedded_schema != cached_schema {
264            return Err(Error::new(ErrorKind::EmbeddedSchemaMismatch));
265        }
266        Ok(())
267    }
268
269    fn read_schema_id(&mut self, require_non_empty: bool) -> Result<SchemaId<'de>> {
270        let len = self.read_len("schema id length")?;
271        if len > self.options.limits.max_schema_id_len {
272            return Err(Error::new(ErrorKind::InvalidSchemaId));
273        }
274        if require_non_empty && len == 0 {
275            return Err(Error::new(ErrorKind::InvalidSchemaId));
276        }
277        Ok(SchemaId::borrowed(self.read_bytes(len)?))
278    }
279
280    fn decode_type_descriptor(&mut self, depth: usize) -> Result<TypeDescriptor> {
281        if depth > self.options.limits.max_depth {
282            return Err(Error::limit("schema depth"));
283        }
284        let tag = self.read_u8()?;
285        let ty = match tag {
286            0x00 => TypeDescriptor::Null,
287            0x01 => TypeDescriptor::Bool,
288            0x02 => TypeDescriptor::I8,
289            0x03 => TypeDescriptor::I16,
290            0x04 => TypeDescriptor::I32,
291            0x05 => TypeDescriptor::I64,
292            0x06 => TypeDescriptor::U8,
293            0x07 => TypeDescriptor::U16,
294            0x08 => TypeDescriptor::U32,
295            0x09 => TypeDescriptor::U64,
296            0x0A => TypeDescriptor::F32,
297            0x0B => TypeDescriptor::F64,
298            0x0C => TypeDescriptor::Decimal,
299            0x0D => {
300                let precision = self.read_uvarint()?;
301                let scale = self.read_uvarint()?;
302                if precision == 0 || scale > precision {
303                    return Err(Error::new(ErrorKind::InvalidDecimalParameters));
304                }
305                TypeDescriptor::DecimalFixed { precision, scale }
306            }
307            0x0E => TypeDescriptor::String {
308                max_len: Some(self.read_uvarint()?),
309            },
310            0x0F => TypeDescriptor::String { max_len: None },
311            0x10 => TypeDescriptor::Bytes {
312                max_len: Some(self.read_uvarint()?),
313            },
314            0x11 => TypeDescriptor::Bytes { max_len: None },
315            0x12 => TypeDescriptor::Date,
316            0x13 => TypeDescriptor::Time,
317            0x14 => TypeDescriptor::DateTime,
318            0x15 => TypeDescriptor::DateTimeTz,
319            0x16 => {
320                let precision = match self.read_u8()? {
321                    0 => TimestampPrecision::Seconds,
322                    1 => TimestampPrecision::Milliseconds,
323                    2 => TimestampPrecision::Microseconds,
324                    3 => TimestampPrecision::Nanoseconds,
325                    other => return Err(Error::new(ErrorKind::InvalidTimestampPrecision(other))),
326                };
327                TypeDescriptor::Timestamp(precision)
328            }
329            0x17 => TypeDescriptor::Duration,
330            0x18 => TypeDescriptor::BigInt,
331            0x19 => TypeDescriptor::BigUInt,
332            0x1A => TypeDescriptor::CalendarInterval,
333            0x20 => {
334                let count = self.read_count("struct field count")?;
335                if count > self.options.limits.max_fields {
336                    return Err(Error::limit("struct field count"));
337                }
338                let mut fields = Vec::with_capacity(count);
339                let mut seen_ids = BTreeSet::new();
340                let mut seen_names = BTreeSet::new();
341                for _ in 0..count {
342                    let id = self.read_uvarint()?;
343                    if id == 0 {
344                        return Err(Error::new(ErrorKind::StructFieldIdZero));
345                    }
346                    let name = self.read_text_owned()?;
347                    if name.is_empty() {
348                        return Err(Error::new(ErrorKind::StructFieldNameEmpty));
349                    }
350                    let flags = self.read_uvarint()?;
351                    if flags != 0 {
352                        return Err(Error::new(ErrorKind::StructFieldFlagsNonZero(flags)));
353                    }
354                    let ty = self.decode_type_descriptor(depth + 1)?;
355                    if !seen_ids.insert(id) || !seen_names.insert(name.clone()) {
356                        return Err(Error::new(ErrorKind::DuplicateStructFieldDefinition));
357                    }
358                    fields.push(Field { id, name, ty });
359                }
360                TypeDescriptor::Struct(fields)
361            }
362            0x21 => {
363                let max_count = wire::max_count_from_wire(self.read_uvarint()?);
364                let element = Box::new(self.decode_type_descriptor(depth + 1)?);
365                TypeDescriptor::List { max_count, element }
366            }
367            0x22 => {
368                let max_count = wire::max_count_from_wire(self.read_uvarint()?);
369                let key = Box::new(self.decode_type_descriptor(depth + 1)?);
370                if !validate::is_valid_map_key_type(&key) {
371                    return Err(Error::new(ErrorKind::InvalidMapKeyType));
372                }
373                let value = Box::new(self.decode_type_descriptor(depth + 1)?);
374                TypeDescriptor::Map {
375                    max_count,
376                    key,
377                    value,
378                }
379            }
380            0x23 => {
381                let count = self.read_count("union variant count")?;
382                if count > self.options.limits.max_variants {
383                    return Err(Error::limit("union variant count"));
384                }
385                let mut variants = Vec::with_capacity(count);
386                let mut seen_names = BTreeSet::new();
387                for _ in 0..count {
388                    let name = self.read_text_owned()?;
389                    if name.is_empty() {
390                        return Err(Error::new(ErrorKind::UnionVariantNameEmpty));
391                    }
392                    if !seen_names.insert(name.clone()) {
393                        return Err(Error::new(ErrorKind::DuplicateUnionVariantName));
394                    }
395                    let ty = self.decode_type_descriptor(depth + 1)?;
396                    variants.push(Variant { name, ty });
397                }
398                TypeDescriptor::Union(variants)
399            }
400            0x24 => {
401                let count = self.read_count("enum symbol count")?;
402                if count > self.options.limits.max_variants {
403                    return Err(Error::limit("enum symbol count"));
404                }
405                let mut symbols = Vec::with_capacity(count);
406                let mut seen_symbols = BTreeSet::new();
407                for _ in 0..count {
408                    let symbol = self.read_text_owned()?;
409                    if symbol.is_empty() {
410                        return Err(Error::new(ErrorKind::EnumSymbolEmpty));
411                    }
412                    if !seen_symbols.insert(symbol.clone()) {
413                        return Err(Error::new(ErrorKind::DuplicateEnumSymbol));
414                    }
415                    symbols.push(symbol);
416                }
417                TypeDescriptor::Enum(symbols)
418            }
419            0x25 => {
420                let inner = Box::new(self.decode_type_descriptor(depth + 1)?);
421                TypeDescriptor::Optional(inner)
422            }
423            0x26 => {
424                let authority = self.read_text_owned()?;
425                let type_label = self.read_text_owned()?;
426                let schema_params = self.read_bytes_owned(self.options.limits.max_extension_len)?;
427                TypeDescriptor::Extension {
428                    authority,
429                    type_name: type_label,
430                    schema_params,
431                }
432            }
433            other => return Err(Error::new(ErrorKind::UnknownTypeTag(other))),
434        };
435        Ok(ty)
436    }
437
438    fn decode_value_for(&mut self, ty: &TypeDescriptor, depth: usize) -> Result<TpackValue<'de>> {
439        if depth > self.options.limits.max_depth {
440            return Err(Error::limit("value depth"));
441        }
442        let value = match ty {
443            TypeDescriptor::Null => TpackValue::Null,
444            TypeDescriptor::Bool => match self.read_u8()? {
445                0 => TpackValue::Bool(false),
446                1 => TpackValue::Bool(true),
447                _ => return Err(Error::invalid("invalid bool value")),
448            },
449            TypeDescriptor::I8 => TpackValue::I8(self.read_i8()?),
450            TypeDescriptor::I16 => TpackValue::I16(i16::from_be_bytes(self.read_array()?)),
451            TypeDescriptor::I32 => TpackValue::I32(i32::from_be_bytes(self.read_array()?)),
452            TypeDescriptor::I64 => TpackValue::I64(i64::from_be_bytes(self.read_array()?)),
453            TypeDescriptor::U8 => TpackValue::U8(self.read_u8()?),
454            TypeDescriptor::U16 => TpackValue::U16(u16::from_be_bytes(self.read_array()?)),
455            TypeDescriptor::U32 => TpackValue::U32(u32::from_be_bytes(self.read_array()?)),
456            TypeDescriptor::U64 => TpackValue::U64(u64::from_be_bytes(self.read_array()?)),
457            TypeDescriptor::F32 => {
458                let bits = u32::from_be_bytes(self.read_array()?);
459                if self.options.canonical.is_strict()
460                    && f32::from_bits(bits).is_nan()
461                    && bits != 0x7FC0_0000
462                {
463                    return Err(Error::invalid("non-canonical f32 NaN"));
464                }
465                TpackValue::F32(f32::from_bits(bits))
466            }
467            TypeDescriptor::F64 => {
468                let bits = u64::from_be_bytes(self.read_array()?);
469                if self.options.canonical.is_strict()
470                    && f64::from_bits(bits).is_nan()
471                    && bits != 0x7FF8_0000_0000_0000
472                {
473                    return Err(Error::invalid("non-canonical f64 NaN"));
474                }
475                TpackValue::F64(f64::from_bits(bits))
476            }
477            TypeDescriptor::Decimal => {
478                let scale = self.read_svarint()?;
479                let coefficient = self.read_svarint()?;
480                TpackValue::Decimal(Decimal { scale, coefficient })
481            }
482            TypeDescriptor::DecimalFixed { precision, .. } => {
483                let coefficient = self.read_svarint()?;
484                if validate::decimal_digits_abs(coefficient) > *precision {
485                    return Err(Error::invalid("Decimal(P,S) coefficient exceeds precision"));
486                }
487                TpackValue::DecimalFixed(coefficient)
488            }
489            TypeDescriptor::String { max_len } => {
490                let value = self.read_text_borrowed(*max_len)?;
491                TpackValue::String(Cow::Borrowed(value))
492            }
493            TypeDescriptor::Bytes { max_len } => {
494                let value = self.read_byte_component(*max_len)?;
495                TpackValue::Bytes(Cow::Borrowed(value))
496            }
497            TypeDescriptor::Date => TpackValue::Date(self.read_svarint()?),
498            TypeDescriptor::Time => {
499                let nanos = self.read_uvarint()?;
500                if nanos >= NANOS_PER_DAY {
501                    return Err(Error::invalid("time value exceeds nanos-per-day"));
502                }
503                TpackValue::Time(nanos)
504            }
505            TypeDescriptor::DateTime => {
506                let days = self.read_svarint()?;
507                let nanos = self.read_uvarint()?;
508                if nanos >= NANOS_PER_DAY {
509                    return Err(Error::invalid("datetime time value exceeds nanos-per-day"));
510                }
511                TpackValue::DateTime { days, nanos }
512            }
513            TypeDescriptor::DateTimeTz => {
514                let days = self.read_svarint()?;
515                let nanos = self.read_uvarint()?;
516                if nanos >= NANOS_PER_DAY {
517                    return Err(Error::invalid(
518                        "datetime-tz time value exceeds nanos-per-day",
519                    ));
520                }
521                let timezone = self.read_text_borrowed(None)?;
522                TpackValue::DateTimeTz {
523                    days,
524                    nanos,
525                    timezone: Cow::Borrowed(timezone),
526                }
527            }
528            TypeDescriptor::Timestamp(_) => TpackValue::Timestamp(self.read_svarint()?),
529            TypeDescriptor::Duration => {
530                let seconds = self.read_svarint()?;
531                let nanos = self.read_svarint()?;
532                validate::validate_duration(seconds, nanos)?;
533                TpackValue::Duration(Duration { seconds, nanos })
534            }
535            TypeDescriptor::BigInt => TpackValue::BigInt(self.read_svarint()?),
536            TypeDescriptor::BigUInt => TpackValue::BigUInt(self.read_uvarint()?),
537            TypeDescriptor::CalendarInterval => {
538                let months = self.read_svarint()?;
539                let days = self.read_svarint()?;
540                let nanos = self.read_svarint()?;
541                TpackValue::CalendarInterval(CalendarInterval {
542                    months,
543                    days,
544                    nanos,
545                })
546            }
547            TypeDescriptor::Struct(fields) => {
548                let mut values = Vec::with_capacity(fields.len());
549                for field in fields {
550                    let value = self
551                        .decode_value_for(&field.ty, depth + 1)
552                        .map_err(|err| err.at_field(field.name.clone()))?;
553                    values.push((field.id, value));
554                }
555                TpackValue::Struct(values)
556            }
557            TypeDescriptor::List { max_count, element } => {
558                let count = self.read_count("list count")?;
559                validate::validate_count("list count", count, *max_count, &self.options.limits)?;
560                let mut values = Vec::with_capacity(count);
561                for index in 0..count {
562                    let value = self
563                        .decode_value_for(element, depth + 1)
564                        .map_err(|err| err.at_index(index))?;
565                    values.push(value);
566                }
567                TpackValue::List(values)
568            }
569            TypeDescriptor::Map {
570                max_count,
571                key,
572                value,
573            } => {
574                let count = self.read_count("map count")?;
575                validate::validate_count("map count", count, *max_count, &self.options.limits)?;
576                let mut entries = Vec::with_capacity(count);
577                let mut seen_key_bytes = if self.options.canonical.is_strict() {
578                    None
579                } else {
580                    Some(BTreeSet::new())
581                };
582                let mut last_key_bytes: Option<&'de [u8]> = None;
583                for _ in 0..count {
584                    let key_start = self.pos;
585                    let key_value = self.decode_value_for(key, depth + 1)?;
586                    let raw_key_bytes = &self.input[key_start..self.pos];
587                    validate::reject_nan_map_key(&key_value)?;
588                    if self.options.canonical.is_strict() {
589                        // Strict canonical input means the bytes just consumed are
590                        // already the canonical key representation. Compare slices
591                        // directly instead of re-encoding every key into a Vec.
592                        if let Some(previous) = last_key_bytes {
593                            match previous.cmp(raw_key_bytes) {
594                                Ordering::Less => {}
595                                Ordering::Equal => {
596                                    return Err(Error::invalid("duplicate map key"));
597                                }
598                                Ordering::Greater => {
599                                    return Err(Error::invalid("non-canonical map key order"));
600                                }
601                            }
602                        }
603                        last_key_bytes = Some(raw_key_bytes);
604                    }
605                    if !self.options.canonical.is_strict() {
606                        let canonical_key = encode::value(
607                            key,
608                            &key_value,
609                            EncodeOptions {
610                                canonical: CanonicalMode::Strict,
611                                limits: self.options.limits,
612                            },
613                        )?;
614                        if !seen_key_bytes
615                            .as_mut()
616                            .expect("non-strict mode allocates a map-key set")
617                            .insert(canonical_key)
618                        {
619                            return Err(Error::invalid("duplicate map key"));
620                        }
621                    }
622                    let value = self.decode_value_for(value, depth + 1)?;
623                    entries.push(ValueMapEntry {
624                        key: key_value,
625                        value,
626                    });
627                }
628                TpackValue::Map(entries)
629            }
630            TypeDescriptor::Union(variants) => {
631                let index = self.read_uvarint()?;
632                let variant = variants
633                    .get(usize::try_from(index).map_err(|_| Error::limit("variant index"))?)
634                    .ok_or(Error::invalid("union variant index out of range"))?;
635                let value = self.decode_value_for(&variant.ty, depth + 1)?;
636                TpackValue::Union {
637                    index,
638                    value: Box::new(value),
639                }
640            }
641            TypeDescriptor::Enum(symbols) => {
642                let index = self.read_uvarint()?;
643                symbols
644                    .get(usize::try_from(index).map_err(|_| Error::limit("enum index"))?)
645                    .ok_or(Error::invalid("enum symbol index out of range"))?;
646                TpackValue::Enum(index)
647            }
648            TypeDescriptor::Optional(inner) => match self.read_u8()? {
649                0 => TpackValue::Optional(None),
650                1 => TpackValue::Optional(Some(Box::new(self.decode_value_for(inner, depth + 1)?))),
651                _ => return Err(Error::invalid("invalid optional presence marker")),
652            },
653            TypeDescriptor::Extension { .. } => {
654                let bytes = self.read_extension_component()?;
655                TpackValue::Extension(Cow::Borrowed(bytes))
656            }
657        };
658        Ok(value)
659    }
660
661    fn read_u8(&mut self) -> Result<u8> {
662        let byte = *self
663            .input
664            .get(self.pos)
665            .ok_or(Error::new(ErrorKind::UnexpectedEof))?;
666        self.pos += 1;
667        Ok(byte)
668    }
669
670    fn read_i8(&mut self) -> Result<i8> {
671        Ok(i8::from_be_bytes([self.read_u8()?]))
672    }
673
674    fn read_array<const N: usize>(&mut self) -> Result<[u8; N]> {
675        let bytes = self.read_bytes(N)?;
676        let mut out = [0u8; N];
677        out.copy_from_slice(bytes);
678        Ok(out)
679    }
680
681    fn read_bytes(&mut self, len: usize) -> Result<&'de [u8]> {
682        let end = self
683            .pos
684            .checked_add(len)
685            .ok_or(Error::new(ErrorKind::UnexpectedEof))?;
686        let bytes = self
687            .input
688            .get(self.pos..end)
689            .ok_or(Error::new(ErrorKind::UnexpectedEof))?;
690        self.pos = end;
691        Ok(bytes)
692    }
693
694    fn read_uvarint(&mut self) -> Result<u64> {
695        // The common case is a one-byte length/id/count. Keep it on a tiny
696        // predictable path and push overflow/canonical checks to the cold loop.
697        if let Some(&byte) = self.input.get(self.pos) {
698            if byte < 0x80 {
699                self.pos += 1;
700                return Ok(u64::from(byte));
701            }
702        }
703        self.read_uvarint_slow()
704    }
705
706    #[cold]
707    fn read_uvarint_slow(&mut self) -> Result<u64> {
708        let start = self.pos;
709        let mut value = 0u64;
710        for i in 0..self.options.limits.max_varint_bytes {
711            let byte = self.read_u8()?;
712            let payload = u64::from(byte & 0x7F);
713            if i == 9 && payload > 1 {
714                return Err(Error::new(ErrorKind::VarintOverflow));
715            }
716            value |= payload << (7 * i);
717            if byte & 0x80 == 0 {
718                let encoded_len = self.pos - start;
719                if self.options.canonical.is_strict() && encoded_len != wire::uvarint_len(value) {
720                    return Err(Error::new(ErrorKind::OverlongVarint));
721                }
722                return Ok(value);
723            }
724        }
725        Err(Error::new(ErrorKind::VarintOverflow))
726    }
727
728    fn read_svarint(&mut self) -> Result<i64> {
729        let raw = self.read_uvarint()?;
730        Ok(((raw >> 1) as i64) ^ (-((raw & 1) as i64)))
731    }
732
733    fn read_len(&mut self, name: &'static str) -> Result<usize> {
734        usize::try_from(self.read_uvarint()?).map_err(|_| Error::limit(name))
735    }
736
737    fn read_count(&mut self, name: &'static str) -> Result<usize> {
738        usize::try_from(self.read_uvarint()?).map_err(|_| Error::limit(name))
739    }
740
741    fn read_text_owned(&mut self) -> Result<String> {
742        Ok(String::from(self.read_text_borrowed(None)?))
743    }
744
745    fn read_text_borrowed(&mut self, schema_max: Option<u64>) -> Result<&'de str> {
746        let bytes = self.read_limited_component(
747            "string length",
748            schema_max,
749            self.options.limits.max_string_len,
750        )?;
751        Ok(core::str::from_utf8(bytes)?)
752    }
753
754    fn read_bytes_owned(&mut self, limit: usize) -> Result<Vec<u8>> {
755        Ok(self
756            .read_limited_component("byte string length", None, limit)?
757            .to_vec())
758    }
759
760    fn read_byte_component(&mut self, schema_max: Option<u64>) -> Result<&'de [u8]> {
761        self.read_limited_component(
762            "byte string length",
763            schema_max,
764            self.options.limits.max_bytes_len,
765        )
766    }
767
768    fn read_extension_component(&mut self) -> Result<&'de [u8]> {
769        self.read_limited_component(
770            "extension payload size",
771            None,
772            self.options.limits.max_extension_len,
773        )
774    }
775
776    fn read_limited_component(
777        &mut self,
778        limit_name: &'static str,
779        schema_max: Option<u64>,
780        max_len: usize,
781    ) -> Result<&'de [u8]> {
782        let len = self.read_len(limit_name)?;
783        let limit = schema_max
784            .and_then(|max| usize::try_from(max).ok())
785            .unwrap_or(max_len)
786            .min(max_len);
787        if len > limit {
788            return Err(Error::limit(limit_name));
789        }
790        self.read_bytes(len)
791    }
792}
793
794pub struct Encoder {
795    out: Vec<u8>,
796    options: EncodeOptions,
797}
798
799impl Encoder {
800    pub fn new() -> Self {
801        Self::with_options(EncodeOptions::default())
802    }
803
804    pub fn with_options(options: EncodeOptions) -> Self {
805        Self {
806            out: Vec::new(),
807            options,
808        }
809    }
810
811    pub fn into_vec(self) -> Vec<u8> {
812        self.out
813    }
814
815    pub fn encode_message(
816        &mut self,
817        schema: &Schema,
818        value: &TpackValue<'_>,
819        mode: EnvelopeMode,
820        schema_id: Option<&[u8]>,
821    ) -> Result<()> {
822        let schema_bytes = encode::schema(schema, self.options)?;
823        self.out.extend_from_slice(&MAGIC);
824        self.out.push(VERSION);
825        self.out.push(mode.tag());
826        match mode {
827            EnvelopeMode::FullSchema => {
828                wire::write_uvarint(&mut self.out, schema_bytes.len() as u64);
829                self.out.extend_from_slice(&schema_bytes);
830            }
831            EnvelopeMode::FullSchemaWithId => {
832                let schema_id = schema_id.unwrap_or(&[]);
833                if schema_id.len() > self.options.limits.max_schema_id_len {
834                    return Err(Error::new(ErrorKind::InvalidSchemaId));
835                }
836                wire::write_uvarint(&mut self.out, schema_id.len() as u64);
837                self.out.extend_from_slice(schema_id);
838                wire::write_uvarint(&mut self.out, schema_bytes.len() as u64);
839                self.out.extend_from_slice(&schema_bytes);
840            }
841            EnvelopeMode::SchemaRef => {
842                let schema_id = schema_id.ok_or(Error::new(ErrorKind::InvalidSchemaId))?;
843                if schema_id.is_empty() || schema_id.len() > self.options.limits.max_schema_id_len {
844                    return Err(Error::new(ErrorKind::InvalidSchemaId));
845                }
846                wire::write_uvarint(&mut self.out, schema_id.len() as u64);
847                self.out.extend_from_slice(schema_id);
848            }
849        }
850        encode::ValueEncoder::new(&mut self.out, self.options).write_value(&schema.root, value)?;
851        Ok(())
852    }
853
854    pub fn encode_schema(&mut self, schema: &Schema) -> Result<()> {
855        let schema_bytes = encode::schema(schema, self.options)?;
856        self.out.extend_from_slice(&schema_bytes);
857        Ok(())
858    }
859
860    pub fn encode_value(&mut self, schema: &Schema, value: &TpackValue<'_>) -> Result<()> {
861        encode::ValueEncoder::new(&mut self.out, self.options).write_value(&schema.root, value)
862    }
863}
864
865impl Default for Encoder {
866    fn default() -> Self {
867        Self::new()
868    }
869}
870
871pub fn decode_message(input: &[u8]) -> Result<Message<'_>> {
872    Decoder::new(input).decode_message()
873}
874
875pub fn encode_message(
876    schema: &Schema,
877    value: &TpackValue<'_>,
878    mode: EnvelopeMode,
879    schema_id: Option<&[u8]>,
880) -> Result<Vec<u8>> {
881    let mut encoder = Encoder::new();
882    encoder.encode_message(schema, value, mode, schema_id)?;
883    Ok(encoder.into_vec())
884}
885
886pub fn encode_schema(schema: &Schema) -> Result<Vec<u8>> {
887    encode::schema(schema, EncodeOptions::default())
888}
889
890pub fn encode_value(schema: &Schema, value: &TpackValue<'_>) -> Result<Vec<u8>> {
891    encode::value(&schema.root, value, EncodeOptions::default())
892}
893
894#[cfg(test)]
895mod tests {
896    use super::*;
897    use alloc::{borrow::Cow, vec};
898
899    fn flat_schema() -> Schema {
900        Schema::new(TypeDescriptor::Struct(vec![
901            Field::new(1, "id", TypeDescriptor::String { max_len: Some(64) }),
902            Field::new(
903                2,
904                "price",
905                TypeDescriptor::DecimalFixed {
906                    precision: 18,
907                    scale: 4,
908                },
909            ),
910            Field::new(3, "tax", TypeDescriptor::Decimal),
911            Field::new(4, "qty", TypeDescriptor::I32),
912            Field::new(5, "ts", TypeDescriptor::I64),
913        ]))
914    }
915
916    fn flat_value<'a>() -> TpackValue<'a> {
917        TpackValue::Struct(vec![
918            (1, TpackValue::String(Cow::Borrowed("prod_001"))),
919            (2, TpackValue::DecimalFixed(2_999_900)),
920            (
921                3,
922                TpackValue::Decimal(Decimal {
923                    scale: 3,
924                    coefficient: 13_725,
925                }),
926            ),
927            (4, TpackValue::I32(10)),
928            (5, TpackValue::I64(1_715_000_000)),
929        ])
930    }
931
932    fn flat_example_bytes() -> Vec<u8> {
933        vec![
934            0x54, 0x50, 0x41, 0x4B, 0x01, 0x00, 0x28, 0x20, 0x05, 0x01, 0x02, 0x69, 0x64, 0x00,
935            0x0E, 0x40, 0x02, 0x05, 0x70, 0x72, 0x69, 0x63, 0x65, 0x00, 0x0D, 0x12, 0x04, 0x03,
936            0x03, 0x74, 0x61, 0x78, 0x00, 0x0C, 0x04, 0x03, 0x71, 0x74, 0x79, 0x00, 0x04, 0x05,
937            0x02, 0x74, 0x73, 0x00, 0x05, 0x08, 0x70, 0x72, 0x6F, 0x64, 0x5F, 0x30, 0x30, 0x31,
938            0xB8, 0x99, 0xEE, 0x02, 0x06, 0xBA, 0xD6, 0x01, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00,
939            0x00, 0x00, 0x66, 0x38, 0xD2, 0xC0,
940        ]
941    }
942
943    #[test]
944    fn draft_flat_record_roundtrips_exactly() {
945        let schema = flat_schema();
946        let value = flat_value();
947        let encoded =
948            encode_message(&schema, &value, EnvelopeMode::FullSchema, None).expect("encode");
949        assert_eq!(encoded, flat_example_bytes());
950
951        let decoded = decode_message(&encoded).expect("decode");
952        assert_eq!(decoded.schema.as_ref(), &schema);
953        assert_eq!(decoded.value, value);
954    }
955
956    #[test]
957    fn canonical_rejects_overlong_varint() {
958        let mut bytes = flat_example_bytes();
959        bytes[6] = 0xA8;
960        bytes.insert(7, 0x00);
961        let mut decoder = Decoder::with_options(
962            &bytes,
963            DecodeOptions {
964                canonical: CanonicalMode::Strict,
965                ..DecodeOptions::default()
966            },
967        );
968        assert!(matches!(
969            decoder.decode_message().unwrap_err().kind(),
970            ErrorKind::OverlongVarint
971        ));
972    }
973
974    #[test]
975    fn rejects_duplicate_map_keys() {
976        let schema = Schema::new(TypeDescriptor::Map {
977            max_count: None,
978            key: Box::new(TypeDescriptor::String { max_len: None }),
979            value: Box::new(TypeDescriptor::I32),
980        });
981        let value = TpackValue::Map(vec![
982            ValueMapEntry {
983                key: TpackValue::String(Cow::Borrowed("a")),
984                value: TpackValue::I32(1),
985            },
986            ValueMapEntry {
987                key: TpackValue::String(Cow::Borrowed("a")),
988                value: TpackValue::I32(2),
989            },
990        ]);
991        assert!(encode_message(&schema, &value, EnvelopeMode::FullSchema, None).is_err());
992    }
993
994    #[test]
995    fn encode_schema_helper_rejects_oversized_serialized_schema() {
996        let schema = Schema::new(TypeDescriptor::Struct(vec![Field::new(
997            1,
998            "schema_name",
999            TypeDescriptor::Null,
1000        )]));
1001        let schema_len = encode::schema(&schema, EncodeOptions::default())
1002            .expect("encode schema")
1003            .len();
1004        let options = EncodeOptions {
1005            limits: Limits {
1006                max_schema_len: schema_len - 1,
1007                ..Limits::default()
1008            },
1009            ..EncodeOptions::default()
1010        };
1011
1012        assert!(matches!(
1013            encode::schema(&schema, options).unwrap_err().kind(),
1014            ErrorKind::SchemaLengthExceeded
1015        ));
1016    }
1017}