use serde::de::value::MapAccessDeserializer;
use serde::de::{DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor};
use serde::Deserializer;
use std::collections::hash_map::Iter;
use std::collections::HashMap;
use std::borrow::Cow;
use std::error::Error;
use std::fmt;
use std::fmt::Display;
use super::{Number, Value};
pub struct ValueDeserializer<'a>(pub &'a Value);
impl<'de> Deserializer<'de> for Number {
type Error = DeserializationError;
fn deserialize_any<V>(self, v: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self {
Number::Integer(value) => v.visit_i64(value),
Number::Float(value) => v.visit_f64(value),
Number::UInteger(value) => v.visit_u64(value),
}
}
serde::forward_to_deserialize_any! {
bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq enum
bytes byte_buf map struct unit newtype_struct
ignored_any unit_struct tuple_struct tuple option identifier
}
}
impl<'de> Deserializer<'de> for ValueDeserializer<'de> {
type Error = DeserializationError;
fn deserialize_any<V>(self, v: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.0 {
Value::String(ref s) => v.visit_str(s),
Value::Bool(b) => v.visit_bool(*b),
Value::Number(n) => n.deserialize_any(v),
Value::None => v.visit_none(),
Value::Map(map) => v.visit_map(MapDeserializer::new(map)),
Value::Array(seq) => v.visit_seq(SequenceDeserializer::new(seq)),
}
}
serde::forward_to_deserialize_any! {
bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str
string seq bytes byte_buf map struct
ignored_any tuple_struct tuple identifier
}
fn deserialize_option<V>(self, v: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.0 {
Value::None => v.visit_none(),
_ => v.visit_some(self),
}
}
fn deserialize_unit<V>(self, v: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.0 {
Value::None => v.visit_unit(),
_ => self.deserialize_any(v),
}
}
fn deserialize_unit_struct<V>(self, _name: &'static str, v: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.0 {
Value::None => v.visit_unit(),
_ => self.deserialize_any(v),
}
}
fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
v: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
v.visit_newtype_struct(self)
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
v: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.0 {
Value::String(s) => v.visit_enum((**s).into_deserializer()),
Value::Map(ref map) => {
let map_access = MapDeserializer::new(map);
v.visit_enum(MapAccessDeserializer::new(map_access))
}
_ => self.deserialize_any(v),
}
}
fn is_human_readable(&self) -> bool {
false
}
}
struct MapDeserializer<'de> {
iter: Iter<'de, String, Value>,
last_kv_pair: Option<(&'de String, &'de Value)>,
}
impl<'de> MapDeserializer<'de> {
fn new(map: &'de HashMap<String, Value>) -> Self {
MapDeserializer {
iter: map.iter(),
last_kv_pair: None,
}
}
}
impl<'de> MapAccess<'de> for MapDeserializer<'de> {
type Error = DeserializationError;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, DeserializationError>
where
K: DeserializeSeed<'de>,
{
if let Some((k, v)) = self.iter.next() {
let result = seed.deserialize(k.as_str().into_deserializer()).map(Some);
self.last_kv_pair = Some((k, v));
result.map_err(|err: DeserializationError| err.with_prefix(k))
} else {
Ok(None)
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, DeserializationError>
where
V: DeserializeSeed<'de>,
{
let (key, value) = self
.last_kv_pair
.take()
.expect("visit_value called before visit_key");
seed.deserialize(ValueDeserializer(value))
.map_err(|err: DeserializationError| err.with_prefix(key))
}
}
struct SequenceDeserializer<'de> {
iter: std::iter::Enumerate<std::slice::Iter<'de, Value>>,
len: usize,
}
impl<'de> SequenceDeserializer<'de> {
fn new(vec: &'de [Value]) -> Self {
SequenceDeserializer {
iter: vec.iter().enumerate(),
len: vec.len(),
}
}
}
impl<'de> SeqAccess<'de> for SequenceDeserializer<'de> {
type Error = DeserializationError;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
where
T: DeserializeSeed<'de>,
{
if let Some((i, item)) = self.iter.next() {
self.len -= 1;
seed.deserialize(ValueDeserializer(item))
.map(Some)
.map_err(|e: DeserializationError| e.with_prefix(&i.to_string()))
} else {
Ok(None)
}
}
fn size_hint(&self) -> Option<usize> {
Some(self.len)
}
}
#[derive(Debug)]
pub struct DeserializationError {
kind: ErrorKind,
key: Vec<String>,
}
impl DeserializationError {
pub fn with_prefix(mut self, prefix: &str) -> Self {
self.key.push(prefix.to_string());
self
}
}
impl Display for DeserializationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let key = self
.key
.iter()
.rev()
.map(|x| x.as_ref())
.collect::<Vec<&str>>()
.join(".");
write!(f, "{}, key = `{}`", self.kind, key)
}
}
impl Error for DeserializationError {}
#[derive(Debug)]
enum ErrorKind {
Message(String),
InvalidType(UnexpectedOwned, String),
InvalidValue(UnexpectedOwned, String),
InvalidLength(usize, String),
UnknownVariant(String, &'static [&'static str]),
UnknownField(String, &'static [&'static str]),
MissingField(Cow<'static, str>),
DuplicateField(&'static str),
}
impl Display for ErrorKind {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ErrorKind::Message(msg) => f.write_str(msg),
ErrorKind::InvalidType(v, exp) => {
write!(f, "Invalid type `{}`, expected `{}`", v, exp)
}
ErrorKind::InvalidValue(v, exp) => {
write!(f, "Invalid value `{}`, expected `{}`", v, exp)
}
ErrorKind::InvalidLength(v, exp) => {
write!(f, "Invalid length `{}`, expected `{}`", v, exp)
}
ErrorKind::UnknownVariant(v, exp) => {
write!(
f,
"Unknown variant `{}`, expected `{}`",
v,
OneOfDisplayWrapper(exp)
)
}
ErrorKind::UnknownField(v, exp) => {
write!(
f,
"Unknown field `{}`, expected `{}`",
v,
OneOfDisplayWrapper(exp)
)
}
ErrorKind::MissingField(v) => {
write!(f, "Missing field `{}`", v)
}
ErrorKind::DuplicateField(v) => {
write!(f, "Duplicate field `{}`", v)
}
}
}
}
impl serde::de::Error for DeserializationError {
fn custom<T>(msg: T) -> Self
where
T: Display,
{
DeserializationError {
kind: ErrorKind::Message(msg.to_string()),
key: Vec::new(),
}
}
fn invalid_type(unexp: serde::de::Unexpected, exp: &dyn serde::de::Expected) -> Self {
DeserializationError {
kind: ErrorKind::InvalidType(unexp.into(), exp.to_string()),
key: Vec::new(),
}
}
fn invalid_value(unexp: serde::de::Unexpected, exp: &dyn serde::de::Expected) -> Self {
DeserializationError {
kind: ErrorKind::InvalidValue(unexp.into(), exp.to_string()),
key: Vec::new(),
}
}
fn invalid_length(len: usize, exp: &dyn serde::de::Expected) -> Self {
DeserializationError {
kind: ErrorKind::InvalidLength(len, exp.to_string()),
key: Vec::new(),
}
}
fn unknown_variant(variant: &str, expected: &'static [&'static str]) -> Self {
DeserializationError {
kind: ErrorKind::UnknownVariant(variant.into(), expected),
key: Vec::new(),
}
}
fn unknown_field(field: &str, expected: &'static [&'static str]) -> Self {
DeserializationError {
kind: ErrorKind::UnknownField(field.into(), expected),
key: Vec::new(),
}
}
fn missing_field(field: &'static str) -> Self {
DeserializationError {
kind: ErrorKind::MissingField(field.into()),
key: Vec::new(),
}
}
fn duplicate_field(field: &'static str) -> Self {
DeserializationError {
kind: ErrorKind::DuplicateField(field),
key: Vec::new(),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum UnexpectedOwned {
Bool(bool),
Unsigned(u128),
Signed(i128),
Float(f64),
Char(char),
Str(String),
Bytes(Vec<u8>),
Unit,
Option,
NewtypeStruct,
Seq,
Map,
Enum,
UnitVariant,
NewtypeVariant,
TupleVariant,
StructVariant,
Other(String),
}
impl fmt::Display for UnexpectedOwned {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UnexpectedOwned::Bool(v) => write!(f, "bool {}", v),
UnexpectedOwned::Unsigned(v) => write!(f, "unsigned int `{}`", v),
UnexpectedOwned::Signed(v) => write!(f, "signed int `{}`", v),
UnexpectedOwned::Float(v) => write!(f, "float `{}`", v),
UnexpectedOwned::Char(v) => write!(f, "char {:?}", v),
UnexpectedOwned::Str(v) => write!(f, "string {:?}", v),
UnexpectedOwned::Bytes(v) => write!(f, "bytes {:?}", v),
UnexpectedOwned::Unit => write!(f, "unit"),
UnexpectedOwned::Option => write!(f, "option"),
UnexpectedOwned::NewtypeStruct => write!(f, "new-type struct"),
UnexpectedOwned::Seq => write!(f, "sequence"),
UnexpectedOwned::Map => write!(f, "map"),
UnexpectedOwned::Enum => write!(f, "enum"),
UnexpectedOwned::UnitVariant => write!(f, "unit variant"),
UnexpectedOwned::NewtypeVariant => write!(f, "new-type variant"),
UnexpectedOwned::TupleVariant => write!(f, "tuple variant"),
UnexpectedOwned::StructVariant => write!(f, "struct variant"),
UnexpectedOwned::Other(v) => v.fmt(f),
}
}
}
impl From<serde::de::Unexpected<'_>> for UnexpectedOwned {
fn from(value: serde::de::Unexpected<'_>) -> UnexpectedOwned {
match value {
serde::de::Unexpected::Bool(v) => UnexpectedOwned::Bool(v),
serde::de::Unexpected::Unsigned(v) => UnexpectedOwned::Unsigned(v as u128),
serde::de::Unexpected::Signed(v) => UnexpectedOwned::Signed(v as i128),
serde::de::Unexpected::Float(v) => UnexpectedOwned::Float(v),
serde::de::Unexpected::Char(v) => UnexpectedOwned::Char(v),
serde::de::Unexpected::Str(v) => UnexpectedOwned::Str(v.into()),
serde::de::Unexpected::Bytes(v) => UnexpectedOwned::Bytes(v.into()),
serde::de::Unexpected::Unit => UnexpectedOwned::Unit,
serde::de::Unexpected::Option => UnexpectedOwned::Option,
serde::de::Unexpected::NewtypeStruct => UnexpectedOwned::NewtypeStruct,
serde::de::Unexpected::Seq => UnexpectedOwned::Seq,
serde::de::Unexpected::Map => UnexpectedOwned::Map,
serde::de::Unexpected::Enum => UnexpectedOwned::Enum,
serde::de::Unexpected::UnitVariant => UnexpectedOwned::UnitVariant,
serde::de::Unexpected::NewtypeVariant => UnexpectedOwned::NewtypeVariant,
serde::de::Unexpected::TupleVariant => UnexpectedOwned::TupleVariant,
serde::de::Unexpected::StructVariant => UnexpectedOwned::StructVariant,
serde::de::Unexpected::Other(v) => UnexpectedOwned::Other(v.into()),
}
}
}
struct OneOfDisplayWrapper(pub &'static [&'static str]);
impl Display for OneOfDisplayWrapper {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.0.len() {
0 => write!(f, "none"),
1 => write!(f, "`{}`", self.0[0]),
2 => write!(f, "`{}` or `{}`", self.0[0], self.0[1]),
_ => {
write!(f, "one of ")?;
for (i, alt) in self.0.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "`{}`", alt)?;
}
Ok(())
}
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use crate::value::{Number, Value};
use serde::Deserialize;
use super::ValueDeserializer;
#[test]
fn test_deserialize_none() {
let res = <()>::deserialize(ValueDeserializer(&Value::None)).unwrap();
assert_eq!(res, ());
let res =
<Option<i32>>::deserialize(ValueDeserializer(&Value::None)).unwrap();
assert_eq!(res, None);
}
#[test]
fn test_deserialize_int() {
let res = i64::deserialize(ValueDeserializer(&Value::Number(
(42 as i64).into(),
)))
.unwrap();
assert_eq!(res, 42);
let res = u64::deserialize(ValueDeserializer(&Value::Number(
(42 as i64).into(),
)))
.unwrap();
assert_eq!(res, 42);
}
#[test]
fn test_deserialize_uint() {
let res = u64::deserialize(ValueDeserializer(&Value::Number(
(42 as u64).into(),
)))
.unwrap();
assert_eq!(res, 42);
let res = i64::deserialize(ValueDeserializer(&Value::Number(
(42 as u64).into(),
)))
.unwrap();
assert_eq!(res, 42);
}
#[test]
fn test_deserialize_float() {
let res = f64::deserialize(ValueDeserializer(&Value::Number(
(42.1 as f64).into(),
)))
.unwrap();
assert_eq!(res, 42.1);
}
#[test]
fn test_deserialize_string() {
let res = String::deserialize(ValueDeserializer(&Value::String(
"hello world".to_string(),
)))
.unwrap();
assert_eq!(res, "hello world");
}
#[test]
fn test_deserialize_bool() {
let res = bool::deserialize(ValueDeserializer(&Value::Bool(true))).unwrap();
assert!(res);
let res =
bool::deserialize(ValueDeserializer(&Value::Bool(false))).unwrap();
assert!(!res);
}
#[test]
fn test_deserialize_map() {
let value = Value::Map(HashMap::from([
(
"hello".to_string(),
Value::String("world".to_string()),
),
(
"world".to_string(),
Value::String("hello".to_string()),
),
]));
let res =
HashMap::<String, String>::deserialize(ValueDeserializer(&value)).unwrap();
assert_eq!(res.get("hello").unwrap(), "world");
assert_eq!(res.get("world").unwrap(), "hello");
}
#[test]
fn test_deserialize_array() {
let value = Value::Array(vec![
Value::String("hello".to_string()),
Value::String("world".to_string()),
]);
let res = Vec::<String>::deserialize(ValueDeserializer(&value)).unwrap();
assert_eq!(res.get(0).unwrap(), "hello");
assert_eq!(res.get(1).unwrap(), "world");
}
#[test]
fn test_deserialize_struct() {
#[derive(Deserialize)]
struct TestStruct {
pub string: String,
pub int: i64,
pub optional: Option<i32>,
pub optional_missing: Option<i32>,
pub optional_present: Option<i32>,
pub unit: (),
}
let value = Value::Map(HashMap::from([
(
"string".to_string(),
Value::String("Hello World".to_string()),
),
(
"int".to_string(),
Value::Number(Number::UInteger(42)),
),
("optional".to_string(), Value::None),
(
"optional_present".to_string(),
Value::Number(42.into()),
),
("unit".to_string(), Value::None),
]));
let res = TestStruct::deserialize(ValueDeserializer(&value)).unwrap();
assert_eq!(res.string, "Hello World");
assert_eq!(res.int, 42);
assert_eq!(res.optional, None);
assert_eq!(res.optional_missing, None);
assert_eq!(res.optional_present, Some(42));
assert_eq!(res.unit, ());
}
#[test]
fn test_deserialize_unit_struct() {
#[derive(Deserialize, PartialEq, Eq, Debug)]
struct TestStruct;
let value = Value::None;
let res = TestStruct::deserialize(ValueDeserializer(&value)).unwrap();
assert_eq!(res, TestStruct);
}
#[test]
fn test_deserialize_newtype_struct() {
#[derive(Deserialize, PartialEq, Eq, Debug)]
struct TestStruct(String);
let value = Value::String("Hello World".to_string());
let res = TestStruct::deserialize(ValueDeserializer(&value)).unwrap();
assert_eq!(res.0, "Hello World");
}
#[test]
fn test_deserialize_enum() {
#[derive(Deserialize, PartialEq, Eq, Debug)]
enum TestEnum {
Unit,
NewType(String),
Complex { value: String, id: i32 },
}
let value = Value::Map(HashMap::from([("Unit".to_string(), Value::None)]));
let res = TestEnum::deserialize(ValueDeserializer(&value)).unwrap();
assert_eq!(res, TestEnum::Unit);
let value = Value::String("Unit".to_string());
let res = TestEnum::deserialize(ValueDeserializer(&value)).unwrap();
assert_eq!(res, TestEnum::Unit);
let value = Value::Map(HashMap::from([(
"NewType".to_string(),
Value::String("Hello World".to_string()),
)]));
let res = TestEnum::deserialize(ValueDeserializer(&value)).unwrap();
assert_eq!(res, TestEnum::NewType("Hello World".to_string()));
let value = Value::Map(HashMap::from([(
"Complex".to_string(),
Value::Map(HashMap::from([
(
"value".to_string(),
Value::String("Hello World".to_string()),
),
(
"id".to_string(),
Value::Number(Number::UInteger(42)),
),
])),
)]));
let res = TestEnum::deserialize(ValueDeserializer(&value)).unwrap();
assert_eq!(
res,
TestEnum::Complex {
value: "Hello World".to_string(),
id: 42,
}
);
}
#[test]
fn test_deserialize_error_invalid_type() {
#[derive(Deserialize, Debug)]
struct TestStruct {
#[serde(rename = "string")]
pub _string: i32,
}
let value = Value::Map(HashMap::from([(
"string".to_string(),
Value::String("Hello World".to_string()),
)]));
let res = TestStruct::deserialize(ValueDeserializer(&value)).unwrap_err();
assert_eq!(
res.to_string(),
"Invalid type `string \"Hello World\"`, expected `i32`, key = `string`"
);
}
}