serde_redis/
decode.rs

1use redis::Value;
2use serde::{self, de};
3use std::borrow::Cow;
4use std::fmt::{self, Display};
5use std::iter::Peekable;
6use std::{error, num, str, string, vec};
7
8use crate::cow_iter::CowIter;
9
10/// Error that can be produced during deserialization
11#[derive(Debug)]
12pub enum Error {
13    Custom(String),
14    EndOfStream,
15    UnknownVariant(String, &'static [&'static str]),
16    UnknownField(String, &'static [&'static str]),
17    MissingField(&'static str),
18    DuplicateField(&'static str),
19    DeserializeNotSupported,
20    WrongValue(String),
21    StrFromUtf8(str::Utf8Error),
22    StringFromUtf8(string::FromUtf8Error),
23    ParseInt(num::ParseIntError),
24    ParseFloat(num::ParseFloatError),
25}
26
27impl Error {
28    pub fn wrong_value<S>(msg: S) -> Error
29    where
30        S: Into<String>,
31    {
32        Error::WrongValue(msg.into())
33    }
34}
35
36pub type Result<T> = ::std::result::Result<T, Error>;
37
38impl error::Error for Error {
39    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
40        match *self {
41            Error::StrFromUtf8(ref err) => Some(err),
42            Error::StringFromUtf8(ref err) => Some(err),
43            Error::ParseInt(ref err) => Some(err),
44            Error::ParseFloat(ref err) => Some(err),
45            _ => None,
46        }
47    }
48}
49
50impl fmt::Display for Error {
51    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52        match *self {
53            Error::Custom(ref reason) => write!(f, "CustomError({})", reason),
54            Error::EndOfStream => write!(f, "Reached end of stream"),
55            Error::UnknownVariant(ref variant, ref expected) => write!(
56                f,
57                "unexpected variant \"{}\"; expected {:?}",
58                variant, expected
59            ),
60            Error::UnknownField(ref field, ref expected) => {
61                write!(f, "unexpected field \"{}\"; expected {:?}", field, expected)
62            }
63            Error::MissingField(ref field) => write!(f, "missing field {:?}", field),
64            Error::DuplicateField(ref field) => write!(f, "duplicate field {:?}", field),
65            Error::DeserializeNotSupported => write!(f, "Deserialization option not supported"),
66            Error::WrongValue(ref value_type) => write!(f, "Got unexpected value: {}", value_type),
67            Error::StrFromUtf8(ref e) => write!(f, "{}", e),
68            Error::StringFromUtf8(ref e) => write!(f, "{}", e),
69            Error::ParseInt(ref e) => write!(f, "{}", e),
70            Error::ParseFloat(ref e) => write!(f, "{}", e),
71        }
72    }
73}
74
75impl de::Error for Error {
76    /// Raised when there is general error when deserializing a type.
77    fn custom<T: Display>(msg: T) -> Self {
78        Error::Custom(msg.to_string())
79    }
80
81    /// Raised when a `Deserialize` enum type received an unexpected variant.
82    fn unknown_variant(variant: &str, expected: &'static [&'static str]) -> Self {
83        Error::UnknownVariant(variant.to_owned(), expected)
84    }
85
86    fn unknown_field(field: &str, expected: &'static [&'static str]) -> Error {
87        Error::UnknownField(field.to_owned(), expected)
88    }
89
90    fn missing_field(field: &'static str) -> Error {
91        Error::MissingField(field)
92    }
93
94    fn duplicate_field(field: &'static str) -> Error {
95        Error::DuplicateField(field)
96    }
97}
98
99impl From<str::Utf8Error> for Error {
100    fn from(err: str::Utf8Error) -> Error {
101        Error::StrFromUtf8(err)
102    }
103}
104
105impl From<string::FromUtf8Error> for Error {
106    fn from(err: string::FromUtf8Error) -> Error {
107        Error::StringFromUtf8(err)
108    }
109}
110
111impl From<num::ParseIntError> for Error {
112    fn from(err: num::ParseIntError) -> Error {
113        Error::ParseInt(err)
114    }
115}
116
117impl From<num::ParseFloatError> for Error {
118    fn from(err: num::ParseFloatError) -> Error {
119        Error::ParseFloat(err)
120    }
121}
122
123/// deserializes Redis `Value`s
124///
125/// Deserializes a sequence of redis values. In the case of a Bulk value (eg, a
126/// nested list), another deserializer is created for that sequence. The limit
127/// to nested sequences is proportional to the maximum stack depth in current
128/// machine.
129///
130/// If creating a Deserializer manually (ie not using `from_redis_value()`), the redis values must
131/// first be placed in a Vec.
132#[derive(Debug)]
133pub struct Deserializer<'a> {
134    values: Peekable<vec::IntoIter<Cow<'a, Value>>>,
135}
136
137pub trait AsValueVec<'a> {
138    fn as_value_vec(self) -> Vec<Cow<'a, Value>>;
139}
140
141impl<'a> AsValueVec<'a> for &'a Value {
142    #[inline]
143    fn as_value_vec(self) -> Vec<Cow<'a, Value>> {
144        vec![Cow::Borrowed(self)]
145    }
146}
147
148impl<'a> AsValueVec<'a> for Cow<'a, Value> {
149    #[inline]
150    fn as_value_vec(self) -> Vec<Cow<'a, Value>> {
151        vec![self]
152    }
153}
154
155impl AsValueVec<'static> for Value {
156    #[inline]
157    fn as_value_vec(self) -> Vec<Cow<'static, Value>> {
158        vec![Cow::Owned(self)]
159    }
160}
161
162impl<'a> AsValueVec<'a> for Vec<Cow<'a, Value>> {
163    #[inline]
164    fn as_value_vec(self) -> Vec<Cow<'a, Value>> {
165        self
166    }
167}
168
169impl<'a> Deserializer<'a> {
170    pub fn new<V>(values: V) -> Self
171    where
172        V: AsValueVec<'a>,
173    {
174        Deserializer {
175            values: values.as_value_vec().into_iter().peekable(),
176        }
177    }
178
179    /// Returns a reference to the next value
180    #[inline]
181    pub fn peek(&mut self) -> Option<&Value> {
182        let val = self.values.peek()?;
183
184        Some(val)
185    }
186
187    /// Return the next value
188    #[inline]
189    pub fn next(&mut self) -> Result<Cow<'a, Value>> {
190        match self.values.next() {
191            Some(value) => Ok(value),
192            None => Err(Error::EndOfStream),
193        }
194    }
195
196    pub fn next_bulk(&mut self) -> Result<Cow<'a, Vec<Value>>> {
197        match self.next()? {
198            Cow::Owned(Value::Bulk(values)) => Ok(Cow::Owned(values)),
199            Cow::Borrowed(Value::Bulk(values)) => Ok(Cow::Borrowed(values)),
200            v @ _ => Err(Error::wrong_value(format!("expected bulk but got {:?}", v))),
201        }
202    }
203
204    pub fn next_bytes(&mut self) -> Result<Cow<'a, Vec<u8>>> {
205        match self.next()? {
206            Cow::Owned(Value::Data(bytes)) => Ok(Cow::Owned(bytes)),
207            Cow::Borrowed(Value::Data(bytes)) => Ok(Cow::Borrowed(bytes)),
208            v => {
209                let msg = format!("Expected bytes, but got {:?}", v);
210                return Err(Error::wrong_value(msg));
211            }
212        }
213    }
214
215    pub fn read_string(&mut self) -> Result<Cow<'a, str>> {
216        let redis_value = self.next()?;
217        Ok(match redis_value {
218            Cow::Owned(Value::Data(bytes)) => Cow::Owned(String::from_utf8(bytes)?),
219            Cow::Borrowed(Value::Data(bytes)) => Cow::Borrowed(str::from_utf8(bytes)?),
220            _ => {
221                let msg = format!("Expected Data, got {:?}", &redis_value);
222                return Err(Error::wrong_value(msg));
223            }
224        })
225    }
226}
227
228macro_rules! impl_num {
229    ($ty:ty, $deserialize_method:ident, $visitor_method:ident) => {
230        #[inline]
231        fn $deserialize_method<V>(mut self, visitor: V) -> Result<V::Value>
232        where
233            V: de::Visitor<'de>,
234        {
235            let redis_value = self.next()?;
236            let value = match redis_value {
237                Cow::Borrowed(Value::Data(bytes)) => {
238                    let s = str::from_utf8(bytes)?;
239                    s.parse::<$ty>()?
240                }
241                Cow::Owned(Value::Data(bytes)) => {
242                    let s = String::from_utf8(bytes)?;
243                    s.parse::<$ty>()?
244                }
245                Cow::Borrowed(Value::Int(i)) => *i as $ty,
246                Cow::Owned(Value::Int(i)) => i as $ty,
247                _ => {
248                    let msg = format!("Expected Data or Int, got {:?}", &redis_value);
249                    return Err(Error::wrong_value(msg));
250                }
251            };
252
253            visitor.$visitor_method(value)
254        }
255    };
256}
257
258macro_rules! default_deserialize {
259    ($($name:ident)*) => {
260        $(
261            #[inline]
262            fn $name<V>(self, visitor: V) -> Result<V::Value>
263                where V: de::Visitor<'de>
264            {
265                self.deserialize_str(visitor)
266            }
267        )*
268    }
269}
270
271impl<'a, 'de> serde::Deserializer<'de> for Deserializer<'a> {
272    type Error = Error;
273
274    #[inline]
275    fn deserialize_any<V>(mut self, visitor: V) -> Result<V::Value>
276    where
277        V: de::Visitor<'de>,
278    {
279        let buf = self.next_bytes()?;
280        match buf {
281            Cow::Borrowed(buf) => visitor.visit_bytes(buf),
282            Cow::Owned(buf) => visitor.visit_byte_buf(buf),
283        }
284    }
285
286    #[inline]
287    fn deserialize_string<V>(mut self, visitor: V) -> Result<V::Value>
288    where
289        V: de::Visitor<'de>,
290    {
291        let s = self.read_string()?;
292        match s {
293            Cow::Borrowed(s) => visitor.visit_str(s),
294            Cow::Owned(s) => visitor.visit_string(s),
295        }
296    }
297
298    #[inline]
299    fn deserialize_str<V>(mut self, visitor: V) -> Result<V::Value>
300    where
301        V: de::Visitor<'de>,
302    {
303        let s = self.read_string()?;
304        match s {
305            Cow::Borrowed(s) => visitor.visit_str(s),
306            Cow::Owned(s) => visitor.visit_string(s),
307        }
308    }
309
310    impl_num!(u8, deserialize_u8, visit_u8);
311    impl_num!(u16, deserialize_u16, visit_u16);
312    impl_num!(u32, deserialize_u32, visit_u32);
313    impl_num!(u64, deserialize_u64, visit_u64);
314
315    impl_num!(i8, deserialize_i8, visit_i8);
316    impl_num!(i16, deserialize_i16, visit_i16);
317    impl_num!(i32, deserialize_i32, visit_i32);
318    impl_num!(i64, deserialize_i64, visit_i64);
319
320    impl_num!(f32, deserialize_f32, visit_f32);
321    impl_num!(f64, deserialize_f64, visit_f64);
322
323    default_deserialize!(
324        deserialize_char
325        deserialize_unit
326    );
327
328    #[inline]
329    fn deserialize_bool<V>(mut self, visitor: V) -> Result<V::Value>
330    where
331        V: de::Visitor<'de>,
332    {
333        let s = self.read_string()?;
334
335        let b = match s.as_ref() {
336            "1" | "true" | "True" => true,
337            "0" | "false" | "False" => false,
338            _ => {
339                return Err(Error::WrongValue(format!(
340                    "Expected 1/0/true/false/True/False, got {}",
341                    s
342                )))
343            }
344        };
345
346        visitor.visit_bool(b)
347    }
348
349    #[inline]
350    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
351    where
352        V: de::Visitor<'de>,
353    {
354        self.deserialize_byte_buf(visitor)
355    }
356
357    #[inline]
358    fn deserialize_byte_buf<V>(mut self, visitor: V) -> Result<V::Value>
359    where
360        V: de::Visitor<'de>,
361    {
362        let bytes = self.next_bytes()?;
363        match bytes {
364            Cow::Borrowed(bytes) => visitor.visit_bytes(bytes),
365            Cow::Owned(bytes) => visitor.visit_byte_buf(bytes),
366        }
367    }
368
369    #[inline]
370    fn deserialize_tuple_struct<V>(
371        self,
372        _name: &'static str,
373        len: usize,
374        visitor: V,
375    ) -> Result<V::Value>
376    where
377        V: de::Visitor<'de>,
378    {
379        self.deserialize_tuple(len, visitor)
380    }
381
382    #[inline]
383    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
384    where
385        V: de::Visitor<'de>,
386    {
387        self.deserialize_seq(visitor)
388    }
389
390    #[inline]
391    fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value>
392    where
393        V: de::Visitor<'de>,
394    {
395        let values = self.next_bulk()?;
396        visitor.visit_seq(SeqVisitor {
397            iter: CowIter::new(values),
398        })
399    }
400
401    #[inline]
402    fn deserialize_map<V>(mut self, visitor: V) -> Result<V::Value>
403    where
404        V: de::Visitor<'de>,
405    {
406        let values = self.next_bulk()?;
407        visitor.visit_map(MapVisitor {
408            iter: CowIter::new(values),
409        })
410    }
411
412    #[inline]
413    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
414    where
415        V: de::Visitor<'de>,
416    {
417        self.deserialize_unit(visitor)
418    }
419
420    #[inline]
421    fn deserialize_struct<V>(
422        self,
423        _name: &'static str,
424        _fields: &'static [&'static str],
425        visitor: V,
426    ) -> Result<V::Value>
427    where
428        V: de::Visitor<'de>,
429    {
430        self.deserialize_map(visitor)
431    }
432
433    #[inline]
434    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
435    where
436        V: de::Visitor<'de>,
437    {
438        self.deserialize_any(visitor)
439    }
440
441    #[inline]
442    fn deserialize_enum<V>(
443        mut self,
444        _enum: &'static str,
445        _variants: &'static [&'static str],
446        visitor: V,
447    ) -> Result<V::Value>
448    where
449        V: de::Visitor<'de>,
450    {
451        visitor.visit_enum(EnumVisitor {
452            variant: self.next()?,
453            content: Cow::Owned(Value::Nil),
454        })
455    }
456
457    #[inline]
458    fn deserialize_option<V>(mut self, visitor: V) -> Result<V::Value>
459    where
460        V: de::Visitor<'de>,
461    {
462        let maybe = match self.peek() {
463            Some(v) => match *v {
464                Value::Data(_) => Some(()),
465                Value::Int(_) => Some(()),
466                Value::Nil => None,
467                _ => {
468                    let msg = format!("Expected Data, Int, or Nil");
469                    return Err(Error::wrong_value(msg));
470                }
471            },
472            None => None,
473        };
474
475        if maybe.is_some() {
476            visitor.visit_some(self)
477        } else {
478            visitor.visit_none()
479        }
480    }
481
482    #[inline]
483    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
484    where
485        V: de::Visitor<'de>,
486    {
487        visitor.visit_newtype_struct(self)
488    }
489
490    #[inline]
491    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
492    where
493        V: de::Visitor<'de>,
494    {
495        self.deserialize_str(visitor)
496    }
497}
498
499struct SeqVisitor<'a> {
500    iter: CowIter<'a>,
501}
502
503impl<'a, 'de> de::SeqAccess<'de> for SeqVisitor<'a> {
504    type Error = Error;
505
506    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
507    where
508        T: de::DeserializeSeed<'de>,
509    {
510        match self.iter.next() {
511            Some(v) => seed.deserialize(Deserializer::new(v)).map(Some),
512            None => Ok(None),
513        }
514    }
515
516    fn size_hint(&self) -> Option<usize> {
517        self.iter.size_hint().1
518    }
519}
520
521struct MapVisitor<'a> {
522    iter: CowIter<'a>,
523}
524
525impl<'a, 'de> serde::de::MapAccess<'de> for MapVisitor<'a> {
526    type Error = Error;
527
528    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
529    where
530        K: de::DeserializeSeed<'de>,
531    {
532        match self.iter.next() {
533            Some(v) => seed.deserialize(Deserializer::new(v)).map(Some),
534            None => Ok(None),
535        }
536    }
537
538    #[inline]
539    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
540    where
541        V: de::DeserializeSeed<'de>,
542    {
543        match self.iter.next() {
544            Some(v) => seed.deserialize(Deserializer::new(v)),
545            None => Err(Error::EndOfStream),
546        }
547    }
548}
549
550struct VariantVisitor<'a> {
551    value: Cow<'a, Value>,
552}
553
554impl<'a, 'de> serde::de::VariantAccess<'de> for VariantVisitor<'a> {
555    type Error = Error;
556
557    fn unit_variant(self) -> Result<()> {
558        Ok(())
559    }
560
561    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
562    where
563        T: de::DeserializeSeed<'de>,
564    {
565        seed.deserialize(Deserializer::new(self.value))
566    }
567
568    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
569    where
570        V: de::Visitor<'de>,
571    {
572        use serde::Deserializer;
573        let deserializer = self::Deserializer::new(self.value);
574        deserializer.deserialize_any(visitor)
575    }
576
577    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
578    where
579        V: de::Visitor<'de>,
580    {
581        use serde::Deserializer;
582        let deserializer = self::Deserializer::new(self.value);
583        deserializer.deserialize_any(visitor)
584    }
585}
586
587struct EnumVisitor<'a> {
588    variant: Cow<'a, Value>,
589    content: Cow<'a, Value>,
590}
591
592impl<'a, 'de> de::EnumAccess<'de> for EnumVisitor<'a> {
593    type Error = Error;
594    type Variant = VariantVisitor<'a>;
595
596    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
597    where
598        V: de::DeserializeSeed<'de>,
599    {
600        Ok((
601            seed.deserialize(Deserializer::new(self.variant))?,
602            VariantVisitor {
603                value: self.content,
604            },
605        ))
606    }
607}