Skip to main content

spacetimedb_sats/de/
impls.rs

1use super::{
2    BasicSmallVecVisitor, BasicVecVisitor, Deserialize, DeserializeSeed, Deserializer, Error, FieldNameVisitor,
3    ProductKind, ProductVisitor, SeqProductAccess, SliceVisitor, SumAccess, SumVisitor, VariantAccess, VariantVisitor,
4};
5use crate::{
6    de::{array_validate, array_visit, ArrayAccess, ArrayVisitor, BasicArrayVisitor, GrowingVec},
7    AlgebraicType, AlgebraicValue, ArrayType, ArrayValue, ProductType, ProductTypeElement, ProductValue, SumType,
8    SumValue, WithTypespace, F32, F64,
9};
10use crate::{i256, u256};
11use core::{iter, marker::PhantomData, ops::Bound};
12use lean_string::LeanString;
13use smallvec::SmallVec;
14use spacetimedb_primitives::{ColId, ColList};
15use std::{borrow::Cow, rc::Rc, sync::Arc};
16
17/// Implements [`Deserialize`] for a type in a simplified manner.
18///
19/// An example:
20/// ```ignore
21/// impl_deserialize!(
22/// //     Type parameters  Optional where  Impl type
23/// //            v               v             v
24/// //   ----------------  --------------- ----------
25///     [T: Deserialize<'de>] where [T: Copy] std::rc::Rc<T>,
26/// //  The `deserialize` implementation where `de` is the `Deserializer<'de>`
27/// //  and the expression right of `=>` is the body of `deserialize`.
28///     de => T::deserialize(de).map(std::rc::Rc::new)
29/// );
30/// ```
31#[macro_export]
32macro_rules! impl_deserialize {
33    (
34        [$($generics:tt)*] $(where [$($wc:tt)*])? $typ:ty,
35        $de:ident => $body:expr
36        $(, $validate_de:ident => $validate:expr)?
37    ) => {
38        impl<'de, $($generics)*> $crate::de::Deserialize<'de> for $typ {
39            fn deserialize<D: $crate::de::Deserializer<'de>>($de: D) -> Result<Self, D::Error> { $body }
40            $(
41                fn validate<D: $crate::de::Deserializer<'de>>($validate_de: D) -> Result<(), D::Error> { $validate }
42            )?
43        }
44    };
45}
46
47/// Implements [`Deserialize`] for a primitive type.
48///
49/// The `$method` is a parameterless method on `deserializer` to call.
50macro_rules! impl_prim {
51    ($(($prim:ty, $method:ident))*) => {
52        $(impl_deserialize!([] $prim, de => de.$method());)*
53    };
54}
55
56impl_prim! {
57    (bool, deserialize_bool)
58    /*(u8, deserialize_u8)*/ (u16, deserialize_u16) (u32, deserialize_u32) (u64, deserialize_u64) (u128, deserialize_u128) (u256, deserialize_u256)
59    (i8, deserialize_i8)     (i16, deserialize_i16) (i32, deserialize_i32) (i64, deserialize_i64) (i128, deserialize_i128) (i256, deserialize_i256)
60    (f32, deserialize_f32) (f64, deserialize_f64)
61}
62
63struct TupleVisitor<A>(PhantomData<A>);
64#[derive(Copy, Clone)]
65struct TupleNameVisitorMax(usize);
66
67impl FieldNameVisitor<'_> for TupleNameVisitorMax {
68    // The index of the field name.
69    type Output = usize;
70
71    fn field_names(&self) -> impl '_ + Iterator<Item = Option<&str>> {
72        iter::repeat_n(None, self.0)
73    }
74
75    fn kind(&self) -> ProductKind {
76        ProductKind::Normal
77    }
78
79    fn visit<E: Error>(self, name: &str) -> Result<Self::Output, E> {
80        let err = || Error::unknown_field_name(name, &self);
81        // Convert `name` to an index.
82        let Ok(index) = name.parse() else {
83            return Err(err());
84        };
85        // Confirm that the index exists or error.
86        if index < self.0 {
87            Ok(index)
88        } else {
89            Err(err())
90        }
91    }
92
93    fn visit_seq(self, index: usize) -> Self::Output {
94        // Assert that the index exists.
95        assert!(index < self.0);
96        index
97    }
98}
99
100macro_rules! impl_deserialize_tuple {
101    ($($ty_name:ident => $const_val:literal),*) => {
102        impl<'de, $($ty_name: Deserialize<'de>),*> ProductVisitor<'de> for TupleVisitor<($($ty_name,)*)> {
103            type Output = ($($ty_name,)*);
104            fn product_name(&self) -> Option<&str> { None }
105            fn product_len(&self) -> usize { crate::count!($($ty_name)*) }
106            fn visit_seq_product<A: SeqProductAccess<'de>>(self, mut _prod: A) -> Result<Self::Output, A::Error> {
107                $(
108                    #[allow(non_snake_case)]
109                    let $ty_name = _prod
110                        .next_element()?
111                        .ok_or_else(|| Error::invalid_product_length($const_val, &self))?;
112                )*
113
114                Ok(($($ty_name,)*))
115            }
116            fn validate_seq_product<A: SeqProductAccess<'de>>(self, mut _prod: A) -> Result<(), A::Error> {
117                $(
118                    #[allow(non_snake_case)]
119                    _prod
120                        .validate_next_element_seed(PhantomData::<$ty_name>)?
121                        .ok_or_else(|| Error::invalid_product_length($const_val, &self))?;
122                )*
123
124                Ok(())
125            }
126            fn visit_named_product<A: super::NamedProductAccess<'de>>(self, mut prod: A) -> Result<Self::Output, A::Error> {
127                $(
128                    #[allow(non_snake_case)]
129                    let mut $ty_name = None;
130                )*
131
132                let visit = TupleNameVisitorMax(self.product_len());
133                while let Some(index) = prod.get_field_ident(visit)? {
134                    match index {
135                        $($const_val => {
136                            if $ty_name.is_some() {
137                                return Err(A::Error::duplicate_field($const_val, None, &self))
138                            }
139                            $ty_name = Some(prod.get_field_value()?);
140                        })*
141                        index => return Err(Error::invalid_product_length(index, &self)),
142                    }
143                }
144                Ok(($(
145                    $ty_name.ok_or_else(|| A::Error::missing_field($const_val, None, &self))?,
146                )*))
147            }
148            fn validate_named_product<A: super::NamedProductAccess<'de>>(self, mut prod: A) -> Result<(), A::Error> {
149                $(
150                    #[allow(non_snake_case)]
151                    let mut $ty_name = false;
152                )*
153
154                let visit = TupleNameVisitorMax(self.product_len());
155                while let Some(index) = prod.get_field_ident(visit)? {
156                    match index {
157                        $($const_val => {
158                            if $ty_name {
159                                return Err(A::Error::duplicate_field($const_val, None, &self))
160                            }
161                            prod.validate_field_value::<$ty_name>()?;
162                            $ty_name = true;
163                        })*
164                        index => return Err(Error::invalid_product_length(index, &self)),
165                    }
166                }
167
168                $(
169                    if !$ty_name {
170                        return Err(A::Error::missing_field($const_val, None, &self))
171                    }
172                )*
173
174                Ok(())
175            }
176        }
177
178        impl_deserialize!([$($ty_name: Deserialize<'de>),*] ($($ty_name,)*), de => {
179            de.deserialize_product(TupleVisitor::<($($ty_name,)*)>(PhantomData))
180        });
181    };
182}
183
184impl_deserialize_tuple!();
185impl_deserialize_tuple!(T0 => 0);
186impl_deserialize_tuple!(T0 => 0, T1 => 1);
187impl_deserialize_tuple!(T0 => 0, T1 => 1, T2 => 2);
188impl_deserialize_tuple!(T0 => 0, T1 => 1, T2 => 2, T3 => 3);
189impl_deserialize_tuple!(T0 => 0, T1 => 1, T2 => 2, T3 => 3, T4 => 4);
190impl_deserialize_tuple!(T0 => 0, T1 => 1, T2 => 2, T3 => 3, T4 => 4, T5 => 5);
191impl_deserialize_tuple!(T0 => 0, T1 => 1, T2 => 2, T3 => 3, T4 => 4, T5 => 5, T6 => 6);
192impl_deserialize_tuple!(T0 => 0, T1 => 1, T2 => 2, T3 => 3, T4 => 4, T5 => 5, T6 => 6, T7 => 7);
193impl_deserialize_tuple!(T0 => 0, T1 => 1, T2 => 2, T3 => 3, T4 => 4, T5 => 5, T6 => 6, T7 => 7, T8 => 8);
194impl_deserialize_tuple!(T0 => 0, T1 => 1, T2 => 2, T3 => 3, T4 => 4, T5 => 5, T6 => 6, T7 => 7, T8 => 8, T9 => 9);
195impl_deserialize_tuple!(T0 => 0, T1 => 1, T2 => 2, T3 => 3, T4 => 4, T5 => 5, T6 => 6, T7 => 7, T8 => 8, T9 => 9, T10 => 10);
196impl_deserialize_tuple!(T0 => 0, T1 => 1, T2 => 2, T3 => 3, T4 => 4, T5 => 5, T6 => 6, T7 => 7, T8 => 8, T9 => 9, T10 => 10, T11 => 11);
197
198impl<'de> Deserialize<'de> for u8 {
199    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
200        deserializer.deserialize_u8()
201    }
202
203    // Specialize `Vec<u8>` deserialization.
204    // This is more likely to compile down to a `memcpy`.
205    fn __deserialize_vec<D: Deserializer<'de>>(deserializer: D) -> Result<Vec<Self>, D::Error> {
206        deserializer.deserialize_bytes(OwnedSliceVisitor)
207    }
208
209    fn __deserialize_array<D: Deserializer<'de>, const N: usize>(deserializer: D) -> Result<[Self; N], D::Error> {
210        deserializer.deserialize_bytes(ByteArrayVisitor)
211    }
212}
213
214impl_deserialize!([] F32, de => f32::deserialize(de).map(Into::into));
215impl_deserialize!([] F64, de => f64::deserialize(de).map(Into::into));
216impl_deserialize!(
217    [] String,
218    de => de.deserialize_str(OwnedSliceVisitor),
219    de => <&str>::validate(de)
220);
221impl_deserialize!(
222    [] LeanString,
223    de => <Cow<'_, str>>::deserialize(de).map(|s| (&*s).into()),
224    de => <&str>::validate(de)
225);
226impl_deserialize!(
227    [T: Deserialize<'de>] Vec<T>,
228    de => T::__deserialize_vec(de),
229    de => de.validate_array_seed(BasicVecVisitor, PhantomData::<T>)
230);
231impl_deserialize!(
232    [T: Deserialize<'de>, const N: usize] SmallVec<[T; N]>,
233    de => de.deserialize_array(BasicSmallVecVisitor),
234    de => de.validate_array_seed(BasicVecVisitor, PhantomData::<T>)
235);
236impl_deserialize!(
237    [T: Deserialize<'de>, const N: usize] [T; N],
238    de => T::__deserialize_array(de),
239    de => de.validate_array_seed(BasicArrayVisitor::<N>, PhantomData::<T>)
240);
241impl_deserialize!(
242    [] Box<str>,
243    de => String::deserialize(de).map(|s| s.into_boxed_str()),
244    de => String::validate(de)
245);
246impl_deserialize!(
247    [T: Deserialize<'de>] Box<[T]>,
248    de => Vec::deserialize(de).map(|s| s.into_boxed_slice()),
249    de => Vec::<T>::validate(de)
250);
251impl_deserialize!(
252    [T: Deserialize<'de>] Rc<[T]>,
253    de => Vec::deserialize(de).map(|s| s.into()),
254    de => Vec::<T>::validate(de)
255);
256impl_deserialize!(
257    [T: Deserialize<'de>] Arc<[T]>,
258    de => Vec::deserialize(de).map(|s| s.into()),
259    de => Vec::<T>::validate(de)
260);
261
262/// The visitor merely valiates the slice.
263struct ValidatingSliceVisitor;
264
265impl<T: ToOwned + ?Sized> SliceVisitor<'_, T> for ValidatingSliceVisitor {
266    type Output = ();
267
268    fn visit<E: Error>(self, _: &T) -> Result<Self::Output, E> {
269        Ok(())
270    }
271}
272
273/// The visitor converts the slice to its owned version.
274struct OwnedSliceVisitor;
275
276impl<T: ToOwned + ?Sized> SliceVisitor<'_, T> for OwnedSliceVisitor {
277    type Output = T::Owned;
278
279    fn visit<E: Error>(self, slice: &T) -> Result<Self::Output, E> {
280        Ok(slice.to_owned())
281    }
282
283    fn visit_owned<E: Error>(self, buf: T::Owned) -> Result<Self::Output, E> {
284        Ok(buf)
285    }
286}
287
288/// The visitor will convert the byte slice to `[u8; N]`.
289///
290/// When `slice.len() != N` an error will be raised.
291struct ByteArrayVisitor<const N: usize>;
292
293impl<const N: usize> SliceVisitor<'_, [u8]> for ByteArrayVisitor<N> {
294    type Output = [u8; N];
295
296    fn visit<E: Error>(self, slice: &[u8]) -> Result<Self::Output, E> {
297        slice.try_into().map_err(|_| {
298            Error::custom(if slice.len() > N {
299                "too many elements for array"
300            } else {
301                "too few elements for array"
302            })
303        })
304    }
305}
306
307impl_deserialize!(
308    [] &'de str,
309    de => de.deserialize_str_slice(),
310    de => de.deserialize_str(ValidatingSliceVisitor)
311);
312impl_deserialize!(
313    [] &'de [u8],
314    de => de.deserialize_bytes(BorrowedSliceVisitor),
315    de => de.deserialize_bytes(ValidatingSliceVisitor)
316);
317
318/// The visitor returns the slice as-is and borrowed.
319pub(crate) struct BorrowedSliceVisitor;
320
321impl<'de, T: ToOwned + ?Sized + 'de> SliceVisitor<'de, T> for BorrowedSliceVisitor {
322    type Output = &'de T;
323
324    fn visit<E: Error>(self, _: &T) -> Result<Self::Output, E> {
325        Err(E::custom("expected *borrowed* slice"))
326    }
327
328    fn visit_borrowed<E: Error>(self, borrowed_slice: &'de T) -> Result<Self::Output, E> {
329        Ok(borrowed_slice)
330    }
331}
332
333impl_deserialize!(
334    [] Cow<'de, str>,
335    de => de.deserialize_str(CowSliceVisitor),
336    de => <&str>::validate(de)
337);
338impl_deserialize!(
339    [] Cow<'de, [u8]>,
340    de => de.deserialize_bytes(CowSliceVisitor),
341    de => <&[u8]>::validate(de)
342);
343
344/// The visitor works with either owned or borrowed versions to produce `Cow<'de, T>`.
345struct CowSliceVisitor;
346
347impl<'de, T: ToOwned + ?Sized + 'de> SliceVisitor<'de, T> for CowSliceVisitor {
348    type Output = Cow<'de, T>;
349
350    fn visit<E: Error>(self, slice: &T) -> Result<Self::Output, E> {
351        self.visit_owned(slice.to_owned())
352    }
353
354    fn visit_owned<E: Error>(self, buf: <T as ToOwned>::Owned) -> Result<Self::Output, E> {
355        Ok(Cow::Owned(buf))
356    }
357
358    fn visit_borrowed<E: Error>(self, borrowed_slice: &'de T) -> Result<Self::Output, E> {
359        Ok(Cow::Borrowed(borrowed_slice))
360    }
361}
362
363impl_deserialize!(
364    [T: Deserialize<'de>] Box<T>,
365    de => T::deserialize(de).map(Box::new),
366    de => T::validate(de)
367);
368impl_deserialize!([T: Deserialize<'de>] Option<T>, de => de.deserialize_sum(OptionVisitor(PhantomData)));
369
370/// The visitor deserializes an `Option<T>`.
371struct OptionVisitor<T>(PhantomData<T>);
372
373impl<'de, T: Deserialize<'de>> SumVisitor<'de> for OptionVisitor<T> {
374    type Output = Option<T>;
375
376    fn sum_name(&self) -> Option<&str> {
377        Some("option")
378    }
379
380    fn is_option(&self) -> bool {
381        true
382    }
383
384    fn visit_sum<A: SumAccess<'de>>(self, data: A) -> Result<Self::Output, A::Error> {
385        // Determine the variant.
386        let (some, data) = data.variant(self)?;
387
388        // Deserialize contents for it.
389        Ok(if some {
390            Some(data.deserialize()?)
391        } else {
392            data.deserialize::<()>()?;
393            None
394        })
395    }
396
397    fn validate_sum<A: SumAccess<'de>>(self, data: A) -> Result<(), A::Error> {
398        // Determine the variant.
399        let (some, data) = data.variant(self)?;
400
401        // Validate contents for it.
402        if some {
403            data.validate::<T>()
404        } else {
405            data.validate::<()>()
406        }
407    }
408}
409
410impl<'de, T: Deserialize<'de>> VariantVisitor<'de> for OptionVisitor<T> {
411    type Output = bool;
412
413    fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
414        ["some", "none"].into_iter()
415    }
416
417    fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
418        match tag {
419            0 => Ok(true),
420            1 => Ok(false),
421            _ => Err(E::unknown_variant_tag(tag, &self)),
422        }
423    }
424
425    fn visit_name<E: Error>(self, name: &str) -> Result<Self::Output, E> {
426        match name {
427            "some" => Ok(true),
428            "none" => Ok(false),
429            _ => Err(E::unknown_variant_name(name, &self)),
430        }
431    }
432}
433
434impl_deserialize!([T: Deserialize<'de>, E: Deserialize<'de>] Result<T, E>, de =>
435    de.deserialize_sum(ResultVisitor(PhantomData))
436);
437
438/// Visitor to deserialize a `Result<T, E>`.
439struct ResultVisitor<T, E>(PhantomData<(T, E)>);
440
441/// Variant determined by the [`VariantVisitor`] for `Result<T, E>`.
442enum ResultVariant {
443    Ok,
444    Err,
445}
446
447impl<'de, T: Deserialize<'de>, E: Deserialize<'de>> SumVisitor<'de> for ResultVisitor<T, E> {
448    type Output = Result<T, E>;
449
450    fn sum_name(&self) -> Option<&str> {
451        Some("result")
452    }
453
454    fn is_option(&self) -> bool {
455        false
456    }
457
458    fn visit_sum<A: SumAccess<'de>>(self, data: A) -> Result<Self::Output, A::Error> {
459        let (variant, data) = data.variant(self)?;
460        Ok(match variant {
461            ResultVariant::Ok => Ok(data.deserialize()?),
462            ResultVariant::Err => Err(data.deserialize()?),
463        })
464    }
465
466    fn validate_sum<A: SumAccess<'de>>(self, data: A) -> Result<(), A::Error> {
467        let (variant, data) = data.variant(self)?;
468        match variant {
469            ResultVariant::Ok => data.validate::<T>(),
470            ResultVariant::Err => data.validate::<E>(),
471        }
472    }
473}
474
475impl<'de, T: Deserialize<'de>, U: Deserialize<'de>> VariantVisitor<'de> for ResultVisitor<T, U> {
476    type Output = ResultVariant;
477
478    fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
479        ["ok", "err"].into_iter()
480    }
481
482    fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
483        match tag {
484            0 => Ok(ResultVariant::Ok),
485            1 => Ok(ResultVariant::Err),
486            _ => Err(E::unknown_variant_tag(tag, &self)),
487        }
488    }
489
490    fn visit_name<E: Error>(self, name: &str) -> Result<Self::Output, E> {
491        match name {
492            "ok" => Ok(ResultVariant::Ok),
493            "err" => Ok(ResultVariant::Err),
494            _ => Err(E::unknown_variant_name(name, &self)),
495        }
496    }
497}
498
499impl_deserialize!([T: Deserialize<'de>] Bound<T>, de => WithBound(PhantomData).deserialize(de));
500
501/// The visitor deserializes a `Bound<T>`.
502#[derive(Clone, Copy)]
503pub struct WithBound<S>(pub S);
504
505impl<'de, S: Copy + DeserializeSeed<'de>> DeserializeSeed<'de> for WithBound<S> {
506    type Output = Bound<S::Output>;
507
508    fn deserialize<D: Deserializer<'de>>(self, de: D) -> Result<Self::Output, D::Error> {
509        de.deserialize_sum(BoundVisitor(self.0))
510    }
511}
512
513/// The visitor deserializes a `Bound<T>`.
514struct BoundVisitor<S>(S);
515
516/// Variant determined by the [`BoundVisitor`] for `Bound<T>`.
517enum BoundVariant {
518    Included,
519    Excluded,
520    Unbounded,
521}
522
523impl<'de, S: Copy + DeserializeSeed<'de>> SumVisitor<'de> for BoundVisitor<S> {
524    type Output = Bound<S::Output>;
525
526    fn sum_name(&self) -> Option<&str> {
527        Some("bound")
528    }
529
530    fn visit_sum<A: SumAccess<'de>>(self, data: A) -> Result<Self::Output, A::Error> {
531        // Determine the variant.
532        let this = self.0;
533        let (variant, data) = data.variant(self)?;
534
535        // Deserialize contents for it.
536        match variant {
537            BoundVariant::Included => data.deserialize_seed(this).map(Bound::Included),
538            BoundVariant::Excluded => data.deserialize_seed(this).map(Bound::Excluded),
539            BoundVariant::Unbounded => data.deserialize::<()>().map(|_| Bound::Unbounded),
540        }
541    }
542
543    fn validate_sum<A: SumAccess<'de>>(self, data: A) -> Result<(), A::Error> {
544        // Determine the variant.
545        let this = self.0;
546        let (variant, data) = data.variant(self)?;
547
548        // Validate contents for it.
549        match variant {
550            BoundVariant::Included | BoundVariant::Excluded => data.validate_seed(this),
551            BoundVariant::Unbounded => data.validate::<()>(),
552        }
553    }
554}
555
556impl<'de, T: Copy + DeserializeSeed<'de>> VariantVisitor<'de> for BoundVisitor<T> {
557    type Output = BoundVariant;
558
559    fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
560        ["included", "excluded", "unbounded"].into_iter()
561    }
562
563    fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
564        match tag {
565            0 => Ok(BoundVariant::Included),
566            1 => Ok(BoundVariant::Excluded),
567            // if this ever changes, edit crates/bindings/src/table.rs
568            2 => Ok(BoundVariant::Unbounded),
569            _ => Err(E::unknown_variant_tag(tag, &self)),
570        }
571    }
572
573    fn visit_name<E: Error>(self, name: &str) -> Result<Self::Output, E> {
574        match name {
575            "included" => Ok(BoundVariant::Included),
576            "excluded" => Ok(BoundVariant::Excluded),
577            "unbounded" => Ok(BoundVariant::Unbounded),
578            _ => Err(E::unknown_variant_name(name, &self)),
579        }
580    }
581}
582
583impl<'de> DeserializeSeed<'de> for WithTypespace<'_, AlgebraicType> {
584    type Output = AlgebraicValue;
585
586    fn deserialize<D: Deserializer<'de>>(self, de: D) -> Result<Self::Output, D::Error> {
587        match self.ty() {
588            AlgebraicType::Ref(r) => self.resolve(*r).deserialize(de),
589            AlgebraicType::Sum(sum) => self.with(sum).deserialize(de).map(Into::into),
590            AlgebraicType::Product(prod) => self.with(prod).deserialize(de).map(Into::into),
591            AlgebraicType::Array(ty) => self.with(ty).deserialize(de).map(Into::into),
592            AlgebraicType::Bool => bool::deserialize(de).map(Into::into),
593            AlgebraicType::I8 => i8::deserialize(de).map(Into::into),
594            AlgebraicType::U8 => u8::deserialize(de).map(Into::into),
595            AlgebraicType::I16 => i16::deserialize(de).map(Into::into),
596            AlgebraicType::U16 => u16::deserialize(de).map(Into::into),
597            AlgebraicType::I32 => i32::deserialize(de).map(Into::into),
598            AlgebraicType::U32 => u32::deserialize(de).map(Into::into),
599            AlgebraicType::I64 => i64::deserialize(de).map(Into::into),
600            AlgebraicType::U64 => u64::deserialize(de).map(Into::into),
601            AlgebraicType::I128 => i128::deserialize(de).map(Into::into),
602            AlgebraicType::U128 => u128::deserialize(de).map(Into::into),
603            AlgebraicType::I256 => i256::deserialize(de).map(Into::into),
604            AlgebraicType::U256 => u256::deserialize(de).map(Into::into),
605            AlgebraicType::F32 => f32::deserialize(de).map(Into::into),
606            AlgebraicType::F64 => f64::deserialize(de).map(Into::into),
607            AlgebraicType::String => <Box<str>>::deserialize(de).map(Into::into),
608        }
609    }
610
611    fn validate<D: Deserializer<'de>>(self, de: D) -> Result<(), D::Error> {
612        match self.ty() {
613            AlgebraicType::Ref(r) => self.resolve(*r).validate(de),
614            AlgebraicType::Sum(sum) => self.with(sum).validate(de),
615            AlgebraicType::Product(prod) => self.with(prod).validate(de),
616            AlgebraicType::Array(ty) => self.with(ty).validate(de),
617            AlgebraicType::Bool => bool::validate(de),
618            AlgebraicType::I8 => i8::validate(de),
619            AlgebraicType::U8 => u8::validate(de),
620            AlgebraicType::I16 => i16::validate(de),
621            AlgebraicType::U16 => u16::validate(de),
622            AlgebraicType::I32 => i32::validate(de),
623            AlgebraicType::U32 => u32::validate(de),
624            AlgebraicType::I64 => i64::validate(de),
625            AlgebraicType::U64 => u64::validate(de),
626            AlgebraicType::I128 => i128::validate(de),
627            AlgebraicType::U128 => u128::validate(de),
628            AlgebraicType::I256 => i256::validate(de),
629            AlgebraicType::U256 => u256::validate(de),
630            AlgebraicType::F32 => f32::validate(de),
631            AlgebraicType::F64 => f64::validate(de),
632            AlgebraicType::String => <&str>::validate(de),
633        }
634    }
635}
636
637impl<'de> DeserializeSeed<'de> for WithTypespace<'_, SumType> {
638    type Output = SumValue;
639
640    fn deserialize<D: Deserializer<'de>>(self, deserializer: D) -> Result<Self::Output, D::Error> {
641        deserializer.deserialize_sum(self)
642    }
643
644    fn validate<D: Deserializer<'de>>(self, deserializer: D) -> Result<(), D::Error> {
645        deserializer.validate_sum(self)
646    }
647}
648
649impl<'de> SumVisitor<'de> for WithTypespace<'_, SumType> {
650    type Output = SumValue;
651
652    fn sum_name(&self) -> Option<&str> {
653        None
654    }
655
656    fn is_option(&self) -> bool {
657        self.ty().as_option().is_some()
658    }
659
660    fn visit_sum<A: SumAccess<'de>>(self, data: A) -> Result<Self::Output, A::Error> {
661        let (tag, data) = data.variant(self)?;
662        // Find the variant type by `tag`.
663        let variant_ty = self.map(|ty| &ty.variants[tag as usize].algebraic_type);
664
665        let value = Box::new(data.deserialize_seed(variant_ty)?);
666        Ok(SumValue { tag, value })
667    }
668
669    fn validate_sum<A: SumAccess<'de>>(self, data: A) -> Result<(), A::Error> {
670        let (tag, data) = data.variant(self)?;
671        // Find the variant type by `tag`.
672        let variant_ty = self.map(|ty| &ty.variants[tag as usize].algebraic_type);
673
674        data.validate_seed(variant_ty)
675    }
676}
677
678impl VariantVisitor<'_> for WithTypespace<'_, SumType> {
679    type Output = u8;
680
681    fn variant_names(&self) -> impl '_ + Iterator<Item = &str> {
682        // Provide the names known from the `SumType`.
683        self.ty().variants.iter().filter_map(|v| v.name().map(|n| &**n))
684    }
685
686    fn visit_tag<E: Error>(self, tag: u8) -> Result<Self::Output, E> {
687        // Verify that tag identifies a valid variant in `SumType`.
688        self.ty()
689            .variants
690            .get(tag as usize)
691            .ok_or_else(|| E::unknown_variant_tag(tag, &self))?;
692
693        Ok(tag)
694    }
695
696    fn visit_name<E: Error>(self, name: &str) -> Result<Self::Output, E> {
697        // Translate the variant `name` to its tag.
698        self.ty()
699            .variants
700            .iter()
701            .position(|var| var.has_name(name))
702            .map(|pos| pos as u8)
703            .ok_or_else(|| E::unknown_variant_name(name, &self))
704    }
705}
706
707impl<'de> DeserializeSeed<'de> for WithTypespace<'_, ProductType> {
708    type Output = ProductValue;
709
710    fn deserialize<D: Deserializer<'de>>(self, deserializer: D) -> Result<Self::Output, D::Error> {
711        deserializer.deserialize_product(self.map(|pt| &*pt.elements))
712    }
713
714    fn validate<D: Deserializer<'de>>(self, deserializer: D) -> Result<(), D::Error> {
715        deserializer.validate_product(self.map(|pt| &*pt.elements))
716    }
717}
718
719impl<'de> DeserializeSeed<'de> for WithTypespace<'_, [ProductTypeElement]> {
720    type Output = ProductValue;
721
722    fn deserialize<D: Deserializer<'de>>(self, deserializer: D) -> Result<Self::Output, D::Error> {
723        deserializer.deserialize_product(self)
724    }
725
726    fn validate<D: Deserializer<'de>>(self, deserializer: D) -> Result<(), D::Error> {
727        deserializer.validate_product(self)
728    }
729}
730
731impl<'de> ProductVisitor<'de> for WithTypespace<'_, [ProductTypeElement]> {
732    type Output = ProductValue;
733
734    fn product_name(&self) -> Option<&str> {
735        None
736    }
737    fn product_len(&self) -> usize {
738        self.ty().len()
739    }
740
741    fn visit_seq_product<A: SeqProductAccess<'de>>(self, tup: A) -> Result<Self::Output, A::Error> {
742        visit_seq_product(self, &self, tup)
743    }
744
745    fn validate_seq_product<A: SeqProductAccess<'de>>(self, prod: A) -> Result<(), A::Error> {
746        validate_seq_product(self, &self, prod)
747    }
748
749    fn visit_named_product<A: super::NamedProductAccess<'de>>(self, tup: A) -> Result<Self::Output, A::Error> {
750        visit_named_product(self, &self, tup)
751    }
752
753    fn validate_named_product<A: super::NamedProductAccess<'de>>(self, prod: A) -> Result<(), A::Error> {
754        validate_named_product(self, &self, prod)
755    }
756}
757
758impl<'de> DeserializeSeed<'de> for WithTypespace<'_, ArrayType> {
759    type Output = ArrayValue;
760
761    fn deserialize<D: Deserializer<'de>>(self, deserializer: D) -> Result<Self::Output, D::Error> {
762        /// Deserialize a vector and `map` it to the appropriate `ArrayValue` variant.
763        fn de_array<'de, D: Deserializer<'de>, T: Deserialize<'de>>(
764            de: D,
765            map: impl FnOnce(Box<[T]>) -> ArrayValue,
766        ) -> Result<ArrayValue, D::Error> {
767            de.deserialize_array(BasicVecVisitor).map(<Box<[_]>>::from).map(map)
768        }
769
770        let mut ty = &*self.ty().elem_ty;
771
772        // Loop, resolving `Ref`s, until we reach a non-`Ref` type.
773        loop {
774            break match ty {
775                AlgebraicType::Ref(r) => {
776                    // The only arm that will loop.
777                    ty = self.resolve(*r).ty();
778                    continue;
779                }
780                AlgebraicType::Sum(ty) => deserializer
781                    .deserialize_array_seed(BasicVecVisitor, self.with(ty))
782                    .map(<Box<[_]>>::from)
783                    .map(ArrayValue::Sum),
784                AlgebraicType::Product(ty) => deserializer
785                    .deserialize_array_seed(BasicVecVisitor, self.with(ty))
786                    .map(<Box<[_]>>::from)
787                    .map(ArrayValue::Product),
788                AlgebraicType::Array(ty) => deserializer
789                    .deserialize_array_seed(BasicVecVisitor, self.with(ty))
790                    .map(<Box<[_]>>::from)
791                    .map(ArrayValue::Array),
792                &AlgebraicType::Bool => de_array(deserializer, ArrayValue::Bool),
793                &AlgebraicType::I8 => de_array(deserializer, ArrayValue::I8),
794                &AlgebraicType::U8 => deserializer
795                    .deserialize_bytes(OwnedSliceVisitor)
796                    .map(<Box<[_]>>::from)
797                    .map(ArrayValue::U8),
798                &AlgebraicType::I16 => de_array(deserializer, ArrayValue::I16),
799                &AlgebraicType::U16 => de_array(deserializer, ArrayValue::U16),
800                &AlgebraicType::I32 => de_array(deserializer, ArrayValue::I32),
801                &AlgebraicType::U32 => de_array(deserializer, ArrayValue::U32),
802                &AlgebraicType::I64 => de_array(deserializer, ArrayValue::I64),
803                &AlgebraicType::U64 => de_array(deserializer, ArrayValue::U64),
804                &AlgebraicType::I128 => de_array(deserializer, ArrayValue::I128),
805                &AlgebraicType::U128 => de_array(deserializer, ArrayValue::U128),
806                &AlgebraicType::I256 => de_array(deserializer, ArrayValue::I256),
807                &AlgebraicType::U256 => de_array(deserializer, ArrayValue::U256),
808                &AlgebraicType::F32 => de_array(deserializer, ArrayValue::F32),
809                &AlgebraicType::F64 => de_array(deserializer, ArrayValue::F64),
810                &AlgebraicType::String => de_array(deserializer, ArrayValue::String),
811            };
812        }
813    }
814
815    fn validate<D: Deserializer<'de>>(self, deserializer: D) -> Result<(), D::Error> {
816        /// Validate a vector for the appropriate `ArrayValue` variant.
817        fn val_array<'de, D: Deserializer<'de>, T: Deserialize<'de>>(de: D) -> Result<(), D::Error> {
818            de.validate_array_seed(BasicVecVisitor, PhantomData::<T>)
819        }
820
821        let mut ty = &*self.ty().elem_ty;
822
823        // Loop, resolving `Ref`s, until we reach a non-`Ref` type.
824        loop {
825            break match ty {
826                AlgebraicType::Ref(r) => {
827                    // The only arm that will loop.
828                    ty = self.resolve(*r).ty();
829                    continue;
830                }
831                AlgebraicType::Sum(ty) => deserializer.validate_array_seed(BasicVecVisitor, self.with(ty)),
832                AlgebraicType::Product(ty) => deserializer.validate_array_seed(BasicVecVisitor, self.with(ty)),
833                AlgebraicType::Array(ty) => deserializer.validate_array_seed(BasicVecVisitor, self.with(ty)),
834                &AlgebraicType::Bool => val_array::<_, bool>(deserializer),
835                &AlgebraicType::I8 => val_array::<_, i8>(deserializer),
836                &AlgebraicType::U8 => val_array::<_, u8>(deserializer),
837                &AlgebraicType::I16 => val_array::<_, i16>(deserializer),
838                &AlgebraicType::U16 => val_array::<_, u16>(deserializer),
839                &AlgebraicType::I32 => val_array::<_, i32>(deserializer),
840                &AlgebraicType::U32 => val_array::<_, u32>(deserializer),
841                &AlgebraicType::I64 => val_array::<_, i64>(deserializer),
842                &AlgebraicType::U64 => val_array::<_, u64>(deserializer),
843                &AlgebraicType::I128 => val_array::<_, i128>(deserializer),
844                &AlgebraicType::U128 => val_array::<_, u128>(deserializer),
845                &AlgebraicType::I256 => val_array::<_, i256>(deserializer),
846                &AlgebraicType::U256 => val_array::<_, u256>(deserializer),
847                &AlgebraicType::F32 => val_array::<_, f32>(deserializer),
848                &AlgebraicType::F64 => val_array::<_, f64>(deserializer),
849                &AlgebraicType::String => val_array::<_, String>(deserializer),
850            };
851        }
852    }
853}
854
855/// Deserialize, provided the fields' types, a product value with unnamed fields.
856pub fn visit_seq_product<'de, A: SeqProductAccess<'de>>(
857    elems: WithTypespace<[ProductTypeElement]>,
858    visitor: &impl ProductVisitor<'de>,
859    mut tup: A,
860) -> Result<ProductValue, A::Error> {
861    let elements = elems.ty().iter().enumerate().map(|(i, el)| {
862        tup.next_element_seed(elems.with(&el.algebraic_type))?
863            .ok_or_else(|| Error::invalid_product_length(i, visitor))
864    });
865    let elements = elements.collect::<Result<_, _>>()?;
866    Ok(ProductValue { elements })
867}
868
869/// Validate, provided the fields' types, a product value with unnamed fields.
870pub fn validate_seq_product<'de, A: SeqProductAccess<'de>>(
871    elems: WithTypespace<[ProductTypeElement]>,
872    visitor: &impl ProductVisitor<'de>,
873    mut tup: A,
874) -> Result<(), A::Error> {
875    for (i, el) in elems.ty().iter().enumerate() {
876        tup.validate_next_element_seed(elems.with(&el.algebraic_type))?
877            .ok_or_else(|| Error::invalid_product_length(i, visitor))?;
878    }
879    Ok(())
880}
881
882/// Deserialize, provided the fields' types, a product value with named fields.
883pub fn visit_named_product<'de, A: super::NamedProductAccess<'de>>(
884    elems_tys: WithTypespace<[ProductTypeElement]>,
885    visitor: &impl ProductVisitor<'de>,
886    mut tup: A,
887) -> Result<ProductValue, A::Error> {
888    let elems = elems_tys.ty();
889    let mut elements = vec![None; elems.len()];
890    let kind = visitor.product_kind();
891
892    // Deserialize a product value corresponding to each product type field.
893    // This is worst case quadratic in complexity
894    // as fields can be specified out of order (value side) compared to `elems` (type side).
895    for _ in 0..elems.len() {
896        // Deserialize a field name, match against the element types.
897        let index = tup.get_field_ident(TupleNameVisitor { elems, kind })?.ok_or_else(|| {
898            // Couldn't deserialize a field name.
899            // Find the first field name we haven't filled an element for.
900            let missing = elements.iter().position(|field| field.is_none()).unwrap();
901            let field_name = elems[missing].name().map(|n| &**n);
902            Error::missing_field(missing, field_name, visitor)
903        })?;
904
905        let element = &elems[index];
906
907        // By index we can select which element to deserialize a value for.
908        let slot = &mut elements[index];
909        if slot.is_some() {
910            return Err(Error::duplicate_field(index, element.name().map(|n| &**n), visitor));
911        }
912
913        // Deserialize the value for this field's type.
914        *slot = Some(tup.get_field_value_seed(elems_tys.with(&element.algebraic_type))?);
915    }
916
917    // Get rid of the `Option<_>` layer.
918    let elements = elements
919        .into_iter()
920        // We reached here, so we know nothing was missing, i.e., `None`.
921        .map(|x| x.unwrap_or_else(|| unreachable!("visit_named_product")))
922        .collect();
923
924    Ok(ProductValue { elements })
925}
926
927/// Validate, provided the fields' types, a product value with named fields.
928pub fn validate_named_product<'de, A: super::NamedProductAccess<'de>>(
929    elems_tys: WithTypespace<[ProductTypeElement]>,
930    visitor: &impl ProductVisitor<'de>,
931    mut tup: A,
932) -> Result<(), A::Error> {
933    let elems = elems_tys.ty();
934    // TODO(perf): replace with bitset.
935    let mut elements = vec![false; elems.len()];
936    let kind = visitor.product_kind();
937
938    // Deserialize a product value corresponding to each product type field.
939    // This is worst case quadratic in complexity
940    // as fields can be specified out of order (value side) compared to `elems` (type side).
941    for _ in 0..elems.len() {
942        // Deserialize a field name, match against the element types.
943        let index = tup.get_field_ident(TupleNameVisitor { elems, kind })?.ok_or_else(|| {
944            // Couldn't deserialize a field name.
945            // Find the first field name we haven't filled an element for.
946            let missing = elements.iter().position(|&field| !field).unwrap();
947            let field_name = elems[missing].name().map(|n| &**n);
948            Error::missing_field(missing, field_name, visitor)
949        })?;
950
951        let element = &elems[index];
952
953        // By index we can select which element to deserialize a value for.
954        let slot = &mut elements[index];
955        if *slot {
956            return Err(Error::duplicate_field(index, element.name().map(|n| &**n), visitor));
957        }
958
959        // Deserialize the value for this field's type.
960        tup.validate_field_value_seed(elems_tys.with(&element.algebraic_type))?;
961        *slot = true;
962    }
963
964    Ok(())
965}
966
967/// A visitor for extracting indices of field names in the elements of a [`ProductType`].
968struct TupleNameVisitor<'a> {
969    /// The elements of a product type, in order.
970    elems: &'a [ProductTypeElement],
971    /// The kind of product this is.
972    kind: ProductKind,
973}
974
975impl FieldNameVisitor<'_> for TupleNameVisitor<'_> {
976    // The index of the field name.
977    type Output = usize;
978
979    fn field_names(&self) -> impl '_ + Iterator<Item = Option<&str>> {
980        self.elems.iter().map(|f| f.name().map(|n| &**n))
981    }
982
983    fn kind(&self) -> ProductKind {
984        self.kind
985    }
986
987    fn visit<E: Error>(self, name: &str) -> Result<Self::Output, E> {
988        // Finds the index of a field with `name`.
989        self.elems
990            .iter()
991            .position(|f| f.has_name(name))
992            .ok_or_else(|| Error::unknown_field_name(name, &self))
993    }
994
995    fn visit_seq(self, index: usize) -> Self::Output {
996        // Confirm that the index exists.
997        self.elems
998            .get(index)
999            .expect("`index` should exist when `visit_seq` is called");
1000
1001        index
1002    }
1003}
1004
1005impl_deserialize!([] spacetimedb_primitives::ArgId, de => u64::deserialize(de).map(Self));
1006impl_deserialize!([] spacetimedb_primitives::TableId, de => u32::deserialize(de).map(Self));
1007impl_deserialize!([] spacetimedb_primitives::ViewId, de => u32::deserialize(de).map(Self));
1008impl_deserialize!([] spacetimedb_primitives::SequenceId, de => u32::deserialize(de).map(Self));
1009impl_deserialize!([] spacetimedb_primitives::IndexId, de => u32::deserialize(de).map(Self));
1010impl_deserialize!([] spacetimedb_primitives::ConstraintId, de => u32::deserialize(de).map(Self));
1011impl_deserialize!([] spacetimedb_primitives::ColId, de => u16::deserialize(de).map(Self));
1012impl_deserialize!([] spacetimedb_primitives::ScheduleId, de => u32::deserialize(de).map(Self));
1013
1014impl GrowingVec<ColId> for ColList {
1015    fn try_with_capacity<E: Error>(cap: usize) -> Result<Self, E> {
1016        Ok(Self::with_capacity(cap as u16))
1017    }
1018    fn push(&mut self, elem: ColId) {
1019        self.push(elem);
1020    }
1021}
1022impl_deserialize!([] spacetimedb_primitives::ColList, de => {
1023    struct ColListVisitor;
1024    impl<'de> ArrayVisitor<'de, ColId> for ColListVisitor {
1025        type Output = ColList;
1026
1027        fn visit<A: ArrayAccess<'de, Element = ColId>>(self, vec: A) -> Result<Self::Output, A::Error> {
1028            array_visit(vec)
1029        }
1030
1031        fn validate<A: ArrayAccess<'de, Element = ColId>>(self, vec: A) -> Result<(), A::Error> {
1032            array_validate(vec)
1033        }
1034    }
1035    de.deserialize_array(ColListVisitor)
1036});
1037impl_deserialize!(
1038    [] spacetimedb_primitives::ColSet,
1039    de => ColList::deserialize(de).map(Into::into),
1040    de => ColList::validate(de)
1041);
1042
1043#[cfg(feature = "blake3")]
1044impl_deserialize!([] blake3::Hash, de => <[u8; blake3::OUT_LEN]>::deserialize(de).map(blake3::Hash::from_bytes));
1045
1046// TODO(perf): integrate Bytes with Deserializer to reduce copying
1047impl_deserialize!(
1048    [] bytes::Bytes,
1049    de => <Vec<u8>>::deserialize(de).map(Into::into),
1050    de => <&[u8]>::validate(de)
1051);
1052
1053#[cfg(feature = "bytestring")]
1054impl_deserialize!(
1055    [] bytestring::ByteString,
1056    de => <String>::deserialize(de).map(Into::into),
1057    de => <&str>::validate(de)
1058);
1059
1060#[cfg(test)]
1061mod test {
1062    use crate::{
1063        algebraic_value::{de::ValueDeserializer, ser::value_serialize},
1064        bsatn,
1065        serde::SerdeWrapper,
1066        Deserialize, Serialize,
1067    };
1068    use core::fmt::Debug;
1069
1070    #[test]
1071    fn roundtrip_tuples_in_different_data_formats() {
1072        fn test<T: Serialize + for<'de> Deserialize<'de> + Eq + Debug>(x: T) {
1073            let bsatn = bsatn::to_vec(&x).unwrap();
1074            let y: T = bsatn::from_slice(&bsatn).unwrap();
1075            assert_eq!(x, y);
1076
1077            let val = value_serialize(&x);
1078            let y = T::deserialize(ValueDeserializer::new(val)).unwrap();
1079            assert_eq!(x, y);
1080
1081            let json = serde_json::to_string(SerdeWrapper::from_ref(&x)).unwrap();
1082            let SerdeWrapper(y) = serde_json::from_str::<SerdeWrapper<T>>(&json).unwrap();
1083            assert_eq!(x, y);
1084        }
1085
1086        test(());
1087        test((true,));
1088        test((1337u64, false));
1089        test(((7331u64, false), 42u32, 24u8));
1090    }
1091}