stefans_utils/
php_safe_hashmap.rs

1use serde::{
2    Deserialize,
3    de::{self, Deserializer, MapAccess, Visitor},
4};
5use std::{collections::HashMap, fmt::Debug, hash::Hash, marker::PhantomData};
6
7pub struct PhpSafeHashMap<K, V>(PhantomData<(K, V)>);
8
9impl<'de, K: Hash + Eq + Debug + Deserialize<'de>, V: Deserialize<'de>> PhpSafeHashMap<K, V> {
10    pub fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<HashMap<K, V>, D::Error> {
11        deserializer.deserialize_any(PhpSafeHashMap::default())
12    }
13}
14
15impl<K, V> Default for PhpSafeHashMap<K, V> {
16    fn default() -> Self {
17        Self(PhantomData)
18    }
19}
20
21impl<'de, K: Hash + Eq + Debug, V> Visitor<'de> for PhpSafeHashMap<K, V>
22where
23    K: Deserialize<'de>,
24    V: Deserialize<'de>,
25{
26    type Value = HashMap<K, V>;
27
28    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
29        formatter.write_str("A HashMap<K, V> or a sequence of key-value tuples like Vec<(K, V)> where each key must be unique")
30    }
31
32    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
33    where
34        M: MapAccess<'de>,
35    {
36        let mut map = HashMap::with_capacity(access.size_hint().unwrap_or(0));
37
38        while let Some((key, value)) = access.next_entry()? {
39            map.insert(key, value);
40        }
41
42        Ok(map)
43    }
44
45    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
46    where
47        A: serde::de::SeqAccess<'de>,
48    {
49        let mut map = HashMap::with_capacity(seq.size_hint().unwrap_or(0));
50
51        while let Some((k, v)) = seq.next_element::<(K, V)>()? {
52            if map.contains_key(&k) {
53                return Err(de::Error::custom(format_args!("duplicate field `{k:?}`")));
54            }
55
56            map.insert(k, v);
57        }
58
59        Ok(map)
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use std::fmt::Display;
66
67    use super::*;
68    use serde::Serialize;
69    use serde_json::json;
70
71    #[derive(Debug, Serialize, Deserialize, PartialEq)]
72    #[serde(untagged)]
73    enum StringOrUsize {
74        String(String),
75        Usize(usize),
76    }
77
78    impl From<usize> for StringOrUsize {
79        fn from(value: usize) -> Self {
80            Self::Usize(value)
81        }
82    }
83
84    impl From<&'static str> for StringOrUsize {
85        fn from(value: &'static str) -> Self {
86            Self::String(value.to_string())
87        }
88    }
89
90    #[derive(Debug, Serialize, Deserialize, PartialEq)]
91    struct MyPhpSafeMap(
92        #[serde(deserialize_with = "PhpSafeHashMap::deserialize")] HashMap<String, StringOrUsize>,
93    );
94
95    impl MyPhpSafeMap {
96        fn from(key_values: impl IntoIterator<Item = (impl Display, StringOrUsize)>) -> Self {
97            let mut map = HashMap::new();
98
99            for (key, value) in key_values {
100                map.insert(key.to_string(), value);
101            }
102
103            Self(map)
104        }
105    }
106
107    #[test]
108    fn should_work_with_an_actual_map() {
109        let json = json! (
110            {
111                "foo": 123,
112                "bar": "baz"
113            }
114        );
115
116        let map: MyPhpSafeMap = serde_json::from_value(json).unwrap();
117
118        assert_eq!(
119            map,
120            MyPhpSafeMap::from([("foo", 123.into()), ("bar", "baz".into())])
121        )
122    }
123
124    #[test]
125    fn should_work_with_an_associated_array() {
126        let json = json!([["foo", 123], ["bar", "baz"]]);
127        let map: MyPhpSafeMap = serde_json::from_value(json).unwrap();
128
129        assert_eq!(
130            map,
131            MyPhpSafeMap::from([("foo", 123.into()), ("bar", "baz".into())])
132        )
133    }
134
135    #[test]
136    fn should_error_on_duplicate_fields() {
137        let json = json!([["foo", 123], ["foo", "baz"]]);
138        let res: Result<MyPhpSafeMap, _> = serde_json::from_value(json);
139
140        assert!(res.is_err())
141    }
142}