Skip to main content

string_interner/
serde_impl.rs

1use crate::{backend::Backend, StringInterner, Symbol};
2use alloc::boxed::Box;
3use core::{default::Default, fmt, hash::BuildHasher, marker};
4use serde::{
5    de::{Deserialize, Deserializer, SeqAccess, Visitor},
6    ser::{Serialize, SerializeSeq, Serializer},
7};
8
9impl<B, H> Serialize for StringInterner<B, H>
10where
11    B: Backend,
12    <B as Backend>::Symbol: Symbol,
13    for<'a> &'a B: IntoIterator<Item = (<B as Backend>::Symbol, &'a str)>,
14    H: BuildHasher,
15{
16    fn serialize<T>(&self, serializer: T) -> Result<T::Ok, T::Error>
17    where
18        T: Serializer,
19    {
20        let mut seq = serializer.serialize_seq(Some(self.len()))?;
21        for (_symbol, string) in self {
22            seq.serialize_element(string)?
23        }
24        seq.end()
25    }
26}
27
28impl<'de, B, H> Deserialize<'de> for StringInterner<B, H>
29where
30    B: Backend,
31    <B as Backend>::Symbol: Symbol,
32    H: BuildHasher + Default,
33{
34    fn deserialize<D>(deserializer: D) -> Result<StringInterner<B, H>, D::Error>
35    where
36        D: Deserializer<'de>,
37    {
38        deserializer.deserialize_seq(StringInternerVisitor::default())
39    }
40}
41
42struct StringInternerVisitor<B, H>
43where
44    B: Backend,
45    <B as Backend>::Symbol: Symbol,
46    H: BuildHasher,
47{
48    mark: marker::PhantomData<(<B as Backend>::Symbol, B, H)>,
49}
50
51impl<B, H> Default for StringInternerVisitor<B, H>
52where
53    B: Backend,
54    <B as Backend>::Symbol: Symbol,
55    H: BuildHasher,
56{
57    fn default() -> Self {
58        StringInternerVisitor {
59            mark: marker::PhantomData,
60        }
61    }
62}
63
64impl<'de, B, H> Visitor<'de> for StringInternerVisitor<B, H>
65where
66    B: Backend,
67    <B as Backend>::Symbol: Symbol,
68    H: BuildHasher + Default,
69{
70    type Value = StringInterner<B, H>;
71
72    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
73        formatter.write_str("Expected a contiguous sequence of strings.")
74    }
75
76    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
77    where
78        A: SeqAccess<'de>,
79    {
80        let mut interner: StringInterner<B, H> =
81            StringInterner::with_capacity_and_hasher(seq.size_hint().unwrap_or(0), H::default());
82        while let Some(s) = seq.next_element::<Box<str>>()? {
83            interner.get_or_intern(s);
84        }
85        Ok(interner)
86    }
87}
88
89macro_rules! impl_serde_for_symbol {
90    ($name:ident, $ty:ty) => {
91        impl ::serde::Serialize for $crate::symbol::$name {
92            fn serialize<T: ::serde::Serializer>(
93                &self,
94                serializer: T,
95            ) -> ::core::result::Result<T::Ok, T::Error> {
96                self.to_usize().serialize(serializer)
97            }
98        }
99
100        impl<'de> ::serde::Deserialize<'de> for $crate::symbol::$name {
101            fn deserialize<D: ::serde::Deserializer<'de>>(
102                deserializer: D,
103            ) -> ::core::result::Result<Self, D::Error> {
104                let index = <$ty as ::serde::Deserialize<'de>>::deserialize(deserializer)?;
105                let ::core::option::Option::Some(symbol) = Self::new(index) else {
106                    return ::core::result::Result::Err(<D::Error as ::serde::de::Error>::custom(
107                        ::core::concat!(
108                            "invalid index value for `",
109                            ::core::stringify!($name),
110                            "`"
111                        ),
112                    ));
113                };
114                ::core::result::Result::Ok(symbol)
115            }
116        }
117    };
118}
119impl_serde_for_symbol!(SymbolU16, u16);
120impl_serde_for_symbol!(SymbolU32, u32);
121impl_serde_for_symbol!(SymbolUsize, usize);
122
123#[cfg(test)]
124mod tests {
125    use crate::{
126        symbol::{SymbolU16, SymbolU32, SymbolUsize},
127        Symbol,
128    };
129    use serde_json;
130
131    fn symbol_round_trip_serializes<S>(symbol: S) -> bool
132    where
133        S: Symbol + std::fmt::Debug + serde::Serialize + serde::de::DeserializeOwned + PartialEq,
134    {
135        let serialized = serde_json::to_string(&symbol).expect("serialization should succeed");
136        let deserialized: S =
137            serde_json::from_str(&serialized).expect("deserialization should succeed");
138        symbol == deserialized
139    }
140
141    #[test]
142    fn symbol_u16_round_trips() {
143        assert!(symbol_round_trip_serializes(
144            SymbolU16::try_from_usize(0).unwrap()
145        ));
146        assert!(symbol_round_trip_serializes(
147            SymbolU16::try_from_usize(42).unwrap()
148        ));
149        assert!(symbol_round_trip_serializes(
150            SymbolU16::try_from_usize(u16::MAX as usize - 1).unwrap()
151        ));
152    }
153
154    #[test]
155    fn symbol_u32_round_trips() {
156        assert!(symbol_round_trip_serializes(
157            SymbolU32::try_from_usize(0).unwrap()
158        ));
159        assert!(symbol_round_trip_serializes(
160            SymbolU32::try_from_usize(42).unwrap()
161        ));
162        assert!(symbol_round_trip_serializes(
163            SymbolU32::try_from_usize(u32::MAX as usize - 1).unwrap()
164        ));
165    }
166
167    #[test]
168    fn symbol_usize_round_trips() {
169        assert!(symbol_round_trip_serializes(
170            SymbolUsize::try_from_usize(0).unwrap()
171        ));
172        assert!(symbol_round_trip_serializes(
173            SymbolUsize::try_from_usize(42).unwrap()
174        ));
175        assert!(symbol_round_trip_serializes(
176            SymbolUsize::try_from_usize(usize::MAX as usize - 1).unwrap()
177        ));
178    }
179
180    #[test]
181    fn raw_usize_round_trips() {
182        assert!(symbol_round_trip_serializes(
183            usize::try_from_usize(0).unwrap()
184        ));
185        assert!(symbol_round_trip_serializes(
186            usize::try_from_usize(42).unwrap()
187        ));
188        assert!(symbol_round_trip_serializes(
189            usize::try_from_usize(usize::MAX).unwrap()
190        ));
191    }
192}