spacetimedb_sats/algebraic_value/
de.rs

1use crate::array_value::{ArrayValueIntoIter, ArrayValueIterCloned};
2use crate::{de, AlgebraicValue, SumValue};
3use crate::{i256, u256};
4use derive_more::From;
5
6/// An implementation of [`Deserializer`](de::Deserializer)
7/// where the input of deserialization is an `AlgebraicValue`.
8#[repr(transparent)]
9#[derive(From)]
10pub struct ValueDeserializer {
11    /// The value to deserialize to some `T`.
12    val: AlgebraicValue,
13}
14
15impl ValueDeserializer {
16    /// Returns a `ValueDeserializer` with `val` as the input for deserialization.
17    pub fn new(val: AlgebraicValue) -> Self {
18        Self { val }
19    }
20
21    /// Converts `&AlgebraicValue` to `&ValueDeserializer`.
22    pub fn from_ref(val: &AlgebraicValue) -> &Self {
23        // SAFETY: The conversion is OK due to `repr(transparent)`.
24        unsafe { &*(val as *const AlgebraicValue as *const ValueDeserializer) }
25    }
26}
27
28/// Errors that can occur when deserializing the `AlgebraicValue`.
29#[derive(Debug)]
30pub enum ValueDeserializeError {
31    /// The input type does not match the target type.
32    MismatchedType,
33    /// An unstructured error message.
34    Custom(String),
35}
36
37impl de::Error for ValueDeserializeError {
38    fn custom(msg: impl std::fmt::Display) -> Self {
39        Self::Custom(msg.to_string())
40    }
41}
42
43/// Turns any error into `ValueDeserializeError::MismatchedType`.
44fn map_err<T, E>(res: Result<T, E>) -> Result<T, ValueDeserializeError> {
45    res.map_err(|_| ValueDeserializeError::MismatchedType)
46}
47
48/// Turns any option into `ValueDeserializeError::MismatchedType`.
49fn ok_or<T>(res: Option<T>) -> Result<T, ValueDeserializeError> {
50    res.ok_or(ValueDeserializeError::MismatchedType)
51}
52
53impl<'de> de::Deserializer<'de> for ValueDeserializer {
54    type Error = ValueDeserializeError;
55
56    fn deserialize_product<V: de::ProductVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
57        let vals = map_err(self.val.into_product())?.into_iter();
58        visitor.visit_seq_product(ProductAccess { vals })
59    }
60
61    fn deserialize_sum<V: de::SumVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
62        let sum = map_err(self.val.into_sum())?;
63        visitor.visit_sum(SumAccess { sum })
64    }
65
66    fn deserialize_bool(self) -> Result<bool, Self::Error> {
67        map_err(self.val.into_bool())
68    }
69
70    fn deserialize_u8(self) -> Result<u8, Self::Error> {
71        map_err(self.val.into_u8())
72    }
73
74    fn deserialize_u16(self) -> Result<u16, Self::Error> {
75        map_err(self.val.into_u16())
76    }
77
78    fn deserialize_u32(self) -> Result<u32, Self::Error> {
79        map_err(self.val.into_u32())
80    }
81
82    fn deserialize_u64(self) -> Result<u64, Self::Error> {
83        map_err(self.val.into_u64())
84    }
85
86    fn deserialize_u128(self) -> Result<u128, Self::Error> {
87        map_err(self.val.into_u128().map(|x| x.0))
88    }
89
90    fn deserialize_u256(self) -> Result<u256, Self::Error> {
91        map_err(self.val.into_u256().map(|x| *x))
92    }
93
94    fn deserialize_i8(self) -> Result<i8, Self::Error> {
95        map_err(self.val.into_i8())
96    }
97
98    fn deserialize_i16(self) -> Result<i16, Self::Error> {
99        map_err(self.val.into_i16())
100    }
101
102    fn deserialize_i32(self) -> Result<i32, Self::Error> {
103        map_err(self.val.into_i32())
104    }
105
106    fn deserialize_i64(self) -> Result<i64, Self::Error> {
107        map_err(self.val.into_i64())
108    }
109
110    fn deserialize_i128(self) -> Result<i128, Self::Error> {
111        map_err(self.val.into_i128().map(|x| x.0))
112    }
113
114    fn deserialize_i256(self) -> Result<i256, Self::Error> {
115        map_err(self.val.into_i256().map(|x| *x))
116    }
117
118    fn deserialize_f32(self) -> Result<f32, Self::Error> {
119        map_err(self.val.into_f32().map(f32::from))
120    }
121
122    fn deserialize_f64(self) -> Result<f64, Self::Error> {
123        map_err(self.val.into_f64().map(f64::from))
124    }
125
126    fn deserialize_str<V: de::SliceVisitor<'de, str>>(self, visitor: V) -> Result<V::Output, Self::Error> {
127        visitor.visit_owned(map_err(self.val.into_string().map(Into::into))?)
128    }
129
130    fn deserialize_bytes<V: de::SliceVisitor<'de, [u8]>>(self, visitor: V) -> Result<V::Output, Self::Error> {
131        visitor.visit_owned(map_err(self.val.into_bytes().map(Vec::from))?)
132    }
133
134    fn deserialize_array_seed<V: de::ArrayVisitor<'de, T::Output>, T: de::DeserializeSeed<'de> + Clone>(
135        self,
136        visitor: V,
137        seed: T,
138    ) -> Result<V::Output, Self::Error> {
139        let iter = map_err(self.val.into_array())?.into_iter();
140        visitor.visit(ArrayAccess { iter, seed })
141    }
142}
143
144/// Defines deserialization for [`ValueDeserializer`] where product elements are in the input.
145struct ProductAccess {
146    /// The element values of the product as an iterator of owned values.
147    vals: std::vec::IntoIter<AlgebraicValue>,
148}
149
150impl<'de> de::SeqProductAccess<'de> for ProductAccess {
151    type Error = ValueDeserializeError;
152
153    fn next_element_seed<T: de::DeserializeSeed<'de>>(&mut self, seed: T) -> Result<Option<T::Output>, Self::Error> {
154        self.vals
155            .next()
156            .map(|val| seed.deserialize(ValueDeserializer { val }))
157            .transpose()
158    }
159}
160
161/// Defines deserialization for [`ValueDeserializer`] where a sum value is in the input.
162#[repr(transparent)]
163struct SumAccess {
164    /// The input sum value to deserialize.
165    sum: SumValue,
166}
167
168impl SumAccess {
169    /// Converts `&SumValue` to `&SumAccess`.
170    fn from_ref(sum: &SumValue) -> &Self {
171        // SAFETY: `repr(transparent)` allows this.
172        unsafe { &*(sum as *const SumValue as *const SumAccess) }
173    }
174}
175
176impl<'de> de::SumAccess<'de> for SumAccess {
177    type Error = ValueDeserializeError;
178
179    type Variant = ValueDeserializer;
180
181    fn variant<V: de::VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
182        let tag = visitor.visit_tag(self.sum.tag)?;
183        let val = *self.sum.value;
184        Ok((tag, ValueDeserializer { val }))
185    }
186}
187
188impl<'de> de::VariantAccess<'de> for ValueDeserializer {
189    type Error = ValueDeserializeError;
190
191    fn deserialize_seed<T: de::DeserializeSeed<'de>>(self, seed: T) -> Result<T::Output, Self::Error> {
192        seed.deserialize(self)
193    }
194}
195
196/// Defines deserialization for [`ValueDeserializer`] where an array value is in the input.
197struct ArrayAccess<T> {
198    /// The elements of the array as an iterator of owned elements.
199    iter: ArrayValueIntoIter,
200    /// A seed value provided by the caller of
201    /// [`deserialize_array_seed`](de::Deserializer::deserialize_array_seed).
202    seed: T,
203}
204
205impl<'de, T: de::DeserializeSeed<'de> + Clone> de::ArrayAccess<'de> for ArrayAccess<T> {
206    type Element = T::Output;
207    type Error = ValueDeserializeError;
208
209    fn next_element(&mut self) -> Result<Option<Self::Element>, Self::Error> {
210        self.iter
211            .next()
212            .map(|val| self.seed.clone().deserialize(ValueDeserializer { val }))
213            .transpose()
214    }
215}
216
217impl<'de> de::Deserializer<'de> for &'de ValueDeserializer {
218    type Error = ValueDeserializeError;
219
220    fn deserialize_product<V: de::ProductVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
221        let vals = ok_or(self.val.as_product())?.elements.iter();
222        visitor.visit_seq_product(RefProductAccess { vals })
223    }
224
225    fn deserialize_sum<V: de::SumVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
226        let sum = ok_or(self.val.as_sum())?;
227        visitor.visit_sum(SumAccess::from_ref(sum))
228    }
229
230    fn deserialize_bool(self) -> Result<bool, Self::Error> {
231        ok_or(self.val.as_bool().copied())
232    }
233    fn deserialize_u8(self) -> Result<u8, Self::Error> {
234        ok_or(self.val.as_u8().copied())
235    }
236    fn deserialize_u16(self) -> Result<u16, Self::Error> {
237        ok_or(self.val.as_u16().copied())
238    }
239    fn deserialize_u32(self) -> Result<u32, Self::Error> {
240        ok_or(self.val.as_u32().copied())
241    }
242    fn deserialize_u64(self) -> Result<u64, Self::Error> {
243        ok_or(self.val.as_u64().copied())
244    }
245    fn deserialize_u128(self) -> Result<u128, Self::Error> {
246        ok_or(self.val.as_u128().copied().map(|x| x.0))
247    }
248    fn deserialize_u256(self) -> Result<u256, Self::Error> {
249        ok_or(self.val.as_u256().map(|x| **x))
250    }
251    fn deserialize_i8(self) -> Result<i8, Self::Error> {
252        ok_or(self.val.as_i8().copied())
253    }
254    fn deserialize_i16(self) -> Result<i16, Self::Error> {
255        ok_or(self.val.as_i16().copied())
256    }
257    fn deserialize_i32(self) -> Result<i32, Self::Error> {
258        ok_or(self.val.as_i32().copied())
259    }
260    fn deserialize_i64(self) -> Result<i64, Self::Error> {
261        ok_or(self.val.as_i64().copied())
262    }
263    fn deserialize_i128(self) -> Result<i128, Self::Error> {
264        ok_or(self.val.as_i128().copied().map(|x| x.0))
265    }
266    fn deserialize_i256(self) -> Result<i256, Self::Error> {
267        ok_or(self.val.as_i256().map(|x| **x))
268    }
269    fn deserialize_f32(self) -> Result<f32, Self::Error> {
270        ok_or(self.val.as_f32().copied().map(f32::from))
271    }
272    fn deserialize_f64(self) -> Result<f64, Self::Error> {
273        ok_or(self.val.as_f64().copied().map(f64::from))
274    }
275
276    fn deserialize_str<V: de::SliceVisitor<'de, str>>(self, visitor: V) -> Result<V::Output, Self::Error> {
277        visitor.visit_borrowed(ok_or(self.val.as_string())?)
278    }
279
280    fn deserialize_bytes<V: de::SliceVisitor<'de, [u8]>>(self, visitor: V) -> Result<V::Output, Self::Error> {
281        visitor.visit_borrowed(ok_or(self.val.as_bytes())?)
282    }
283
284    fn deserialize_array_seed<V: de::ArrayVisitor<'de, T::Output>, T: de::DeserializeSeed<'de> + Clone>(
285        self,
286        visitor: V,
287        seed: T,
288    ) -> Result<V::Output, Self::Error> {
289        let iter = ok_or(self.val.as_array())?.iter_cloned();
290        visitor.visit(RefArrayAccess { iter, seed })
291    }
292}
293
294/// Defines deserialization for [`&'de ValueDeserializer`] where product elements are in the input.
295struct RefProductAccess<'a> {
296    /// The element values of the product as an iterator of borrowed values.
297    vals: std::slice::Iter<'a, AlgebraicValue>,
298}
299
300impl<'de> de::SeqProductAccess<'de> for RefProductAccess<'de> {
301    type Error = ValueDeserializeError;
302
303    fn next_element_seed<T: de::DeserializeSeed<'de>>(&mut self, seed: T) -> Result<Option<T::Output>, Self::Error> {
304        self.vals
305            .next()
306            .map(|val| seed.deserialize(ValueDeserializer::from_ref(val)))
307            .transpose()
308    }
309}
310
311impl<'de> de::SumAccess<'de> for &'de SumAccess {
312    type Error = ValueDeserializeError;
313
314    type Variant = &'de ValueDeserializer;
315
316    fn variant<V: de::VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
317        let tag = visitor.visit_tag(self.sum.tag)?;
318        Ok((tag, ValueDeserializer::from_ref(&self.sum.value)))
319    }
320}
321
322impl<'de> de::VariantAccess<'de> for &'de ValueDeserializer {
323    type Error = ValueDeserializeError;
324
325    fn deserialize_seed<T: de::DeserializeSeed<'de>>(self, seed: T) -> Result<T::Output, Self::Error> {
326        seed.deserialize(self)
327    }
328}
329
330/// Defines deserialization for [`&'de ValueDeserializer`] where an array value is in the input.
331struct RefArrayAccess<'a, T> {
332    // TODO: idk this kinda sucks
333    /// The elements of the array as an iterator of cloned elements.
334    iter: ArrayValueIterCloned<'a>,
335    /// A seed value provided by the caller of
336    /// [`deserialize_array_seed`](de::Deserializer::deserialize_array_seed).
337    seed: T,
338}
339
340impl<'de, T: de::DeserializeSeed<'de> + Clone> de::ArrayAccess<'de> for RefArrayAccess<'de, T> {
341    type Element = T::Output;
342    type Error = ValueDeserializeError;
343
344    fn next_element(&mut self) -> Result<Option<Self::Element>, Self::Error> {
345        self.iter
346            .next()
347            .map(|val| self.seed.clone().deserialize(ValueDeserializer { val }))
348            .transpose()
349    }
350}