serde_pyobject/
de.rs

1use crate::error::{Error, Result};
2use pyo3::{types::*, Bound};
3use serde::{
4    de::{self, value::StrDeserializer, MapAccess, SeqAccess, Visitor},
5    forward_to_deserialize_any, Deserialize, Deserializer,
6};
7
8/// Deserialize a Python object into Rust type `T: Deserialize`.
9///
10/// # Examples
11///
12/// ## primitive
13///
14/// ```
15/// use pyo3::{Python, Py, PyAny, IntoPyObjectExt};
16/// use serde_pyobject::from_pyobject;
17///
18/// Python::attach(|py| {
19///     // integer
20///     let any: Py<PyAny> = 42_i32.into_bound_py_any(py).unwrap().unbind();
21///     let i: i32 = from_pyobject(any.into_bound(py)).unwrap();
22///     assert_eq!(i, 42);
23///
24///     // float
25///     let any: Py<PyAny> = 0.1_f32.into_bound_py_any(py).unwrap().unbind();
26///     let x: f32 = from_pyobject(any.into_bound(py)).unwrap();
27///     assert_eq!(x, 0.1);
28///
29///     // bool
30///     let any: Py<PyAny> = true.into_bound_py_any(py).unwrap().unbind();
31///     let x: bool = from_pyobject(any.into_bound(py)).unwrap();
32///     assert_eq!(x, true);
33/// });
34/// ```
35///
36/// ## option
37///
38/// ```
39/// use pyo3::{Python, Py, PyAny, IntoPyObjectExt};
40/// use serde_pyobject::from_pyobject;
41///
42/// Python::attach(|py| {
43///     let none = py.None();
44///     let option: Option<i32> = from_pyobject(none.into_bound(py)).unwrap();
45///     assert_eq!(option, None);
46///
47///     let py_int: Py<PyAny> = 42_i32.into_bound_py_any(py).unwrap().unbind();
48///     let i: Option<i32> = from_pyobject(py_int.into_bound(py)).unwrap();
49///     assert_eq!(i, Some(42));
50/// })
51/// ```
52///
53/// ## unit
54///
55/// ```
56/// use pyo3::{Python, types::PyTuple};
57/// use serde_pyobject::from_pyobject;
58///
59/// Python::attach(|py| {
60///     let py_unit = PyTuple::empty(py);
61///     let unit: () = from_pyobject(py_unit).unwrap();
62///     assert_eq!(unit, ());
63/// })
64/// ```
65///
66/// ## unit_struct
67///
68/// ```
69/// use serde::Deserialize;
70/// use pyo3::{Python, types::PyTuple};
71/// use serde_pyobject::from_pyobject;
72///
73/// #[derive(Debug, PartialEq, Deserialize)]
74/// struct UnitStruct;
75///
76/// Python::attach(|py| {
77///     let py_unit = PyTuple::empty(py);
78///     let unit: UnitStruct = from_pyobject(py_unit).unwrap();
79///     assert_eq!(unit, UnitStruct);
80/// })
81/// ```
82///
83/// ## unit variant
84///
85/// ```
86/// use serde::Deserialize;
87/// use pyo3::{Python, types::PyString};
88/// use serde_pyobject::from_pyobject;
89///
90/// #[derive(Debug, PartialEq, Deserialize)]
91/// enum E {
92///     A,
93///     B,
94/// }
95///
96/// Python::attach(|py| {
97///     let any = PyString::new(py, "A");
98///     let out: E = from_pyobject(any).unwrap();
99///     assert_eq!(out, E::A);
100/// })
101/// ```
102///
103/// ## newtype struct
104///
105/// ```
106/// use serde::Deserialize;
107/// use pyo3::{Python, Bound, PyAny, IntoPyObject};
108/// use serde_pyobject::from_pyobject;
109///
110/// #[derive(Debug, PartialEq, Deserialize)]
111/// struct NewTypeStruct(u8);
112///
113/// Python::attach(|py| {
114///     let any: Bound<PyAny> = 1_u32.into_pyobject(py).unwrap().into_any();
115///     let obj: NewTypeStruct = from_pyobject(any).unwrap();
116///     assert_eq!(obj, NewTypeStruct(1));
117/// });
118/// ```
119///
120/// ## newtype variant
121///
122/// ```
123/// use serde::Deserialize;
124/// use pyo3::Python;
125/// use serde_pyobject::{from_pyobject, pydict};
126///
127/// #[derive(Debug, PartialEq, Deserialize)]
128/// enum NewTypeVariant {
129///     N(u8),
130/// }
131///
132/// Python::attach(|py| {
133///     let dict = pydict! { py, "N" => 41 }.unwrap();
134///     let obj: NewTypeVariant = from_pyobject(dict).unwrap();
135///     assert_eq!(obj, NewTypeVariant::N(41));
136/// });
137/// ```
138///
139/// ## seq
140///
141/// ```
142/// use pyo3::Python;
143/// use serde_pyobject::{from_pyobject, pylist};
144///
145/// Python::attach(|py| {
146///     let list = pylist![py; 1, 2, 3].unwrap();
147///     let seq: Vec<i32> = from_pyobject(list).unwrap();
148///     assert_eq!(seq, vec![1, 2, 3]);
149/// });
150/// ```
151///
152/// ## tuple
153///
154/// ```
155/// use pyo3::{Python, types::PyTuple};
156/// use serde_pyobject::from_pyobject;
157///
158/// Python::attach(|py| {
159///     let tuple = PyTuple::new(py, &[1, 2, 3]).unwrap();
160///     let tuple: (i32, i32, i32) = from_pyobject(tuple).unwrap();
161///     assert_eq!(tuple, (1, 2, 3));
162/// });
163/// ```
164///
165/// ## tuple struct
166///
167/// ```
168/// use serde::Deserialize;
169/// use pyo3::{Python, IntoPyObject, types::PyTuple};
170/// use serde_pyobject::from_pyobject;
171///
172/// #[derive(Debug, PartialEq, Deserialize)]
173/// struct T(u8, String);
174///
175/// Python::attach(|py| {
176///     let tuple = PyTuple::new(py, &[1_u32.into_pyobject(py).unwrap().into_any(), "test".into_pyobject(py).unwrap().into_any()]).unwrap();
177///     let obj: T = from_pyobject(tuple).unwrap();
178///     assert_eq!(obj, T(1, "test".to_string()));
179/// });
180/// ```
181///
182/// ## tuple variant
183///
184/// ```
185/// use serde::Deserialize;
186/// use pyo3::Python;
187/// use serde_pyobject::{from_pyobject, pydict};
188///
189/// #[derive(Debug, PartialEq, Deserialize)]
190/// enum TupleVariant {
191///     T(u8, u8),
192/// }
193///
194/// Python::attach(|py| {
195///     let dict = pydict! { py, "T" => (1, 2) }.unwrap();
196///     let obj: TupleVariant = from_pyobject(dict).unwrap();
197///     assert_eq!(obj, TupleVariant::T(1, 2));
198/// });
199/// ```
200///
201/// ## map
202///
203/// ```
204/// use pyo3::Python;
205/// use serde_pyobject::{from_pyobject, pydict};
206/// use std::collections::BTreeMap;
207///
208/// Python::attach(|py| {
209///     let dict = pydict! { py,
210///         "a" => "hom",
211///         "b" => "test"
212///     }
213///     .unwrap();
214///     let map: BTreeMap<String, String> = from_pyobject(dict).unwrap();
215///     assert_eq!(map.get("a"), Some(&"hom".to_string()));
216///     assert_eq!(map.get("b"), Some(&"test".to_string()));
217/// });
218/// ```
219///
220/// ## struct
221///
222/// ```
223/// use serde::Deserialize;
224/// use pyo3::Python;
225/// use serde_pyobject::{from_pyobject, pydict};
226///
227/// #[derive(Debug, PartialEq, Deserialize)]
228/// struct A {
229///     a: i32,
230///     b: String,
231/// }
232///
233/// Python::attach(|py| {
234///     let dict = pydict! {
235///         "a" => 1,
236///         "b" => "test"
237///     }
238///     .unwrap();
239///     let a: A = from_pyobject(dict.into_bound(py)).unwrap();
240///     assert_eq!(
241///         a,
242///         A {
243///             a: 1,
244///             b: "test".to_string()
245///         }
246///     );
247/// });
248///
249/// Python::attach(|py| {
250///     let dict = pydict! {
251///         "A" => pydict! {
252///             "a" => 1,
253///             "b" => "test"
254///         }
255///         .unwrap()
256///     }
257///     .unwrap();
258///     let a: A = from_pyobject(dict.into_bound(py)).unwrap();
259///     assert_eq!(
260///         a,
261///         A {
262///             a: 1,
263///             b: "test".to_string()
264///         }
265///     );
266/// });
267/// ```
268///
269/// ## struct variant
270///
271/// ```
272/// use serde::Deserialize;
273/// use pyo3::Python;
274/// use serde_pyobject::{from_pyobject, pydict};
275///
276/// #[derive(Debug, PartialEq, Deserialize)]
277/// enum StructVariant {
278///     S { r: u8, g: u8, b: u8 },
279/// }
280///
281/// Python::attach(|py| {
282///     let dict = pydict! {
283///         py,
284///         "S" => pydict! {
285///             "r" => 1,
286///             "g" => 2,
287///             "b" => 3
288///         }.unwrap()
289///     }
290///     .unwrap();
291///     let obj: StructVariant = from_pyobject(dict).unwrap();
292///     assert_eq!(obj, StructVariant::S { r: 1, g: 2, b: 3 });
293/// });
294/// ```
295pub fn from_pyobject<'py, 'de, T: Deserialize<'de>, Any>(any: Bound<'py, Any>) -> Result<T> {
296    let any = any.into_any();
297    T::deserialize(PyAnyDeserializer(any))
298}
299
300struct PyAnyDeserializer<'py>(Bound<'py, PyAny>);
301
302impl<'de> de::Deserializer<'de> for PyAnyDeserializer<'_> {
303    type Error = Error;
304
305    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
306    where
307        V: Visitor<'de>,
308    {
309        if self.0.is_instance_of::<PyDict>() {
310            return visitor.visit_map(MapDeserializer::new(self.0.downcast()?));
311        }
312        if self.0.is_instance_of::<PyList>() {
313            return visitor.visit_seq(SeqDeserializer::from_list(self.0.downcast()?));
314        }
315        if self.0.is_instance_of::<PyTuple>() {
316            return visitor.visit_seq(SeqDeserializer::from_tuple(self.0.downcast()?));
317        }
318        if self.0.is_instance_of::<PyString>() {
319            return visitor.visit_str(&self.0.extract::<String>()?);
320        }
321        if self.0.is_instance_of::<PyBool>() {
322            // must be match before PyLong
323            return visitor.visit_bool(self.0.extract()?);
324        }
325        if self.0.is_instance_of::<PyInt>() {
326            return visitor.visit_i64(self.0.extract()?);
327        }
328        if self.0.is_instance_of::<PyFloat>() {
329            return visitor.visit_f64(self.0.extract()?);
330        }
331        if self.0.hasattr("__dict__")? {
332            return visitor.visit_map(MapDeserializer::new(
333                self.0.getattr("__dict__")?.downcast()?,
334            ));
335        }
336        if self.0.hasattr("__slots__")? {
337            // __slots__ and __dict__ are mutually exclusive, see
338            // https://docs.python.org/3/reference/datamodel.html#slots
339            return visitor.visit_map(MapDeserializer::from_slots(&self.0)?);
340        }
341        if self.0.is_none() {
342            return visitor.visit_none();
343        }
344
345        unreachable!("Unsupported type: {}", self.0.get_type());
346    }
347
348    fn deserialize_struct<V: de::Visitor<'de>>(
349        self,
350        name: &'static str,
351        _fields: &'static [&'static str],
352        visitor: V,
353    ) -> Result<V::Value> {
354        // Nested dict `{ "A": { "a": 1, "b": 2 } }` is deserialized as `A { a: 1, b: 2 }`
355        if self.0.is_instance_of::<PyDict>() {
356            let dict: &Bound<PyDict> = self.0.downcast()?;
357            if let Some(inner) = dict.get_item(name)? {
358                if let Ok(inner) = inner.downcast() {
359                    return visitor.visit_map(MapDeserializer::new(inner));
360                }
361            }
362        }
363        // Default to `any` case
364        self.deserialize_any(visitor)
365    }
366
367    fn deserialize_newtype_struct<V: de::Visitor<'de>>(
368        self,
369        _name: &'static str,
370        visitor: V,
371    ) -> Result<V::Value> {
372        visitor.visit_seq(SeqDeserializer {
373            seq_reversed: vec![self.0],
374        })
375    }
376
377    fn deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
378        if self.0.is_none() {
379            visitor.visit_none()
380        } else {
381            visitor.visit_some(self)
382        }
383    }
384
385    fn deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
386        if self.0.is(PyTuple::empty(self.0.py())) {
387            visitor.visit_unit()
388        } else {
389            self.deserialize_any(visitor)
390        }
391    }
392
393    fn deserialize_unit_struct<V: de::Visitor<'de>>(
394        self,
395        _name: &'static str,
396        visitor: V,
397    ) -> Result<V::Value> {
398        if self.0.is(PyTuple::empty(self.0.py())) {
399            visitor.visit_unit()
400        } else {
401            self.deserialize_any(visitor)
402        }
403    }
404
405    fn deserialize_enum<V: de::Visitor<'de>>(
406        self,
407        _name: &'static str,
408        _variants: &'static [&'static str],
409        visitor: V,
410    ) -> Result<V::Value> {
411        if self.0.is_instance_of::<PyString>() {
412            let variant: String = self.0.extract()?;
413            let py = self.0.py();
414            let none = py.None().into_bound(py);
415            return visitor.visit_enum(EnumDeserializer {
416                variant: &variant,
417                inner: none,
418            });
419        }
420        if self.0.is_instance_of::<PyDict>() {
421            let dict: &Bound<PyDict> = self.0.downcast()?;
422            if dict.len() == 1 {
423                let key = dict.keys().get_item(0).unwrap();
424                let value = dict.values().get_item(0).unwrap();
425                if key.is_instance_of::<PyString>() {
426                    let variant: String = key.extract()?;
427                    return visitor.visit_enum(EnumDeserializer {
428                        variant: &variant,
429                        inner: value,
430                    });
431                }
432            }
433        }
434        self.deserialize_any(visitor)
435    }
436
437    fn deserialize_tuple_struct<V: de::Visitor<'de>>(
438        self,
439        name: &'static str,
440        _len: usize,
441        visitor: V,
442    ) -> Result<V::Value> {
443        if self.0.is_instance_of::<PyDict>() {
444            let dict: &Bound<PyDict> = self.0.downcast()?;
445            if let Some(value) = dict.get_item(name)? {
446                if value.is_instance_of::<PyTuple>() {
447                    let tuple: &Bound<PyTuple> = value.downcast()?;
448                    return visitor.visit_seq(SeqDeserializer::from_tuple(tuple));
449                }
450            }
451        }
452        self.deserialize_any(visitor)
453    }
454
455    forward_to_deserialize_any! {
456        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
457        bytes byte_buf seq tuple
458        map identifier ignored_any
459    }
460}
461
462struct SeqDeserializer<'py> {
463    seq_reversed: Vec<Bound<'py, PyAny>>,
464}
465
466impl<'py> SeqDeserializer<'py> {
467    fn from_list(list: &Bound<'py, PyList>) -> Self {
468        let mut seq_reversed = Vec::new();
469        for item in list.iter().rev() {
470            seq_reversed.push(item);
471        }
472        Self { seq_reversed }
473    }
474
475    fn from_tuple(tuple: &Bound<'py, PyTuple>) -> Self {
476        let mut seq_reversed = Vec::new();
477        for item in tuple.iter().rev() {
478            seq_reversed.push(item);
479        }
480        Self { seq_reversed }
481    }
482}
483
484impl<'de> SeqAccess<'de> for SeqDeserializer<'_> {
485    type Error = Error;
486    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
487    where
488        T: de::DeserializeSeed<'de>,
489    {
490        self.seq_reversed.pop().map_or(Ok(None), |value| {
491            let value = seed.deserialize(PyAnyDeserializer(value))?;
492            Ok(Some(value))
493        })
494    }
495}
496
497struct MapDeserializer<'py> {
498    keys: Vec<Bound<'py, PyAny>>,
499    values: Vec<Bound<'py, PyAny>>,
500}
501
502impl<'py> MapDeserializer<'py> {
503    fn new(dict: &Bound<'py, PyDict>) -> Self {
504        let mut keys = Vec::new();
505        let mut values = Vec::new();
506        for (key, value) in dict.iter() {
507            keys.push(key);
508            values.push(value);
509        }
510        Self { keys, values }
511    }
512
513    fn from_slots(obj: &Bound<'py, PyAny>) -> Result<Self> {
514        let mut keys = vec![];
515        let mut values = vec![];
516        for key in obj.getattr("__slots__")?.try_iter()? {
517            let key = key?;
518            keys.push(key.clone());
519            let v = obj.getattr(key.str()?)?;
520            values.push(v);
521        }
522        Ok(Self { keys, values })
523    }
524}
525
526impl<'de> MapAccess<'de> for MapDeserializer<'_> {
527    type Error = Error;
528
529    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
530    where
531        K: de::DeserializeSeed<'de>,
532    {
533        if let Some(key) = self.keys.pop() {
534            let key = seed.deserialize(PyAnyDeserializer(key))?;
535            Ok(Some(key))
536        } else {
537            Ok(None)
538        }
539    }
540
541    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
542    where
543        V: de::DeserializeSeed<'de>,
544    {
545        if let Some(value) = self.values.pop() {
546            let value = seed.deserialize(PyAnyDeserializer(value))?;
547            Ok(value)
548        } else {
549            unreachable!()
550        }
551    }
552}
553
554// this lifetime is technically no longer 'py
555struct EnumDeserializer<'py> {
556    variant: &'py str,
557    inner: Bound<'py, PyAny>,
558}
559
560impl<'de> de::EnumAccess<'de> for EnumDeserializer<'_> {
561    type Error = Error;
562    type Variant = Self;
563
564    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
565    where
566        V: de::DeserializeSeed<'de>,
567    {
568        Ok((
569            seed.deserialize(StrDeserializer::<Error>::new(self.variant))?,
570            self,
571        ))
572    }
573}
574
575impl<'de> de::VariantAccess<'de> for EnumDeserializer<'_> {
576    type Error = Error;
577
578    fn unit_variant(self) -> Result<()> {
579        Ok(())
580    }
581
582    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
583    where
584        T: de::DeserializeSeed<'de>,
585    {
586        seed.deserialize(PyAnyDeserializer(self.inner))
587    }
588
589    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
590    where
591        V: Visitor<'de>,
592    {
593        PyAnyDeserializer(self.inner).deserialize_seq(visitor)
594    }
595
596    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
597    where
598        V: Visitor<'de>,
599    {
600        PyAnyDeserializer(self.inner).deserialize_map(visitor)
601    }
602}