use crate::error::{Error, Result};
use pyo3::types::*;
use serde::{
de::{self, value::StrDeserializer, MapAccess, SeqAccess, Visitor},
forward_to_deserialize_any, Deserialize, Deserializer,
};
pub fn from_pyobject<'py, 'de, T: Deserialize<'de>>(any: &'py PyAny) -> Result<T> {
T::deserialize(PyAnyDeserializer(any))
}
struct PyAnyDeserializer<'py>(&'py PyAny);
impl<'de, 'py> de::Deserializer<'de> for PyAnyDeserializer<'py> {
type Error = Error;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
if self.0.is_instance_of::<PyDict>() {
return visitor.visit_map(MapDeserializer::new(self.0.extract()?));
}
if self.0.is_instance_of::<PyList>() {
return visitor.visit_seq(SeqDeserializer::from_list(self.0.extract()?));
}
if self.0.is_instance_of::<PyTuple>() {
return visitor.visit_seq(SeqDeserializer::from_tuple(self.0.extract()?));
}
if self.0.is_instance_of::<PyString>() {
return visitor.visit_str(self.0.extract()?);
}
if self.0.is_instance_of::<PyBool>() {
return visitor.visit_bool(self.0.extract()?);
}
if self.0.is_instance_of::<PyLong>() {
return visitor.visit_i64(self.0.extract()?);
}
if self.0.is_instance_of::<PyFloat>() {
return visitor.visit_f64(self.0.extract()?);
}
if self.0.is_none() {
return visitor.visit_none();
}
unreachable!("Unsupported type: {}", self.0.get_type());
}
fn deserialize_struct<V: de::Visitor<'de>>(
self,
name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value> {
if self.0.is_instance_of::<PyDict>() {
let dict: &PyDict = self.0.extract()?;
if let Some(inner) = dict.get_item(name)? {
if let Ok(inner) = inner.extract() {
return visitor.visit_map(MapDeserializer::new(inner));
}
}
}
self.deserialize_any(visitor)
}
fn deserialize_newtype_struct<V: de::Visitor<'de>>(
self,
name: &'static str,
visitor: V,
) -> Result<V::Value> {
if self.0.is_instance_of::<PyDict>() {
let dict: &PyDict = self.0.extract()?;
if let Some(inner) = dict.get_item(name)? {
return visitor.visit_seq(SeqDeserializer {
seq_reversed: vec![inner],
});
}
}
self.deserialize_any(visitor)
}
fn deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
if self.0.is_none() {
visitor.visit_none()
} else {
visitor.visit_some(self)
}
}
fn deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
if self.0.is(PyTuple::empty(self.0.py())) {
visitor.visit_unit()
} else {
self.deserialize_any(visitor)
}
}
fn deserialize_unit_struct<V: de::Visitor<'de>>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value> {
if self.0.is(PyTuple::empty(self.0.py())) {
visitor.visit_unit()
} else {
self.deserialize_any(visitor)
}
}
fn deserialize_enum<V: de::Visitor<'de>>(
self,
name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value> {
if self.0.is_instance_of::<PyDict>() {
let dict: &PyDict = self.0.extract()?;
if let Some(value) = dict.get_item(name)? {
if value.is_instance_of::<PyTuple>() {
let tuple: &PyTuple = value.extract()?;
if tuple.len() == 2 {
return visitor.visit_enum(EnumDeserializer {
variant: tuple.get_item(0)?.extract()?,
inner: tuple.get_item(1)?,
});
}
}
if value.is_instance_of::<PyString>() {
let variant = value.extract()?;
let py = self.0.py();
return visitor.visit_enum(EnumDeserializer {
variant,
inner: py.None().into_ref(py),
});
}
}
}
self.deserialize_any(visitor)
}
fn deserialize_tuple_struct<V: de::Visitor<'de>>(
self,
name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value> {
if self.0.is_instance_of::<PyDict>() {
let dict: &PyDict = self.0.extract()?;
if let Some(value) = dict.get_item(name)? {
if value.is_instance_of::<PyTuple>() {
let tuple: &PyTuple = value.extract()?;
return visitor.visit_seq(SeqDeserializer::from_tuple(tuple));
}
}
}
self.deserialize_any(visitor)
}
forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf seq tuple
map identifier ignored_any
}
}
struct SeqDeserializer<'py> {
seq_reversed: Vec<&'py PyAny>,
}
impl<'py> SeqDeserializer<'py> {
fn from_list(list: &'py PyList) -> Self {
let mut seq_reversed = Vec::new();
for item in list.iter().rev() {
seq_reversed.push(item);
}
Self { seq_reversed }
}
fn from_tuple(tuple: &'py PyTuple) -> Self {
let mut seq_reversed = Vec::new();
for item in tuple.iter().rev() {
seq_reversed.push(item);
}
Self { seq_reversed }
}
}
impl<'de, 'py> SeqAccess<'de> for SeqDeserializer<'py> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: de::DeserializeSeed<'de>,
{
self.seq_reversed.pop().map_or(Ok(None), |value| {
let value = seed.deserialize(PyAnyDeserializer(value))?;
Ok(Some(value))
})
}
}
struct MapDeserializer<'py> {
keys: Vec<&'py PyAny>,
values: Vec<&'py PyAny>,
}
impl<'py> MapDeserializer<'py> {
fn new(dict: &'py PyDict) -> Self {
let mut keys = Vec::new();
let mut values = Vec::new();
for (key, value) in dict.iter() {
keys.push(key);
values.push(value);
}
Self { keys, values }
}
}
impl<'de, 'py> MapAccess<'de> for MapDeserializer<'py> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: de::DeserializeSeed<'de>,
{
if let Some(key) = self.keys.pop() {
let key = seed.deserialize(PyAnyDeserializer(key))?;
Ok(Some(key))
} else {
Ok(None)
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: de::DeserializeSeed<'de>,
{
if let Some(value) = self.values.pop() {
let value = seed.deserialize(PyAnyDeserializer(value))?;
Ok(value)
} else {
unreachable!()
}
}
}
struct EnumDeserializer<'py> {
variant: &'py str,
inner: &'py PyAny,
}
impl<'de, 'py> de::EnumAccess<'de> for EnumDeserializer<'py> {
type Error = Error;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
where
V: de::DeserializeSeed<'de>,
{
Ok((
seed.deserialize(StrDeserializer::<Error>::new(self.variant))?,
self,
))
}
}
impl<'de, 'py> de::VariantAccess<'de> for EnumDeserializer<'py> {
type Error = Error;
fn unit_variant(self) -> Result<()> {
Ok(())
}
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
where
T: de::DeserializeSeed<'de>,
{
seed.deserialize(PyAnyDeserializer(self.inner))
}
fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
PyAnyDeserializer(self.inner).deserialize_seq(visitor)
}
fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
where
V: Visitor<'de>,
{
PyAnyDeserializer(self.inner).deserialize_map(visitor)
}
}