Skip to main content

spacetimedb_sats/algebraic_value/
de.rs

1use crate::array_value::{ArrayValueIntoIter, ArrayValueIterCloned};
2use crate::{de, AlgebraicValue, ProductValue, SumValue};
3use crate::{i256, u256};
4use derive_more::From;
5
6/// An implementation of [`Deserializer`](de::Deserializer)
7/// where the input of deserialization is an `AlgebraicValue`.
8#[repr(transparent)]
9#[derive(From)]
10pub struct ValueDeserializer {
11    /// The value to deserialize to some `T`.
12    val: AlgebraicValue,
13}
14
15impl ValueDeserializer {
16    /// Returns a `ValueDeserializer` with `val` as the input for deserialization.
17    pub fn new(val: AlgebraicValue) -> Self {
18        Self { val }
19    }
20
21    /// Converts `&AlgebraicValue` to `&ValueDeserializer`.
22    pub fn from_ref(val: &AlgebraicValue) -> &Self {
23        // SAFETY: The conversion is OK due to `repr(transparent)`.
24        unsafe { &*(val as *const AlgebraicValue as *const ValueDeserializer) }
25    }
26
27    pub fn from_product_ref(prod: &ProductValue) -> RefProductAccess<'_> {
28        let vals = prod.elements.iter();
29        RefProductAccess { vals }
30    }
31}
32
33/// Errors that can occur when deserializing the `AlgebraicValue`.
34#[derive(Debug)]
35pub enum ValueDeserializeError {
36    /// The input type does not match the target type.
37    MismatchedType,
38    /// An unstructured error message.
39    Custom(String),
40}
41
42impl de::Error for ValueDeserializeError {
43    fn custom(msg: impl std::fmt::Display) -> Self {
44        Self::Custom(msg.to_string())
45    }
46}
47
48/// Turns any error into `ValueDeserializeError::MismatchedType`.
49fn map_err<T, E>(res: Result<T, E>) -> Result<T, ValueDeserializeError> {
50    res.map_err(|_| ValueDeserializeError::MismatchedType)
51}
52
53/// Turns any option into `ValueDeserializeError::MismatchedType`.
54fn ok_or<T>(res: Option<T>) -> Result<T, ValueDeserializeError> {
55    res.ok_or(ValueDeserializeError::MismatchedType)
56}
57
58impl<'de> de::Deserializer<'de> for ValueDeserializer {
59    type Error = ValueDeserializeError;
60
61    fn deserialize_product<V: de::ProductVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
62        let vals = map_err(self.val.into_product())?.into_iter();
63        visitor.visit_seq_product(ProductAccess { vals })
64    }
65
66    fn validate_product<V: de::ProductVisitor<'de>>(self, visitor: V) -> Result<(), Self::Error> {
67        let vals = map_err(self.val.into_product())?.into_iter();
68        visitor.validate_seq_product(ProductAccess { vals })
69    }
70
71    fn deserialize_sum<V: de::SumVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
72        let sum = map_err(self.val.into_sum())?;
73        visitor.visit_sum(SumAccess { sum })
74    }
75
76    fn validate_sum<V: de::SumVisitor<'de>>(self, visitor: V) -> Result<(), Self::Error> {
77        let sum = map_err(self.val.into_sum())?;
78        visitor.validate_sum(SumAccess { sum })
79    }
80
81    fn deserialize_bool(self) -> Result<bool, Self::Error> {
82        map_err(self.val.into_bool())
83    }
84
85    fn deserialize_u8(self) -> Result<u8, Self::Error> {
86        map_err(self.val.into_u8())
87    }
88
89    fn deserialize_u16(self) -> Result<u16, Self::Error> {
90        map_err(self.val.into_u16())
91    }
92
93    fn deserialize_u32(self) -> Result<u32, Self::Error> {
94        map_err(self.val.into_u32())
95    }
96
97    fn deserialize_u64(self) -> Result<u64, Self::Error> {
98        map_err(self.val.into_u64())
99    }
100
101    fn deserialize_u128(self) -> Result<u128, Self::Error> {
102        map_err(self.val.into_u128().map(|x| x.0))
103    }
104
105    fn deserialize_u256(self) -> Result<u256, Self::Error> {
106        map_err(self.val.into_u256().map(|x| *x))
107    }
108
109    fn deserialize_i8(self) -> Result<i8, Self::Error> {
110        map_err(self.val.into_i8())
111    }
112
113    fn deserialize_i16(self) -> Result<i16, Self::Error> {
114        map_err(self.val.into_i16())
115    }
116
117    fn deserialize_i32(self) -> Result<i32, Self::Error> {
118        map_err(self.val.into_i32())
119    }
120
121    fn deserialize_i64(self) -> Result<i64, Self::Error> {
122        map_err(self.val.into_i64())
123    }
124
125    fn deserialize_i128(self) -> Result<i128, Self::Error> {
126        map_err(self.val.into_i128().map(|x| x.0))
127    }
128
129    fn deserialize_i256(self) -> Result<i256, Self::Error> {
130        map_err(self.val.into_i256().map(|x| *x))
131    }
132
133    fn deserialize_f32(self) -> Result<f32, Self::Error> {
134        map_err(self.val.into_f32().map(f32::from))
135    }
136
137    fn deserialize_f64(self) -> Result<f64, Self::Error> {
138        map_err(self.val.into_f64().map(f64::from))
139    }
140
141    fn deserialize_str<V: de::SliceVisitor<'de, str>>(self, visitor: V) -> Result<V::Output, Self::Error> {
142        visitor.visit_owned(map_err(self.val.into_string().map(Into::into))?)
143    }
144
145    fn deserialize_bytes<V: de::SliceVisitor<'de, [u8]>>(self, visitor: V) -> Result<V::Output, Self::Error> {
146        visitor.visit_owned(map_err(self.val.into_bytes().map(Vec::from))?)
147    }
148
149    fn deserialize_array_seed<V: de::ArrayVisitor<'de, T::Output>, T: de::DeserializeSeed<'de> + Clone>(
150        self,
151        visitor: V,
152        seed: T,
153    ) -> Result<V::Output, Self::Error> {
154        let iter = map_err(self.val.into_array())?.into_iter();
155        visitor.visit(ArrayAccess { iter, seed })
156    }
157
158    fn validate_array_seed<V: de::ArrayVisitor<'de, T::Output>, T: de::DeserializeSeed<'de> + Clone>(
159        self,
160        visitor: V,
161        seed: T,
162    ) -> Result<(), Self::Error> {
163        let iter = map_err(self.val.into_array())?.into_iter();
164        visitor.validate(ArrayAccess { iter, seed })
165    }
166}
167
168/// Defines deserialization for [`ValueDeserializer`] where product elements are in the input.
169struct ProductAccess {
170    /// The element values of the product as an iterator of owned values.
171    vals: std::vec::IntoIter<AlgebraicValue>,
172}
173
174impl<'de> de::SeqProductAccess<'de> for ProductAccess {
175    type Error = ValueDeserializeError;
176
177    fn next_element_seed<T: de::DeserializeSeed<'de>>(&mut self, seed: T) -> Result<Option<T::Output>, Self::Error> {
178        self.vals
179            .next()
180            .map(|val| seed.deserialize(ValueDeserializer { val }))
181            .transpose()
182    }
183
184    fn validate_next_element_seed<T: de::DeserializeSeed<'de>>(&mut self, seed: T) -> Result<Option<()>, Self::Error> {
185        self.vals
186            .next()
187            .map(|val| seed.validate(ValueDeserializer { val }))
188            .transpose()
189    }
190}
191
192/// Defines deserialization for [`ValueDeserializer`] where a sum value is in the input.
193#[repr(transparent)]
194struct SumAccess {
195    /// The input sum value to deserialize.
196    sum: SumValue,
197}
198
199impl SumAccess {
200    /// Converts `&SumValue` to `&SumAccess`.
201    fn from_ref(sum: &SumValue) -> &Self {
202        // SAFETY: `repr(transparent)` allows this.
203        unsafe { &*(sum as *const SumValue as *const SumAccess) }
204    }
205}
206
207impl<'de> de::SumAccess<'de> for SumAccess {
208    type Error = ValueDeserializeError;
209
210    type Variant = ValueDeserializer;
211
212    fn variant<V: de::VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
213        let tag = visitor.visit_tag(self.sum.tag)?;
214        let val = *self.sum.value;
215        Ok((tag, ValueDeserializer { val }))
216    }
217}
218
219impl<'de> de::VariantAccess<'de> for ValueDeserializer {
220    type Error = ValueDeserializeError;
221
222    fn deserialize_seed<T: de::DeserializeSeed<'de>>(self, seed: T) -> Result<T::Output, Self::Error> {
223        seed.deserialize(self)
224    }
225
226    fn validate_seed<T: de::DeserializeSeed<'de>>(self, seed: T) -> Result<(), Self::Error> {
227        seed.validate(self)
228    }
229}
230
231/// Defines deserialization for [`ValueDeserializer`] where an array value is in the input.
232struct ArrayAccess<T> {
233    /// The elements of the array as an iterator of owned elements.
234    iter: ArrayValueIntoIter,
235    /// A seed value provided by the caller of
236    /// [`deserialize_array_seed`](de::Deserializer::deserialize_array_seed).
237    seed: T,
238}
239
240impl<'de, T: de::DeserializeSeed<'de> + Clone> de::ArrayAccess<'de> for ArrayAccess<T> {
241    type Element = T::Output;
242    type Error = ValueDeserializeError;
243
244    fn next_element(&mut self) -> Result<Option<Self::Element>, Self::Error> {
245        self.iter
246            .next()
247            .map(|val| self.seed.clone().deserialize(ValueDeserializer { val }))
248            .transpose()
249    }
250
251    fn validate_next_element(&mut self) -> Result<Option<()>, Self::Error> {
252        self.iter
253            .next()
254            .map(|val| self.seed.clone().validate(ValueDeserializer { val }))
255            .transpose()
256    }
257}
258
259impl<'de> de::Deserializer<'de> for &'de ValueDeserializer {
260    type Error = ValueDeserializeError;
261
262    fn deserialize_product<V: de::ProductVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
263        let vals = ok_or(self.val.as_product())?.elements.iter();
264        visitor.visit_seq_product(RefProductAccess { vals })
265    }
266
267    fn validate_product<V: de::ProductVisitor<'de>>(self, visitor: V) -> Result<(), Self::Error> {
268        let vals = ok_or(self.val.as_product())?.elements.iter();
269        visitor.validate_seq_product(RefProductAccess { vals })
270    }
271
272    fn deserialize_sum<V: de::SumVisitor<'de>>(self, visitor: V) -> Result<V::Output, Self::Error> {
273        let sum = ok_or(self.val.as_sum())?;
274        visitor.visit_sum(SumAccess::from_ref(sum))
275    }
276
277    fn validate_sum<V: de::SumVisitor<'de>>(self, visitor: V) -> Result<(), Self::Error> {
278        let sum = ok_or(self.val.as_sum())?;
279        visitor.validate_sum(SumAccess::from_ref(sum))
280    }
281
282    fn deserialize_bool(self) -> Result<bool, Self::Error> {
283        ok_or(self.val.as_bool().copied())
284    }
285    fn deserialize_u8(self) -> Result<u8, Self::Error> {
286        ok_or(self.val.as_u8().copied())
287    }
288    fn deserialize_u16(self) -> Result<u16, Self::Error> {
289        ok_or(self.val.as_u16().copied())
290    }
291    fn deserialize_u32(self) -> Result<u32, Self::Error> {
292        ok_or(self.val.as_u32().copied())
293    }
294    fn deserialize_u64(self) -> Result<u64, Self::Error> {
295        ok_or(self.val.as_u64().copied())
296    }
297    fn deserialize_u128(self) -> Result<u128, Self::Error> {
298        ok_or(self.val.as_u128().copied().map(|x| x.0))
299    }
300    fn deserialize_u256(self) -> Result<u256, Self::Error> {
301        ok_or(self.val.as_u256().map(|x| **x))
302    }
303    fn deserialize_i8(self) -> Result<i8, Self::Error> {
304        ok_or(self.val.as_i8().copied())
305    }
306    fn deserialize_i16(self) -> Result<i16, Self::Error> {
307        ok_or(self.val.as_i16().copied())
308    }
309    fn deserialize_i32(self) -> Result<i32, Self::Error> {
310        ok_or(self.val.as_i32().copied())
311    }
312    fn deserialize_i64(self) -> Result<i64, Self::Error> {
313        ok_or(self.val.as_i64().copied())
314    }
315    fn deserialize_i128(self) -> Result<i128, Self::Error> {
316        ok_or(self.val.as_i128().copied().map(|x| x.0))
317    }
318    fn deserialize_i256(self) -> Result<i256, Self::Error> {
319        ok_or(self.val.as_i256().map(|x| **x))
320    }
321    fn deserialize_f32(self) -> Result<f32, Self::Error> {
322        ok_or(self.val.as_f32().copied().map(f32::from))
323    }
324    fn deserialize_f64(self) -> Result<f64, Self::Error> {
325        ok_or(self.val.as_f64().copied().map(f64::from))
326    }
327
328    fn deserialize_str<V: de::SliceVisitor<'de, str>>(self, visitor: V) -> Result<V::Output, Self::Error> {
329        visitor.visit_borrowed(ok_or(self.val.as_string())?)
330    }
331
332    fn deserialize_bytes<V: de::SliceVisitor<'de, [u8]>>(self, visitor: V) -> Result<V::Output, Self::Error> {
333        visitor.visit_borrowed(ok_or(self.val.as_bytes())?)
334    }
335
336    fn deserialize_array_seed<V: de::ArrayVisitor<'de, T::Output>, T: de::DeserializeSeed<'de> + Clone>(
337        self,
338        visitor: V,
339        seed: T,
340    ) -> Result<V::Output, Self::Error> {
341        let iter = ok_or(self.val.as_array())?.iter_cloned();
342        visitor.visit(RefArrayAccess { iter, seed })
343    }
344
345    fn validate_array_seed<V: de::ArrayVisitor<'de, T::Output>, T: de::DeserializeSeed<'de> + Clone>(
346        self,
347        visitor: V,
348        seed: T,
349    ) -> Result<(), Self::Error> {
350        let iter = ok_or(self.val.as_array())?.iter_cloned();
351        visitor.validate(RefArrayAccess { iter, seed })
352    }
353}
354
355/// Defines deserialization for [`&'de ValueDeserializer`] where product elements are in the input.
356pub struct RefProductAccess<'a> {
357    /// The element values of the product as an iterator of borrowed values.
358    vals: std::slice::Iter<'a, AlgebraicValue>,
359}
360
361impl<'de> de::SeqProductAccess<'de> for RefProductAccess<'de> {
362    type Error = ValueDeserializeError;
363
364    fn next_element_seed<T: de::DeserializeSeed<'de>>(&mut self, seed: T) -> Result<Option<T::Output>, Self::Error> {
365        self.vals
366            .next()
367            .map(|val| seed.deserialize(ValueDeserializer::from_ref(val)))
368            .transpose()
369    }
370
371    fn validate_next_element_seed<T: de::DeserializeSeed<'de>>(&mut self, seed: T) -> Result<Option<()>, Self::Error> {
372        self.vals
373            .next()
374            .map(|val| seed.validate(ValueDeserializer::from_ref(val)))
375            .transpose()
376    }
377}
378
379impl<'de> de::SumAccess<'de> for &'de SumAccess {
380    type Error = ValueDeserializeError;
381
382    type Variant = &'de ValueDeserializer;
383
384    fn variant<V: de::VariantVisitor<'de>>(self, visitor: V) -> Result<(V::Output, Self::Variant), Self::Error> {
385        let tag = visitor.visit_tag(self.sum.tag)?;
386        Ok((tag, ValueDeserializer::from_ref(&self.sum.value)))
387    }
388}
389
390impl<'de> de::VariantAccess<'de> for &'de ValueDeserializer {
391    type Error = ValueDeserializeError;
392
393    fn deserialize_seed<T: de::DeserializeSeed<'de>>(self, seed: T) -> Result<T::Output, Self::Error> {
394        seed.deserialize(self)
395    }
396
397    fn validate_seed<T: de::DeserializeSeed<'de>>(self, seed: T) -> Result<(), Self::Error> {
398        seed.validate(self)
399    }
400}
401
402/// Defines deserialization for [`&'de ValueDeserializer`] where an array value is in the input.
403struct RefArrayAccess<'a, T> {
404    // TODO: idk this kinda sucks
405    /// The elements of the array as an iterator of cloned elements.
406    iter: ArrayValueIterCloned<'a>,
407    /// A seed value provided by the caller of
408    /// [`deserialize_array_seed`](de::Deserializer::deserialize_array_seed).
409    seed: T,
410}
411
412impl<'de, T: de::DeserializeSeed<'de> + Clone> de::ArrayAccess<'de> for RefArrayAccess<'de, T> {
413    type Element = T::Output;
414    type Error = ValueDeserializeError;
415
416    fn next_element(&mut self) -> Result<Option<Self::Element>, Self::Error> {
417        self.iter
418            .next()
419            .map(|val| self.seed.clone().deserialize(ValueDeserializer { val }))
420            .transpose()
421    }
422
423    fn validate_next_element(&mut self) -> Result<Option<()>, Self::Error> {
424        self.iter
425            .next()
426            .map(|val| self.seed.clone().validate(ValueDeserializer { val }))
427            .transpose()
428    }
429}