serde_pyobject/de.rs
1use crate::error::{Error, Result};
2use pyo3::{types::*, Bound};
3use serde::{
4 de::{self, value::StrDeserializer, MapAccess, SeqAccess, Visitor},
5 forward_to_deserialize_any, Deserialize, Deserializer,
6};
7
8/// Deserialize a Python object into Rust type `T: Deserialize`.
9///
10/// # Examples
11///
12/// ## primitive
13///
14/// ```
15/// use pyo3::{Python, Py, PyAny, IntoPy};
16/// use serde_pyobject::from_pyobject;
17///
18/// Python::with_gil(|py| {
19/// // integer
20/// let any: Py<PyAny> = 42.into_py(py);
21/// let i: i32 = from_pyobject(any.into_bound(py)).unwrap();
22/// assert_eq!(i, 42);
23///
24/// // float
25/// let any: Py<PyAny> = (0.1).into_py(py);
26/// let x: f32 = from_pyobject(any.into_bound(py)).unwrap();
27/// assert_eq!(x, 0.1);
28///
29/// // bool
30/// let any: Py<PyAny> = true.into_py(py);
31/// let x: bool = from_pyobject(any.into_bound(py)).unwrap();
32/// assert_eq!(x, true);
33/// });
34/// ```
35///
36/// ## option
37///
38/// ```
39/// use pyo3::{Python, Py, PyAny, IntoPy};
40/// use serde_pyobject::from_pyobject;
41///
42/// Python::with_gil(|py| {
43/// let none = py.None();
44/// let option: Option<i32> = from_pyobject(none.into_bound(py)).unwrap();
45/// assert_eq!(option, None);
46///
47/// let py_int: Py<PyAny> = 42.into_py(py);
48/// let i: Option<i32> = from_pyobject(py_int.into_bound(py)).unwrap();
49/// assert_eq!(i, Some(42));
50/// })
51/// ```
52///
53/// ## unit
54///
55/// ```
56/// use pyo3::{Python, types::PyTuple};
57/// use serde_pyobject::from_pyobject;
58///
59/// Python::with_gil(|py| {
60/// let py_unit = PyTuple::empty(py);
61/// let unit: () = from_pyobject(py_unit).unwrap();
62/// assert_eq!(unit, ());
63/// })
64/// ```
65///
66/// ## unit_struct
67///
68/// ```
69/// use serde::Deserialize;
70/// use pyo3::{Python, types::PyTuple};
71/// use serde_pyobject::from_pyobject;
72///
73/// #[derive(Debug, PartialEq, Deserialize)]
74/// struct UnitStruct;
75///
76/// Python::with_gil(|py| {
77/// let py_unit = PyTuple::empty(py);
78/// let unit: UnitStruct = from_pyobject(py_unit).unwrap();
79/// assert_eq!(unit, UnitStruct);
80/// })
81/// ```
82///
83/// ## unit variant
84///
85/// ```
86/// use serde::Deserialize;
87/// use pyo3::{Python, types::PyString};
88/// use serde_pyobject::from_pyobject;
89///
90/// #[derive(Debug, PartialEq, Deserialize)]
91/// enum E {
92/// A,
93/// B,
94/// }
95///
96/// Python::with_gil(|py| {
97/// let any = PyString::new_bound(py, "A");
98/// let out: E = from_pyobject(any).unwrap();
99/// assert_eq!(out, E::A);
100/// })
101/// ```
102///
103/// ## newtype struct
104///
105/// ```
106/// use serde::Deserialize;
107/// use pyo3::{Python, Bound, PyAny, IntoPy};
108/// use serde_pyobject::from_pyobject;
109///
110/// #[derive(Debug, PartialEq, Deserialize)]
111/// struct NewTypeStruct(u8);
112///
113/// Python::with_gil(|py| {
114/// let any: Bound<PyAny> = 1_u32.into_py(py).into_bound(py);
115/// let obj: NewTypeStruct = from_pyobject(any).unwrap();
116/// assert_eq!(obj, NewTypeStruct(1));
117/// });
118/// ```
119///
120/// ## newtype variant
121///
122/// ```
123/// use serde::Deserialize;
124/// use pyo3::Python;
125/// use serde_pyobject::{from_pyobject, pydict};
126///
127/// #[derive(Debug, PartialEq, Deserialize)]
128/// enum NewTypeVariant {
129/// N(u8),
130/// }
131///
132/// Python::with_gil(|py| {
133/// let dict = pydict! { py, "N" => 41 }.unwrap();
134/// let obj: NewTypeVariant = from_pyobject(dict).unwrap();
135/// assert_eq!(obj, NewTypeVariant::N(41));
136/// });
137/// ```
138///
139/// ## seq
140///
141/// ```
142/// use pyo3::Python;
143/// use serde_pyobject::{from_pyobject, pylist};
144///
145/// Python::with_gil(|py| {
146/// let list = pylist![py; 1, 2, 3].unwrap();
147/// let seq: Vec<i32> = from_pyobject(list).unwrap();
148/// assert_eq!(seq, vec![1, 2, 3]);
149/// });
150/// ```
151///
152/// ## tuple
153///
154/// ```
155/// use pyo3::{Python, types::PyTuple};
156/// use serde_pyobject::from_pyobject;
157///
158/// Python::with_gil(|py| {
159/// let tuple = PyTuple::new_bound(py, &[1, 2, 3]);
160/// let tuple: (i32, i32, i32) = from_pyobject(tuple).unwrap();
161/// assert_eq!(tuple, (1, 2, 3));
162/// });
163/// ```
164///
165/// ## tuple struct
166///
167/// ```
168/// use serde::Deserialize;
169/// use pyo3::{Python, IntoPy, types::PyTuple};
170/// use serde_pyobject::from_pyobject;
171///
172/// #[derive(Debug, PartialEq, Deserialize)]
173/// struct T(u8, String);
174///
175/// Python::with_gil(|py| {
176/// let tuple = PyTuple::new_bound(py, &[1_u32.into_py(py), "test".into_py(py)]);
177/// let obj: T = from_pyobject(tuple).unwrap();
178/// assert_eq!(obj, T(1, "test".to_string()));
179/// });
180/// ```
181///
182/// ## tuple variant
183///
184/// ```
185/// use serde::Deserialize;
186/// use pyo3::Python;
187/// use serde_pyobject::{from_pyobject, pydict};
188///
189/// #[derive(Debug, PartialEq, Deserialize)]
190/// enum TupleVariant {
191/// T(u8, u8),
192/// }
193///
194/// Python::with_gil(|py| {
195/// let dict = pydict! { py, "T" => (1, 2) }.unwrap();
196/// let obj: TupleVariant = from_pyobject(dict).unwrap();
197/// assert_eq!(obj, TupleVariant::T(1, 2));
198/// });
199/// ```
200///
201/// ## map
202///
203/// ```
204/// use pyo3::Python;
205/// use serde_pyobject::{from_pyobject, pydict};
206/// use std::collections::BTreeMap;
207///
208/// Python::with_gil(|py| {
209/// let dict = pydict! { py,
210/// "a" => "hom",
211/// "b" => "test"
212/// }
213/// .unwrap();
214/// let map: BTreeMap<String, String> = from_pyobject(dict).unwrap();
215/// assert_eq!(map.get("a"), Some(&"hom".to_string()));
216/// assert_eq!(map.get("b"), Some(&"test".to_string()));
217/// });
218/// ```
219///
220/// ## struct
221///
222/// ```
223/// use serde::Deserialize;
224/// use pyo3::Python;
225/// use serde_pyobject::{from_pyobject, pydict};
226///
227/// #[derive(Debug, PartialEq, Deserialize)]
228/// struct A {
229/// a: i32,
230/// b: String,
231/// }
232///
233/// Python::with_gil(|py| {
234/// let dict = pydict! {
235/// "a" => 1,
236/// "b" => "test"
237/// }
238/// .unwrap();
239/// let a: A = from_pyobject(dict.into_bound(py)).unwrap();
240/// assert_eq!(
241/// a,
242/// A {
243/// a: 1,
244/// b: "test".to_string()
245/// }
246/// );
247/// });
248///
249/// Python::with_gil(|py| {
250/// let dict = pydict! {
251/// "A" => pydict! {
252/// "a" => 1,
253/// "b" => "test"
254/// }
255/// .unwrap()
256/// }
257/// .unwrap();
258/// let a: A = from_pyobject(dict.into_bound(py)).unwrap();
259/// assert_eq!(
260/// a,
261/// A {
262/// a: 1,
263/// b: "test".to_string()
264/// }
265/// );
266/// });
267/// ```
268///
269/// ## struct variant
270///
271/// ```
272/// use serde::Deserialize;
273/// use pyo3::Python;
274/// use serde_pyobject::{from_pyobject, pydict};
275///
276/// #[derive(Debug, PartialEq, Deserialize)]
277/// enum StructVariant {
278/// S { r: u8, g: u8, b: u8 },
279/// }
280///
281/// Python::with_gil(|py| {
282/// let dict = pydict! {
283/// py,
284/// "S" => pydict! {
285/// "r" => 1,
286/// "g" => 2,
287/// "b" => 3
288/// }.unwrap()
289/// }
290/// .unwrap();
291/// let obj: StructVariant = from_pyobject(dict).unwrap();
292/// assert_eq!(obj, StructVariant::S { r: 1, g: 2, b: 3 });
293/// });
294/// ```
295pub fn from_pyobject<'py, 'de, T: Deserialize<'de>, Any>(any: Bound<'py, Any>) -> Result<T> {
296 let any = any.into_any();
297 T::deserialize(PyAnyDeserializer(any))
298}
299
300struct PyAnyDeserializer<'py>(Bound<'py, PyAny>);
301
302impl<'de> de::Deserializer<'de> for PyAnyDeserializer<'_> {
303 type Error = Error;
304
305 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
306 where
307 V: Visitor<'de>,
308 {
309 if self.0.is_instance_of::<PyDict>() {
310 return visitor.visit_map(MapDeserializer::new(self.0.downcast()?));
311 }
312 if self.0.is_instance_of::<PyList>() {
313 return visitor.visit_seq(SeqDeserializer::from_list(self.0.downcast()?));
314 }
315 if self.0.is_instance_of::<PyTuple>() {
316 return visitor.visit_seq(SeqDeserializer::from_tuple(self.0.downcast()?));
317 }
318 if self.0.is_instance_of::<PyString>() {
319 return visitor.visit_str(&self.0.extract::<String>()?);
320 }
321 if self.0.is_instance_of::<PyBool>() {
322 // must be match before PyLong
323 return visitor.visit_bool(self.0.extract()?);
324 }
325 if self.0.is_instance_of::<PyInt>() {
326 return visitor.visit_i64(self.0.extract()?);
327 }
328 if self.0.is_instance_of::<PyFloat>() {
329 return visitor.visit_f64(self.0.extract()?);
330 }
331 if self.0.is_none() {
332 return visitor.visit_none();
333 }
334 unreachable!("Unsupported type: {}", self.0.get_type());
335 }
336
337 fn deserialize_struct<V: de::Visitor<'de>>(
338 self,
339 name: &'static str,
340 _fields: &'static [&'static str],
341 visitor: V,
342 ) -> Result<V::Value> {
343 // Nested dict `{ "A": { "a": 1, "b": 2 } }` is deserialized as `A { a: 1, b: 2 }`
344 if self.0.is_instance_of::<PyDict>() {
345 let dict: &Bound<PyDict> = self.0.downcast()?;
346 if let Some(inner) = dict.get_item(name)? {
347 if let Ok(inner) = inner.downcast() {
348 return visitor.visit_map(MapDeserializer::new(inner));
349 }
350 }
351 }
352 // Default to `any` case
353 self.deserialize_any(visitor)
354 }
355
356 fn deserialize_newtype_struct<V: de::Visitor<'de>>(
357 self,
358 _name: &'static str,
359 visitor: V,
360 ) -> Result<V::Value> {
361 visitor.visit_seq(SeqDeserializer {
362 seq_reversed: vec![self.0],
363 })
364 }
365
366 fn deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
367 if self.0.is_none() {
368 visitor.visit_none()
369 } else {
370 visitor.visit_some(self)
371 }
372 }
373
374 fn deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
375 if self.0.is(&PyTuple::empty(self.0.py())) {
376 visitor.visit_unit()
377 } else {
378 self.deserialize_any(visitor)
379 }
380 }
381
382 fn deserialize_unit_struct<V: de::Visitor<'de>>(
383 self,
384 _name: &'static str,
385 visitor: V,
386 ) -> Result<V::Value> {
387 if self.0.is(&PyTuple::empty(self.0.py())) {
388 visitor.visit_unit()
389 } else {
390 self.deserialize_any(visitor)
391 }
392 }
393
394 fn deserialize_enum<V: de::Visitor<'de>>(
395 self,
396 _name: &'static str,
397 _variants: &'static [&'static str],
398 visitor: V,
399 ) -> Result<V::Value> {
400 if self.0.is_instance_of::<PyString>() {
401 let variant: String = self.0.extract()?;
402 let py = self.0.py();
403 let none = py.None().into_bound(py);
404 return visitor.visit_enum(EnumDeserializer {
405 variant: &variant,
406 inner: none,
407 });
408 }
409 if self.0.is_instance_of::<PyDict>() {
410 let dict: &Bound<PyDict> = self.0.downcast()?;
411 if dict.len() == 1 {
412 let key = dict.keys().get_item(0).unwrap();
413 let value = dict.values().get_item(0).unwrap();
414 if key.is_instance_of::<PyString>() {
415 let variant: String = key.extract()?;
416 return visitor.visit_enum(EnumDeserializer {
417 variant: &variant,
418 inner: value,
419 });
420 }
421 }
422 }
423 self.deserialize_any(visitor)
424 }
425
426 fn deserialize_tuple_struct<V: de::Visitor<'de>>(
427 self,
428 name: &'static str,
429 _len: usize,
430 visitor: V,
431 ) -> Result<V::Value> {
432 if self.0.is_instance_of::<PyDict>() {
433 let dict: &Bound<PyDict> = self.0.downcast()?;
434 if let Some(value) = dict.get_item(name)? {
435 if value.is_instance_of::<PyTuple>() {
436 let tuple: &Bound<PyTuple> = value.downcast()?;
437 return visitor.visit_seq(SeqDeserializer::from_tuple(tuple));
438 }
439 }
440 }
441 self.deserialize_any(visitor)
442 }
443
444 forward_to_deserialize_any! {
445 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
446 bytes byte_buf seq tuple
447 map identifier ignored_any
448 }
449}
450
451struct SeqDeserializer<'py> {
452 seq_reversed: Vec<Bound<'py, PyAny>>,
453}
454
455impl<'py> SeqDeserializer<'py> {
456 fn from_list(list: &Bound<'py, PyList>) -> Self {
457 let mut seq_reversed = Vec::new();
458 for item in list.iter().rev() {
459 seq_reversed.push(item);
460 }
461 Self { seq_reversed }
462 }
463
464 fn from_tuple(tuple: &Bound<'py, PyTuple>) -> Self {
465 let mut seq_reversed = Vec::new();
466 for item in tuple.iter().rev() {
467 seq_reversed.push(item);
468 }
469 Self { seq_reversed }
470 }
471}
472
473impl<'de> SeqAccess<'de> for SeqDeserializer<'_> {
474 type Error = Error;
475 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
476 where
477 T: de::DeserializeSeed<'de>,
478 {
479 self.seq_reversed.pop().map_or(Ok(None), |value| {
480 let value = seed.deserialize(PyAnyDeserializer(value))?;
481 Ok(Some(value))
482 })
483 }
484}
485
486struct MapDeserializer<'py> {
487 keys: Vec<Bound<'py, PyAny>>,
488 values: Vec<Bound<'py, PyAny>>,
489}
490
491impl<'py> MapDeserializer<'py> {
492 fn new(dict: &Bound<'py, PyDict>) -> Self {
493 let mut keys = Vec::new();
494 let mut values = Vec::new();
495 for (key, value) in dict.iter() {
496 keys.push(key);
497 values.push(value);
498 }
499 Self { keys, values }
500 }
501}
502
503impl<'de> MapAccess<'de> for MapDeserializer<'_> {
504 type Error = Error;
505
506 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
507 where
508 K: de::DeserializeSeed<'de>,
509 {
510 if let Some(key) = self.keys.pop() {
511 let key = seed.deserialize(PyAnyDeserializer(key))?;
512 Ok(Some(key))
513 } else {
514 Ok(None)
515 }
516 }
517
518 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
519 where
520 V: de::DeserializeSeed<'de>,
521 {
522 if let Some(value) = self.values.pop() {
523 let value = seed.deserialize(PyAnyDeserializer(value))?;
524 Ok(value)
525 } else {
526 unreachable!()
527 }
528 }
529}
530
531// this lifetime is technically no longer 'py
532struct EnumDeserializer<'py> {
533 variant: &'py str,
534 inner: Bound<'py, PyAny>,
535}
536
537impl<'de> de::EnumAccess<'de> for EnumDeserializer<'_> {
538 type Error = Error;
539 type Variant = Self;
540
541 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
542 where
543 V: de::DeserializeSeed<'de>,
544 {
545 Ok((
546 seed.deserialize(StrDeserializer::<Error>::new(self.variant))?,
547 self,
548 ))
549 }
550}
551
552impl<'de> de::VariantAccess<'de> for EnumDeserializer<'_> {
553 type Error = Error;
554
555 fn unit_variant(self) -> Result<()> {
556 Ok(())
557 }
558
559 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
560 where
561 T: de::DeserializeSeed<'de>,
562 {
563 seed.deserialize(PyAnyDeserializer(self.inner))
564 }
565
566 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
567 where
568 V: Visitor<'de>,
569 {
570 PyAnyDeserializer(self.inner).deserialize_seq(visitor)
571 }
572
573 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
574 where
575 V: Visitor<'de>,
576 {
577 PyAnyDeserializer(self.inner).deserialize_map(visitor)
578 }
579}