Skip to main content

serde_firestore_value/de/
deserializer.rs

1use crate::de::GoogleFirestoreFunctionMapAccess;
2use crate::de::GoogleFirestorePipelineMapAccess;
3use crate::de::GoogleTypeLatLngMapAccess;
4use crate::de::ProstTypesTimestampMapAccess;
5use crate::de::firestore_enum_deserializer::FirestoreEnumDeserializer;
6use crate::google::firestore::v1::{Value, value::ValueType};
7use crate::{
8    Error, FieldReference, Function, LatLng, Pipeline, Reference, Timestamp, error::ErrorCode,
9    value_ext::ValueExt,
10};
11
12/// A Deserializer type which implements [`serde::Deserializer`] for [`Value`].
13#[derive(Debug)]
14pub struct Deserializer<'a> {
15    value: &'a Value,
16}
17
18impl<'de> Deserializer<'de> {
19    /// Creates a new [`Deserializer`].
20    pub fn new(value: &'de Value) -> Self {
21        Self { value }
22    }
23}
24
25impl<'a> serde::Deserializer<'a> for Deserializer<'a> {
26    type Error = Error;
27
28    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
29    where
30        V: serde::de::Visitor<'a>,
31    {
32        match self.value.value_type {
33            Some(ref value_type) => match value_type {
34                ValueType::NullValue(_) => visitor.visit_unit(),
35                ValueType::BooleanValue(v) => visitor.visit_bool(*v),
36                ValueType::IntegerValue(v) => visitor.visit_i64(*v),
37                ValueType::DoubleValue(v) => visitor.visit_f64(*v),
38                ValueType::TimestampValue(v) => {
39                    visitor.visit_map(ProstTypesTimestampMapAccess::new(v))
40                }
41                ValueType::StringValue(v) => visitor.visit_str(v),
42                ValueType::BytesValue(v) => visitor.visit_bytes(v),
43                ValueType::ReferenceValue(v) => {
44                    visitor.visit_map(serde::de::value::MapDeserializer::new(std::iter::once((
45                        crate::Reference::NAME,
46                        serde::de::value::StrDeserializer::new(v),
47                    ))))
48                }
49                ValueType::GeoPointValue(v) => visitor.visit_map(GoogleTypeLatLngMapAccess::new(v)),
50                ValueType::ArrayValue(v) => visitor.visit_seq(
51                    serde::de::value::SeqDeserializer::new(v.values.iter().map(Deserializer::new)),
52                ),
53                ValueType::MapValue(map) => {
54                    visitor.visit_map(serde::de::value::MapDeserializer::new(
55                        map.fields
56                            .iter()
57                            .map(|(k, v)| (k.as_str(), Deserializer::new(v))),
58                    ))
59                }
60                ValueType::FieldReferenceValue(v) => {
61                    visitor.visit_map(serde::de::value::MapDeserializer::new(std::iter::once((
62                        crate::FieldReference::NAME,
63                        serde::de::value::StrDeserializer::new(v),
64                    ))))
65                }
66                ValueType::FunctionValue(v) => {
67                    visitor.visit_map(GoogleFirestoreFunctionMapAccess::new(v))
68                }
69                ValueType::PipelineValue(v) => {
70                    visitor.visit_map(GoogleFirestorePipelineMapAccess::new(v))
71                }
72            },
73            None => Err(Error::from(ErrorCode::ValueTypeMustBeSome)),
74        }
75    }
76
77    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
78    where
79        V: serde::de::Visitor<'a>,
80    {
81        let value = self.value.as_boolean()?;
82        visitor.visit_bool(value)
83    }
84
85    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
86    where
87        V: serde::de::Visitor<'a>,
88    {
89        let value = self.value.as_integer()?;
90        visitor.visit_i8(i8::try_from(value).map_err(|_| Error::from(ErrorCode::I8OutOfRange))?)
91    }
92
93    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
94    where
95        V: serde::de::Visitor<'a>,
96    {
97        let value = self.value.as_integer()?;
98        visitor.visit_i16(i16::try_from(value).map_err(|_| Error::from(ErrorCode::I16OutOfRange))?)
99    }
100
101    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
102    where
103        V: serde::de::Visitor<'a>,
104    {
105        let value = self.value.as_integer()?;
106        visitor.visit_i32(i32::try_from(value).map_err(|_| Error::from(ErrorCode::I32OutOfRange))?)
107    }
108
109    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
110    where
111        V: serde::de::Visitor<'a>,
112    {
113        let value = self.value.as_integer()?;
114        visitor.visit_i64(value)
115    }
116
117    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
118    where
119        V: serde::de::Visitor<'a>,
120    {
121        let value = self.value.as_integer()?;
122        visitor.visit_u8(u8::try_from(value).map_err(|_| Error::from(ErrorCode::U8OutOfRange))?)
123    }
124
125    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
126    where
127        V: serde::de::Visitor<'a>,
128    {
129        let value = self.value.as_integer()?;
130        visitor.visit_u16(u16::try_from(value).map_err(|_| Error::from(ErrorCode::U16OutOfRange))?)
131    }
132
133    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
134    where
135        V: serde::de::Visitor<'a>,
136    {
137        let value = self.value.as_integer()?;
138        visitor.visit_u32(u32::try_from(value).map_err(|_| Error::from(ErrorCode::U32OutOfRange))?)
139    }
140
141    fn deserialize_u64<V>(self, _: V) -> Result<V::Value, Self::Error>
142    where
143        V: serde::de::Visitor<'a>,
144    {
145        Err(Error::from(ErrorCode::U64IsNotSupported))
146    }
147
148    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
149    where
150        V: serde::de::Visitor<'a>,
151    {
152        let value = self.value.as_double()?;
153        visitor.visit_f32(value as f32)
154    }
155
156    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
157    where
158        V: serde::de::Visitor<'a>,
159    {
160        let value = self.value.as_double()?;
161        visitor.visit_f64(value)
162    }
163
164    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
165    where
166        V: serde::de::Visitor<'a>,
167    {
168        let value = self.value.as_string()?;
169        let mut chars = value.chars();
170        match (chars.next(), chars.next()) {
171            (None, None) => Err(Error::from(ErrorCode::StringIsEmpty)),
172            (None, Some(_)) => unreachable!(),
173            (Some(c), None) => visitor.visit_char(c),
174            (Some(_), Some(_)) => Err(Error::from(ErrorCode::TooManyChars)),
175        }
176    }
177
178    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
179    where
180        V: serde::de::Visitor<'a>,
181    {
182        let value = self.value.as_string()?;
183        visitor.visit_str(value)
184    }
185
186    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
187    where
188        V: serde::de::Visitor<'a>,
189    {
190        self.deserialize_str(visitor)
191    }
192
193    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
194    where
195        V: serde::de::Visitor<'a>,
196    {
197        let value = self.value.as_bytes()?;
198        visitor.visit_bytes(value)
199    }
200
201    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
202    where
203        V: serde::de::Visitor<'a>,
204    {
205        let value = self.value.as_bytes()?;
206        visitor.visit_byte_buf(value.to_vec())
207    }
208
209    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
210    where
211        V: serde::de::Visitor<'a>,
212    {
213        match self.value.value_type()? {
214            ValueType::NullValue(_) => visitor.visit_none(),
215            _ => visitor.visit_some(self),
216        }
217    }
218
219    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
220    where
221        V: serde::de::Visitor<'a>,
222    {
223        self.value.as_null()?;
224        visitor.visit_unit()
225    }
226
227    fn deserialize_unit_struct<V>(
228        self,
229        _name: &'static str,
230        visitor: V,
231    ) -> Result<V::Value, Self::Error>
232    where
233        V: serde::de::Visitor<'a>,
234    {
235        self.deserialize_unit(visitor)
236    }
237
238    fn deserialize_newtype_struct<V>(
239        self,
240        name: &'static str,
241        visitor: V,
242    ) -> Result<V::Value, Self::Error>
243    where
244        V: serde::de::Visitor<'a>,
245    {
246        if name == FieldReference::NAME {
247            visitor.visit_newtype_struct(serde::de::value::StrDeserializer::new(
248                self.value.as_field_reference_value_as_string()?,
249            ))
250        } else if name == Reference::NAME {
251            visitor.visit_newtype_struct(serde::de::value::StrDeserializer::new(
252                self.value.as_reference_value_as_string()?,
253            ))
254        } else {
255            visitor.visit_newtype_struct(self)
256        }
257    }
258
259    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
260    where
261        V: serde::de::Visitor<'a>,
262    {
263        visitor.visit_seq(serde::de::value::SeqDeserializer::new(
264            self.value.as_values()?.iter().map(Deserializer::new),
265        ))
266    }
267
268    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
269    where
270        V: serde::de::Visitor<'a>,
271    {
272        visitor.visit_seq(serde::de::value::SeqDeserializer::new(
273            self.value.as_values()?.iter().map(Deserializer::new),
274        ))
275    }
276
277    fn deserialize_tuple_struct<V>(
278        self,
279        _name: &'static str,
280        _len: usize,
281        visitor: V,
282    ) -> Result<V::Value, Self::Error>
283    where
284        V: serde::de::Visitor<'a>,
285    {
286        visitor.visit_seq(serde::de::value::SeqDeserializer::new(
287            self.value.as_values()?.iter().map(Deserializer::new),
288        ))
289    }
290
291    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
292    where
293        V: serde::de::Visitor<'a>,
294    {
295        visitor.visit_map(serde::de::value::MapDeserializer::new(
296            self.value
297                .as_fields()?
298                .iter()
299                .map(|(k, v)| (k.as_str(), Deserializer::new(v))),
300        ))
301    }
302
303    fn deserialize_struct<V>(
304        self,
305        name: &'static str,
306        fields: &'static [&'static str],
307        visitor: V,
308    ) -> Result<V::Value, Self::Error>
309    where
310        V: serde::de::Visitor<'a>,
311    {
312        if name == Function::NAME {
313            visitor.visit_map(GoogleFirestoreFunctionMapAccess::new(
314                self.value.as_function()?,
315            ))
316        } else if name == LatLng::NAME {
317            visitor.visit_map(GoogleTypeLatLngMapAccess::new(self.value.as_lat_lng()?))
318        } else if name == Pipeline::NAME {
319            visitor.visit_map(GoogleFirestorePipelineMapAccess::new(
320                self.value.as_pipeline()?,
321            ))
322        } else if name == Timestamp::NAME {
323            visitor.visit_map(ProstTypesTimestampMapAccess::new(
324                self.value.as_timestamp()?,
325            ))
326        } else {
327            visitor.visit_map(serde::de::value::MapDeserializer::new(
328                self.value
329                    .as_fields()?
330                    .iter()
331                    .filter(|(k, _)| fields.contains(&k.as_str()))
332                    .map(|(k, v)| (k.as_str(), Deserializer::new(v))),
333            ))
334        }
335    }
336
337    fn deserialize_enum<V>(
338        self,
339        _name: &'static str,
340        variants: &'static [&'static str],
341        visitor: V,
342    ) -> Result<V::Value, Self::Error>
343    where
344        V: serde::de::Visitor<'a>,
345    {
346        visitor.visit_enum(FirestoreEnumDeserializer::new(self.value, variants)?)
347    }
348
349    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
350    where
351        V: serde::de::Visitor<'a>,
352    {
353        match self.value.value_type()? {
354            ValueType::StringValue(s) => visitor.visit_str(s.as_str()),
355            ValueType::MapValue(_) => {
356                let (variant, _) = self.value.as_variant_value()?;
357                visitor.visit_str(variant.as_str())
358            }
359            _ => todo!(),
360        }
361    }
362
363    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
364    where
365        V: serde::de::Visitor<'a>,
366    {
367        visitor.visit_unit()
368    }
369}
370
371impl<'de> serde::de::IntoDeserializer<'de, Error> for Deserializer<'de> {
372    type Deserializer = Self;
373
374    fn into_deserializer(self) -> Self::Deserializer {
375        self
376    }
377}