tower_sesh/value/
de.rs

1// Adapted from https://github.com/serde-rs/json.
2
3use serde::{
4    de::{EnumAccess, Expected, IntoDeserializer, SeqAccess, Unexpected, VariantAccess, Visitor},
5    Deserialize,
6};
7
8use super::{error::Error, number::Number, Value};
9
10impl<'de> Deserialize<'de> for Value {
11    #[inline]
12    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
13    where
14        D: serde::Deserializer<'de>,
15    {
16        struct ValueVisitor;
17
18        impl<'de> Visitor<'de> for ValueVisitor {
19            type Value = Value;
20
21            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
22                formatter.write_str("a valid session value")
23            }
24
25            #[inline]
26            fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E> {
27                Ok(Value::Bool(v))
28            }
29
30            #[inline]
31            fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E> {
32                Ok(Value::Number(v.into()))
33            }
34
35            fn visit_i128<E>(self, v: i128) -> Result<Self::Value, E>
36            where
37                E: serde::de::Error,
38            {
39                let de = serde::de::value::I128Deserializer::new(v);
40                Number::deserialize(de).map(Value::Number)
41            }
42
43            #[inline]
44            fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E> {
45                Ok(Value::Number(v.into()))
46            }
47
48            fn visit_u128<E>(self, v: u128) -> Result<Self::Value, E>
49            where
50                E: serde::de::Error,
51            {
52                let de = serde::de::value::U128Deserializer::new(v);
53                Number::deserialize(de).map(Value::Number)
54            }
55
56            fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
57            where
58                E: serde::de::Error,
59            {
60                Number::from_f32(v).map(Value::Number).ok_or_else(|| {
61                    serde::de::Error::custom("float must be finite (got NaN or +/-inf)")
62                })
63            }
64
65            fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
66            where
67                E: serde::de::Error,
68            {
69                Number::from_f64(v).map(Value::Number).ok_or_else(|| {
70                    serde::de::Error::custom("float must be finite (got NaN or +/-inf)")
71                })
72            }
73
74            #[inline]
75            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
76            where
77                E: serde::de::Error,
78            {
79                self.visit_string(String::from(v))
80            }
81
82            #[inline]
83            fn visit_string<E>(self, v: String) -> Result<Self::Value, E> {
84                Ok(Value::String(v))
85            }
86
87            #[inline]
88            fn visit_none<E>(self) -> Result<Self::Value, E> {
89                Ok(Value::Null)
90            }
91
92            #[inline]
93            fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
94            where
95                D: serde::Deserializer<'de>,
96            {
97                Deserialize::deserialize(deserializer)
98            }
99
100            #[inline]
101            fn visit_unit<E>(self) -> Result<Self::Value, E> {
102                Ok(Value::Null)
103            }
104
105            #[inline]
106            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
107            where
108                A: serde::de::SeqAccess<'de>,
109            {
110                let mut vec = Vec::new();
111
112                while let Some(elem) = seq.next_element()? {
113                    vec.push(elem);
114                }
115
116                Ok(Value::Array(vec))
117            }
118
119            fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
120            where
121                A: serde::de::MapAccess<'de>,
122            {
123                todo!("map deserialization")
124            }
125        }
126
127        deserializer.deserialize_any(ValueVisitor)
128    }
129}
130
131macro_rules! deserialize_number {
132    ($method:ident) => {
133        fn $method<V>(self, visitor: V) -> Result<V::Value, Error>
134        where
135            V: Visitor<'de>,
136        {
137            match self {
138                Value::Number(n) => n.deserialize_any(visitor),
139                _ => Err(self.invalid_type(&visitor)),
140            }
141        }
142    };
143}
144
145fn visit_array<'de, V>(array: Vec<Value>, visitor: V) -> Result<V::Value, Error>
146where
147    V: Visitor<'de>,
148{
149    let len = array.len();
150    let mut deserializer = SeqDeserializer::new(array);
151    let seq = visitor.visit_seq(&mut deserializer)?;
152    let remaining = deserializer.iter.len();
153    if remaining == 0 {
154        Ok(seq)
155    } else {
156        Err(serde::de::Error::invalid_length(
157            len,
158            &"fewer elements in sequence",
159        ))
160    }
161}
162
163impl<'de> serde::Deserializer<'de> for Value {
164    type Error = Error;
165
166    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
167    where
168        V: Visitor<'de>,
169    {
170        match self {
171            Value::Null => visitor.visit_unit(),
172            Value::Bool(b) => visitor.visit_bool(b),
173            Value::Number(n) => n.deserialize_any(visitor),
174            Value::String(s) => visitor.visit_string(s),
175            Value::ByteArray(b) => visitor.visit_byte_buf(b),
176            Value::Array(v) => visit_array(v, visitor),
177            Value::Map(m) => todo!("map deserialization"),
178        }
179    }
180
181    deserialize_number!(deserialize_i8);
182    deserialize_number!(deserialize_i16);
183    deserialize_number!(deserialize_i32);
184    deserialize_number!(deserialize_i64);
185    deserialize_number!(deserialize_i128);
186    deserialize_number!(deserialize_u8);
187    deserialize_number!(deserialize_u16);
188    deserialize_number!(deserialize_u32);
189    deserialize_number!(deserialize_u64);
190    deserialize_number!(deserialize_u128);
191    deserialize_number!(deserialize_f32);
192    deserialize_number!(deserialize_f64);
193
194    #[inline]
195    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
196    where
197        V: Visitor<'de>,
198    {
199        match self {
200            Value::Null => visitor.visit_none(),
201            _ => visitor.visit_some(self),
202        }
203    }
204
205    #[inline]
206    fn deserialize_enum<V>(
207        self,
208        name: &'static str,
209        variants: &'static [&'static str],
210        visitor: V,
211    ) -> Result<V::Value, Self::Error>
212    where
213        V: Visitor<'de>,
214    {
215        todo!("map deserialization")
216    }
217
218    #[inline]
219    fn deserialize_newtype_struct<V>(
220        self,
221        _name: &'static str,
222        visitor: V,
223    ) -> Result<V::Value, Self::Error>
224    where
225        V: Visitor<'de>,
226    {
227        visitor.visit_newtype_struct(self)
228    }
229
230    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
231    where
232        V: Visitor<'de>,
233    {
234        match self {
235            Value::Bool(v) => visitor.visit_bool(v),
236            _ => Err(self.invalid_type(&visitor)),
237        }
238    }
239
240    #[inline]
241    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
242    where
243        V: Visitor<'de>,
244    {
245        self.deserialize_string(visitor)
246    }
247
248    #[inline]
249    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
250    where
251        V: Visitor<'de>,
252    {
253        self.deserialize_string(visitor)
254    }
255
256    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
257    where
258        V: Visitor<'de>,
259    {
260        match self {
261            Value::String(v) => visitor.visit_string(v),
262            _ => Err(self.invalid_type(&visitor)),
263        }
264    }
265
266    #[inline]
267    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
268    where
269        V: Visitor<'de>,
270    {
271        self.deserialize_byte_buf(visitor)
272    }
273
274    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
275    where
276        V: Visitor<'de>,
277    {
278        match self {
279            Value::String(v) => visitor.visit_string(v),
280            Value::ByteArray(v) => visitor.visit_byte_buf(v),
281            Value::Array(v) => visit_array(v, visitor),
282            _ => Err(self.invalid_type(&visitor)),
283        }
284    }
285
286    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
287    where
288        V: Visitor<'de>,
289    {
290        match self {
291            Value::Null => visitor.visit_unit(),
292            _ => Err(self.invalid_type(&visitor)),
293        }
294    }
295
296    #[inline]
297    fn deserialize_unit_struct<V>(
298        self,
299        _name: &'static str,
300        visitor: V,
301    ) -> Result<V::Value, Self::Error>
302    where
303        V: Visitor<'de>,
304    {
305        self.deserialize_unit(visitor)
306    }
307
308    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
309    where
310        V: Visitor<'de>,
311    {
312        match self {
313            Value::Array(v) => visit_array(v, visitor),
314            _ => Err(self.invalid_type(&visitor)),
315        }
316    }
317
318    #[inline]
319    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
320    where
321        V: Visitor<'de>,
322    {
323        self.deserialize_seq(visitor)
324    }
325
326    #[inline]
327    fn deserialize_tuple_struct<V>(
328        self,
329        _name: &'static str,
330        _len: usize,
331        visitor: V,
332    ) -> Result<V::Value, Self::Error>
333    where
334        V: Visitor<'de>,
335    {
336        self.deserialize_seq(visitor)
337    }
338
339    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
340    where
341        V: Visitor<'de>,
342    {
343        todo!("map deserialization")
344    }
345
346    fn deserialize_struct<V>(
347        self,
348        name: &'static str,
349        fields: &'static [&'static str],
350        visitor: V,
351    ) -> Result<V::Value, Self::Error>
352    where
353        V: Visitor<'de>,
354    {
355        todo!("map deserialization")
356    }
357
358    #[inline]
359    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
360    where
361        V: Visitor<'de>,
362    {
363        self.deserialize_string(visitor)
364    }
365
366    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
367    where
368        V: Visitor<'de>,
369    {
370        drop(self);
371        visitor.visit_unit()
372    }
373}
374
375struct SeqDeserializer {
376    iter: std::vec::IntoIter<Value>,
377}
378
379impl SeqDeserializer {
380    fn new(vec: Vec<Value>) -> Self {
381        SeqDeserializer {
382            iter: vec.into_iter(),
383        }
384    }
385}
386
387impl<'de> SeqAccess<'de> for SeqDeserializer {
388    type Error = Error;
389
390    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
391    where
392        T: serde::de::DeserializeSeed<'de>,
393    {
394        match self.iter.next() {
395            Some(value) => seed.deserialize(value).map(Some),
396            None => Ok(None),
397        }
398    }
399
400    fn size_hint(&self) -> Option<usize> {
401        size_hint_helper(&self.iter)
402    }
403}
404
405fn visit_array_ref<'de, V>(array: &'de [Value], visitor: V) -> Result<V::Value, Error>
406where
407    V: Visitor<'de>,
408{
409    let len = array.len();
410    let mut deserializer = SeqRefDeserializer::new(array);
411    let seq = visitor.visit_seq(&mut deserializer)?;
412    let remaining = deserializer.iter.len();
413    if remaining == 0 {
414        Ok(seq)
415    } else {
416        Err(serde::de::Error::invalid_length(
417            len,
418            &"fewer elements in array",
419        ))
420    }
421}
422
423impl<'de> serde::Deserializer<'de> for &'de Value {
424    type Error = Error;
425
426    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
427    where
428        V: Visitor<'de>,
429    {
430        match self {
431            Value::Null => visitor.visit_unit(),
432            Value::Bool(b) => visitor.visit_bool(*b),
433            Value::Number(n) => n.deserialize_any(visitor),
434            Value::String(s) => visitor.visit_borrowed_str(s),
435            Value::ByteArray(b) => visitor.visit_borrowed_bytes(b),
436            Value::Array(v) => visit_array_ref(v, visitor),
437            Value::Map(m) => todo!("map deserialization"),
438        }
439    }
440
441    deserialize_number!(deserialize_i8);
442    deserialize_number!(deserialize_i16);
443    deserialize_number!(deserialize_i32);
444    deserialize_number!(deserialize_i64);
445    deserialize_number!(deserialize_i128);
446    deserialize_number!(deserialize_u8);
447    deserialize_number!(deserialize_u16);
448    deserialize_number!(deserialize_u32);
449    deserialize_number!(deserialize_u64);
450    deserialize_number!(deserialize_u128);
451    deserialize_number!(deserialize_f32);
452    deserialize_number!(deserialize_f64);
453
454    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
455    where
456        V: Visitor<'de>,
457    {
458        match *self {
459            Value::Null => visitor.visit_none(),
460            _ => visitor.visit_some(self),
461        }
462    }
463
464    fn deserialize_enum<V>(
465        self,
466        name: &'static str,
467        variants: &'static [&'static str],
468        visitor: V,
469    ) -> Result<V::Value, Self::Error>
470    where
471        V: Visitor<'de>,
472    {
473        match self {
474            Value::Map(value) => todo!("map deserialization"),
475            Value::String(variant) => visitor.visit_enum(EnumRefDeserializer {
476                variant,
477                value: None,
478            }),
479            other => Err(serde::de::Error::invalid_type(
480                other.unexpected(),
481                &"string or map",
482            )),
483        }
484    }
485
486    fn deserialize_newtype_struct<V>(
487        self,
488        _name: &'static str,
489        visitor: V,
490    ) -> Result<V::Value, Self::Error>
491    where
492        V: Visitor<'de>,
493    {
494        visitor.visit_newtype_struct(self)
495    }
496
497    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
498    where
499        V: Visitor<'de>,
500    {
501        match *self {
502            Value::Bool(v) => visitor.visit_bool(v),
503            _ => Err(self.invalid_type(&visitor)),
504        }
505    }
506
507    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
508    where
509        V: Visitor<'de>,
510    {
511        self.deserialize_str(visitor)
512    }
513
514    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
515    where
516        V: Visitor<'de>,
517    {
518        match self {
519            Value::String(v) => visitor.visit_borrowed_str(v),
520            _ => Err(self.invalid_type(&visitor)),
521        }
522    }
523
524    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
525    where
526        V: Visitor<'de>,
527    {
528        self.deserialize_str(visitor)
529    }
530
531    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
532    where
533        V: Visitor<'de>,
534    {
535        match self {
536            Value::String(v) => visitor.visit_borrowed_str(v),
537            Value::ByteArray(v) => visitor.visit_borrowed_bytes(v),
538            Value::Array(v) => visit_array_ref(v, visitor),
539            _ => Err(self.invalid_type(&visitor)),
540        }
541    }
542
543    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
544    where
545        V: Visitor<'de>,
546    {
547        self.deserialize_bytes(visitor)
548    }
549
550    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
551    where
552        V: Visitor<'de>,
553    {
554        match *self {
555            Value::Null => visitor.visit_unit(),
556            _ => Err(self.invalid_type(&visitor)),
557        }
558    }
559
560    fn deserialize_unit_struct<V>(
561        self,
562        _name: &'static str,
563        visitor: V,
564    ) -> Result<V::Value, Self::Error>
565    where
566        V: Visitor<'de>,
567    {
568        self.deserialize_unit(visitor)
569    }
570
571    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
572    where
573        V: Visitor<'de>,
574    {
575        match self {
576            Value::Array(v) => visit_array_ref(v, visitor),
577            _ => Err(self.invalid_type(&visitor)),
578        }
579    }
580
581    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
582    where
583        V: Visitor<'de>,
584    {
585        self.deserialize_seq(visitor)
586    }
587
588    fn deserialize_tuple_struct<V>(
589        self,
590        _name: &'static str,
591        _len: usize,
592        visitor: V,
593    ) -> Result<V::Value, Self::Error>
594    where
595        V: Visitor<'de>,
596    {
597        self.deserialize_seq(visitor)
598    }
599
600    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
601    where
602        V: Visitor<'de>,
603    {
604        match self {
605            Value::Map(v) => todo!("map deserialization"),
606            _ => Err(self.invalid_type(&visitor)),
607        }
608    }
609
610    fn deserialize_struct<V>(
611        self,
612        _name: &'static str,
613        _fields: &'static [&'static str],
614        visitor: V,
615    ) -> Result<V::Value, Self::Error>
616    where
617        V: Visitor<'de>,
618    {
619        match self {
620            Value::Array(v) => visit_array_ref(v, visitor),
621            Value::Map(v) => todo!("map deserialization"),
622            _ => Err(self.invalid_type(&visitor)),
623        }
624    }
625
626    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
627    where
628        V: Visitor<'de>,
629    {
630        self.deserialize_str(visitor)
631    }
632
633    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
634    where
635        V: Visitor<'de>,
636    {
637        visitor.visit_unit()
638    }
639}
640
641struct EnumRefDeserializer<'de> {
642    variant: &'de str,
643    value: Option<&'de Value>,
644}
645
646impl<'de> EnumAccess<'de> for EnumRefDeserializer<'de> {
647    type Error = Error;
648    type Variant = VariantRefDeserializer<'de>;
649
650    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Error>
651    where
652        V: serde::de::DeserializeSeed<'de>,
653    {
654        let variant = self.variant.into_deserializer();
655        let visitor = VariantRefDeserializer { value: self.value };
656        seed.deserialize(variant).map(|v| (v, visitor))
657    }
658}
659
660struct VariantRefDeserializer<'de> {
661    value: Option<&'de Value>,
662}
663
664impl<'de> VariantAccess<'de> for VariantRefDeserializer<'de> {
665    type Error = Error;
666
667    fn unit_variant(self) -> Result<(), Self::Error> {
668        match self.value {
669            Some(value) => Deserialize::deserialize(value),
670            None => Ok(()),
671        }
672    }
673
674    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
675    where
676        T: serde::de::DeserializeSeed<'de>,
677    {
678        match self.value {
679            Some(value) => seed.deserialize(value),
680            None => Err(serde::de::Error::invalid_type(
681                Unexpected::UnitVariant,
682                &"newtype variant",
683            )),
684        }
685    }
686
687    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
688    where
689        V: Visitor<'de>,
690    {
691        match self.value {
692            Some(Value::Array(v)) => {
693                if v.is_empty() {
694                    visitor.visit_unit()
695                } else {
696                    visit_array_ref(v, visitor)
697                }
698            }
699            Some(other) => Err(serde::de::Error::invalid_type(
700                other.unexpected(),
701                &"tuple variant",
702            )),
703            None => Err(serde::de::Error::invalid_type(
704                Unexpected::UnitVariant,
705                &"tuple variant",
706            )),
707        }
708    }
709
710    fn struct_variant<V>(
711        self,
712        _fields: &'static [&'static str],
713        visitor: V,
714    ) -> Result<V::Value, Self::Error>
715    where
716        V: Visitor<'de>,
717    {
718        match self.value {
719            Some(Value::Map(m)) => todo!("map deserialization"),
720            Some(other) => Err(serde::de::Error::invalid_type(
721                other.unexpected(),
722                &"struct variant",
723            )),
724            None => Err(serde::de::Error::invalid_type(
725                Unexpected::UnitVariant,
726                &"struct variant",
727            )),
728        }
729    }
730}
731
732struct SeqRefDeserializer<'de> {
733    iter: std::slice::Iter<'de, Value>,
734}
735
736impl<'de> SeqRefDeserializer<'de> {
737    fn new(slice: &'de [Value]) -> Self {
738        SeqRefDeserializer { iter: slice.iter() }
739    }
740}
741
742impl<'de> SeqAccess<'de> for SeqRefDeserializer<'de> {
743    type Error = Error;
744
745    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
746    where
747        T: serde::de::DeserializeSeed<'de>,
748    {
749        match self.iter.next() {
750            Some(value) => seed.deserialize(value).map(Some),
751            None => Ok(None),
752        }
753    }
754
755    fn size_hint(&self) -> Option<usize> {
756        size_hint_helper(&self.iter)
757    }
758}
759
760impl<'de> IntoDeserializer<'de, Error> for Value {
761    type Deserializer = Self;
762
763    fn into_deserializer(self) -> Self::Deserializer {
764        self
765    }
766}
767
768impl<'de> IntoDeserializer<'de, Error> for &'de Value {
769    type Deserializer = Self;
770
771    fn into_deserializer(self) -> Self::Deserializer {
772        self
773    }
774}
775
776impl Value {
777    #[cold]
778    fn invalid_type<E>(&self, exp: &dyn Expected) -> E
779    where
780        E: serde::de::Error,
781    {
782        serde::de::Error::invalid_type(self.unexpected(), exp)
783    }
784
785    #[cold]
786    fn unexpected(&self) -> Unexpected {
787        match self {
788            Value::Null => Unexpected::Unit,
789            Value::Bool(b) => Unexpected::Bool(*b),
790            Value::Number(n) => n.unexpected(),
791            Value::String(s) => Unexpected::Str(s),
792            Value::ByteArray(b) => Unexpected::Bytes(b),
793            Value::Array(_) => Unexpected::Seq,
794            Value::Map(_) => Unexpected::Map,
795        }
796    }
797}
798
799fn size_hint_helper<I>(iter: &I) -> Option<usize>
800where
801    I: Iterator,
802{
803    fn _size_hint_helper(bounds: (usize, Option<usize>)) -> Option<usize> {
804        match bounds {
805            (lower, Some(upper)) if lower == upper => Some(upper),
806            (_, _) => None,
807        }
808    }
809
810    _size_hint_helper(iter.size_hint())
811}