serde_pyobject/
de.rs

1use crate::error::{Error, Result};
2use pyo3::{types::*, Bound};
3use serde::{
4    de::{self, value::StrDeserializer, MapAccess, SeqAccess, Visitor},
5    forward_to_deserialize_any, Deserialize, Deserializer,
6};
7
8/// Deserialize a Python object into Rust type `T: Deserialize`.
9///
10/// # Examples
11///
12/// ## primitive
13///
14/// ```
15/// use pyo3::{Python, Py, PyAny, IntoPy};
16/// use serde_pyobject::from_pyobject;
17///
18/// Python::with_gil(|py| {
19///     // integer
20///     let any: Py<PyAny> = 42.into_py(py);
21///     let i: i32 = from_pyobject(any.into_bound(py)).unwrap();
22///     assert_eq!(i, 42);
23///
24///     // float
25///     let any: Py<PyAny> = (0.1).into_py(py);
26///     let x: f32 = from_pyobject(any.into_bound(py)).unwrap();
27///     assert_eq!(x, 0.1);
28///
29///     // bool
30///     let any: Py<PyAny> = true.into_py(py);
31///     let x: bool = from_pyobject(any.into_bound(py)).unwrap();
32///     assert_eq!(x, true);
33/// });
34/// ```
35///
36/// ## option
37///
38/// ```
39/// use pyo3::{Python, Py, PyAny, IntoPy};
40/// use serde_pyobject::from_pyobject;
41///
42/// Python::with_gil(|py| {
43///     let none = py.None();
44///     let option: Option<i32> = from_pyobject(none.into_bound(py)).unwrap();
45///     assert_eq!(option, None);
46///
47///     let py_int: Py<PyAny> = 42.into_py(py);
48///     let i: Option<i32> = from_pyobject(py_int.into_bound(py)).unwrap();
49///     assert_eq!(i, Some(42));
50/// })
51/// ```
52///
53/// ## unit
54///
55/// ```
56/// use pyo3::{Python, types::PyTuple};
57/// use serde_pyobject::from_pyobject;
58///
59/// Python::with_gil(|py| {
60///     let py_unit = PyTuple::empty(py);
61///     let unit: () = from_pyobject(py_unit).unwrap();
62///     assert_eq!(unit, ());
63/// })
64/// ```
65///
66/// ## unit_struct
67///
68/// ```
69/// use serde::Deserialize;
70/// use pyo3::{Python, types::PyTuple};
71/// use serde_pyobject::from_pyobject;
72///
73/// #[derive(Debug, PartialEq, Deserialize)]
74/// struct UnitStruct;
75///
76/// Python::with_gil(|py| {
77///     let py_unit = PyTuple::empty(py);
78///     let unit: UnitStruct = from_pyobject(py_unit).unwrap();
79///     assert_eq!(unit, UnitStruct);
80/// })
81/// ```
82///
83/// ## unit variant
84///
85/// ```
86/// use serde::Deserialize;
87/// use pyo3::{Python, types::PyString};
88/// use serde_pyobject::from_pyobject;
89///
90/// #[derive(Debug, PartialEq, Deserialize)]
91/// enum E {
92///     A,
93///     B,
94/// }
95///
96/// Python::with_gil(|py| {
97///     let any = PyString::new_bound(py, "A");
98///     let out: E = from_pyobject(any).unwrap();
99///     assert_eq!(out, E::A);
100/// })
101/// ```
102///
103/// ## newtype struct
104///
105/// ```
106/// use serde::Deserialize;
107/// use pyo3::{Python, Bound, PyAny, IntoPy};
108/// use serde_pyobject::from_pyobject;
109///
110/// #[derive(Debug, PartialEq, Deserialize)]
111/// struct NewTypeStruct(u8);
112///
113/// Python::with_gil(|py| {
114///     let any: Bound<PyAny> = 1_u32.into_py(py).into_bound(py);
115///     let obj: NewTypeStruct = from_pyobject(any).unwrap();
116///     assert_eq!(obj, NewTypeStruct(1));
117/// });
118/// ```
119///
120/// ## newtype variant
121///
122/// ```
123/// use serde::Deserialize;
124/// use pyo3::Python;
125/// use serde_pyobject::{from_pyobject, pydict};
126///
127/// #[derive(Debug, PartialEq, Deserialize)]
128/// enum NewTypeVariant {
129///     N(u8),
130/// }
131///
132/// Python::with_gil(|py| {
133///     let dict = pydict! { py, "N" => 41 }.unwrap();
134///     let obj: NewTypeVariant = from_pyobject(dict).unwrap();
135///     assert_eq!(obj, NewTypeVariant::N(41));
136/// });
137/// ```
138///
139/// ## seq
140///
141/// ```
142/// use pyo3::Python;
143/// use serde_pyobject::{from_pyobject, pylist};
144///
145/// Python::with_gil(|py| {
146///     let list = pylist![py; 1, 2, 3].unwrap();
147///     let seq: Vec<i32> = from_pyobject(list).unwrap();
148///     assert_eq!(seq, vec![1, 2, 3]);
149/// });
150/// ```
151///
152/// ## tuple
153///
154/// ```
155/// use pyo3::{Python, types::PyTuple};
156/// use serde_pyobject::from_pyobject;
157///
158/// Python::with_gil(|py| {
159///     let tuple = PyTuple::new_bound(py, &[1, 2, 3]);
160///     let tuple: (i32, i32, i32) = from_pyobject(tuple).unwrap();
161///     assert_eq!(tuple, (1, 2, 3));
162/// });
163/// ```
164///
165/// ## tuple struct
166///
167/// ```
168/// use serde::Deserialize;
169/// use pyo3::{Python, IntoPy, types::PyTuple};
170/// use serde_pyobject::from_pyobject;
171///
172/// #[derive(Debug, PartialEq, Deserialize)]
173/// struct T(u8, String);
174///
175/// Python::with_gil(|py| {
176///     let tuple = PyTuple::new_bound(py, &[1_u32.into_py(py), "test".into_py(py)]);
177///     let obj: T = from_pyobject(tuple).unwrap();
178///     assert_eq!(obj, T(1, "test".to_string()));
179/// });
180/// ```
181///
182/// ## tuple variant
183///
184/// ```
185/// use serde::Deserialize;
186/// use pyo3::Python;
187/// use serde_pyobject::{from_pyobject, pydict};
188///
189/// #[derive(Debug, PartialEq, Deserialize)]
190/// enum TupleVariant {
191///     T(u8, u8),
192/// }
193///
194/// Python::with_gil(|py| {
195///     let dict = pydict! { py, "T" => (1, 2) }.unwrap();
196///     let obj: TupleVariant = from_pyobject(dict).unwrap();
197///     assert_eq!(obj, TupleVariant::T(1, 2));
198/// });
199/// ```
200///
201/// ## map
202///
203/// ```
204/// use pyo3::Python;
205/// use serde_pyobject::{from_pyobject, pydict};
206/// use std::collections::BTreeMap;
207///
208/// Python::with_gil(|py| {
209///     let dict = pydict! { py,
210///         "a" => "hom",
211///         "b" => "test"
212///     }
213///     .unwrap();
214///     let map: BTreeMap<String, String> = from_pyobject(dict).unwrap();
215///     assert_eq!(map.get("a"), Some(&"hom".to_string()));
216///     assert_eq!(map.get("b"), Some(&"test".to_string()));
217/// });
218/// ```
219///
220/// ## struct
221///
222/// ```
223/// use serde::Deserialize;
224/// use pyo3::Python;
225/// use serde_pyobject::{from_pyobject, pydict};
226///
227/// #[derive(Debug, PartialEq, Deserialize)]
228/// struct A {
229///     a: i32,
230///     b: String,
231/// }
232///
233/// Python::with_gil(|py| {
234///     let dict = pydict! {
235///         "a" => 1,
236///         "b" => "test"
237///     }
238///     .unwrap();
239///     let a: A = from_pyobject(dict.into_bound(py)).unwrap();
240///     assert_eq!(
241///         a,
242///         A {
243///             a: 1,
244///             b: "test".to_string()
245///         }
246///     );
247/// });
248///
249/// Python::with_gil(|py| {
250///     let dict = pydict! {
251///         "A" => pydict! {
252///             "a" => 1,
253///             "b" => "test"
254///         }
255///         .unwrap()
256///     }
257///     .unwrap();
258///     let a: A = from_pyobject(dict.into_bound(py)).unwrap();
259///     assert_eq!(
260///         a,
261///         A {
262///             a: 1,
263///             b: "test".to_string()
264///         }
265///     );
266/// });
267/// ```
268///
269/// ## struct variant
270///
271/// ```
272/// use serde::Deserialize;
273/// use pyo3::Python;
274/// use serde_pyobject::{from_pyobject, pydict};
275///
276/// #[derive(Debug, PartialEq, Deserialize)]
277/// enum StructVariant {
278///     S { r: u8, g: u8, b: u8 },
279/// }
280///
281/// Python::with_gil(|py| {
282///     let dict = pydict! {
283///         py,
284///         "S" => pydict! {
285///             "r" => 1,
286///             "g" => 2,
287///             "b" => 3
288///         }.unwrap()
289///     }
290///     .unwrap();
291///     let obj: StructVariant = from_pyobject(dict).unwrap();
292///     assert_eq!(obj, StructVariant::S { r: 1, g: 2, b: 3 });
293/// });
294/// ```
295pub fn from_pyobject<'py, 'de, T: Deserialize<'de>, Any>(any: Bound<'py, Any>) -> Result<T> {
296    let any = any.into_any();
297    T::deserialize(PyAnyDeserializer(any))
298}
299
300struct PyAnyDeserializer<'py>(Bound<'py, PyAny>);
301
302impl<'de> de::Deserializer<'de> for PyAnyDeserializer<'_> {
303    type Error = Error;
304
305    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
306    where
307        V: Visitor<'de>,
308    {
309        if self.0.is_instance_of::<PyDict>() {
310            return visitor.visit_map(MapDeserializer::new(self.0.downcast()?));
311        }
312        if self.0.is_instance_of::<PyList>() {
313            return visitor.visit_seq(SeqDeserializer::from_list(self.0.downcast()?));
314        }
315        if self.0.is_instance_of::<PyTuple>() {
316            return visitor.visit_seq(SeqDeserializer::from_tuple(self.0.downcast()?));
317        }
318        if self.0.is_instance_of::<PyString>() {
319            return visitor.visit_str(&self.0.extract::<String>()?);
320        }
321        if self.0.is_instance_of::<PyBool>() {
322            // must be match before PyLong
323            return visitor.visit_bool(self.0.extract()?);
324        }
325        if self.0.is_instance_of::<PyInt>() {
326            return visitor.visit_i64(self.0.extract()?);
327        }
328        if self.0.is_instance_of::<PyFloat>() {
329            return visitor.visit_f64(self.0.extract()?);
330        }
331        if self.0.is_none() {
332            return visitor.visit_none();
333        }
334        unreachable!("Unsupported type: {}", self.0.get_type());
335    }
336
337    fn deserialize_struct<V: de::Visitor<'de>>(
338        self,
339        name: &'static str,
340        _fields: &'static [&'static str],
341        visitor: V,
342    ) -> Result<V::Value> {
343        // Nested dict `{ "A": { "a": 1, "b": 2 } }` is deserialized as `A { a: 1, b: 2 }`
344        if self.0.is_instance_of::<PyDict>() {
345            let dict: &Bound<PyDict> = self.0.downcast()?;
346            if let Some(inner) = dict.get_item(name)? {
347                if let Ok(inner) = inner.downcast() {
348                    return visitor.visit_map(MapDeserializer::new(inner));
349                }
350            }
351        }
352        // Default to `any` case
353        self.deserialize_any(visitor)
354    }
355
356    fn deserialize_newtype_struct<V: de::Visitor<'de>>(
357        self,
358        _name: &'static str,
359        visitor: V,
360    ) -> Result<V::Value> {
361        visitor.visit_seq(SeqDeserializer {
362            seq_reversed: vec![self.0],
363        })
364    }
365
366    fn deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
367        if self.0.is_none() {
368            visitor.visit_none()
369        } else {
370            visitor.visit_some(self)
371        }
372    }
373
374    fn deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
375        if self.0.is(&PyTuple::empty(self.0.py())) {
376            visitor.visit_unit()
377        } else {
378            self.deserialize_any(visitor)
379        }
380    }
381
382    fn deserialize_unit_struct<V: de::Visitor<'de>>(
383        self,
384        _name: &'static str,
385        visitor: V,
386    ) -> Result<V::Value> {
387        if self.0.is(&PyTuple::empty(self.0.py())) {
388            visitor.visit_unit()
389        } else {
390            self.deserialize_any(visitor)
391        }
392    }
393
394    fn deserialize_enum<V: de::Visitor<'de>>(
395        self,
396        _name: &'static str,
397        _variants: &'static [&'static str],
398        visitor: V,
399    ) -> Result<V::Value> {
400        if self.0.is_instance_of::<PyString>() {
401            let variant: String = self.0.extract()?;
402            let py = self.0.py();
403            let none = py.None().into_bound(py);
404            return visitor.visit_enum(EnumDeserializer {
405                variant: &variant,
406                inner: none,
407            });
408        }
409        if self.0.is_instance_of::<PyDict>() {
410            let dict: &Bound<PyDict> = self.0.downcast()?;
411            if dict.len() == 1 {
412                let key = dict.keys().get_item(0).unwrap();
413                let value = dict.values().get_item(0).unwrap();
414                if key.is_instance_of::<PyString>() {
415                    let variant: String = key.extract()?;
416                    return visitor.visit_enum(EnumDeserializer {
417                        variant: &variant,
418                        inner: value,
419                    });
420                }
421            }
422        }
423        self.deserialize_any(visitor)
424    }
425
426    fn deserialize_tuple_struct<V: de::Visitor<'de>>(
427        self,
428        name: &'static str,
429        _len: usize,
430        visitor: V,
431    ) -> Result<V::Value> {
432        if self.0.is_instance_of::<PyDict>() {
433            let dict: &Bound<PyDict> = self.0.downcast()?;
434            if let Some(value) = dict.get_item(name)? {
435                if value.is_instance_of::<PyTuple>() {
436                    let tuple: &Bound<PyTuple> = value.downcast()?;
437                    return visitor.visit_seq(SeqDeserializer::from_tuple(tuple));
438                }
439            }
440        }
441        self.deserialize_any(visitor)
442    }
443
444    forward_to_deserialize_any! {
445        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
446        bytes byte_buf seq tuple
447        map identifier ignored_any
448    }
449}
450
451struct SeqDeserializer<'py> {
452    seq_reversed: Vec<Bound<'py, PyAny>>,
453}
454
455impl<'py> SeqDeserializer<'py> {
456    fn from_list(list: &Bound<'py, PyList>) -> Self {
457        let mut seq_reversed = Vec::new();
458        for item in list.iter().rev() {
459            seq_reversed.push(item);
460        }
461        Self { seq_reversed }
462    }
463
464    fn from_tuple(tuple: &Bound<'py, PyTuple>) -> Self {
465        let mut seq_reversed = Vec::new();
466        for item in tuple.iter().rev() {
467            seq_reversed.push(item);
468        }
469        Self { seq_reversed }
470    }
471}
472
473impl<'de> SeqAccess<'de> for SeqDeserializer<'_> {
474    type Error = Error;
475    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
476    where
477        T: de::DeserializeSeed<'de>,
478    {
479        self.seq_reversed.pop().map_or(Ok(None), |value| {
480            let value = seed.deserialize(PyAnyDeserializer(value))?;
481            Ok(Some(value))
482        })
483    }
484}
485
486struct MapDeserializer<'py> {
487    keys: Vec<Bound<'py, PyAny>>,
488    values: Vec<Bound<'py, PyAny>>,
489}
490
491impl<'py> MapDeserializer<'py> {
492    fn new(dict: &Bound<'py, PyDict>) -> Self {
493        let mut keys = Vec::new();
494        let mut values = Vec::new();
495        for (key, value) in dict.iter() {
496            keys.push(key);
497            values.push(value);
498        }
499        Self { keys, values }
500    }
501}
502
503impl<'de> MapAccess<'de> for MapDeserializer<'_> {
504    type Error = Error;
505
506    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
507    where
508        K: de::DeserializeSeed<'de>,
509    {
510        if let Some(key) = self.keys.pop() {
511            let key = seed.deserialize(PyAnyDeserializer(key))?;
512            Ok(Some(key))
513        } else {
514            Ok(None)
515        }
516    }
517
518    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
519    where
520        V: de::DeserializeSeed<'de>,
521    {
522        if let Some(value) = self.values.pop() {
523            let value = seed.deserialize(PyAnyDeserializer(value))?;
524            Ok(value)
525        } else {
526            unreachable!()
527        }
528    }
529}
530
531// this lifetime is technically no longer 'py
532struct EnumDeserializer<'py> {
533    variant: &'py str,
534    inner: Bound<'py, PyAny>,
535}
536
537impl<'de> de::EnumAccess<'de> for EnumDeserializer<'_> {
538    type Error = Error;
539    type Variant = Self;
540
541    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
542    where
543        V: de::DeserializeSeed<'de>,
544    {
545        Ok((
546            seed.deserialize(StrDeserializer::<Error>::new(self.variant))?,
547            self,
548        ))
549    }
550}
551
552impl<'de> de::VariantAccess<'de> for EnumDeserializer<'_> {
553    type Error = Error;
554
555    fn unit_variant(self) -> Result<()> {
556        Ok(())
557    }
558
559    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
560    where
561        T: de::DeserializeSeed<'de>,
562    {
563        seed.deserialize(PyAnyDeserializer(self.inner))
564    }
565
566    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
567    where
568        V: Visitor<'de>,
569    {
570        PyAnyDeserializer(self.inner).deserialize_seq(visitor)
571    }
572
573    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
574    where
575        V: Visitor<'de>,
576    {
577        PyAnyDeserializer(self.inner).deserialize_map(visitor)
578    }
579}