versionneer/
simd_json.rs

1use std::io::Write;
2
3use simd_json::{
4    Error as JsonError, ErrorType, Node, Tape,
5    derived::{TypedScalarValue, ValueObjectAccessAsScalar},
6};
7use simd_json_derive::Serialize;
8
9use crate::{Decode, Encode};
10
11/// Bincode decoder using `std::io::Read`
12pub struct Decoder<'data> {
13    tape: Option<Tape<'data>>,
14}
15
16impl<'data> Decoder<'data> {
17    /// Create a new decoder from a reader
18    /// # Errors
19    /// if the json is invalid
20    pub fn new(data: &'data mut [u8]) -> Result<Self, DecodeError> {
21        let tape = simd_json::to_tape(data)?;
22        Ok(Self { tape: Some(tape) })
23    }
24}
25
26impl crate::Decoder for Decoder<'_> {
27    type Error = DecodeError;
28    fn decode_version(&mut self) -> Result<u32, Self::Error> {
29        self.tape
30            .as_ref()
31            .ok_or(DecodeError::AlreadyConsumed)?
32            .as_value()
33            .get_u32("v")
34            .ok_or(DecodeError::InvalidVersionField)
35    }
36}
37
38#[derive(Debug)]
39/// Errors for decoding versioned data with simd-json
40pub enum DecodeError {
41    /// A simd-json related error
42    Json(JsonError),
43    /// Derive error
44    JsonDerive(simd_json_derive::de::Error),
45    /// The version field is missing or not an integer
46    InvalidVersionField,
47    /// Invalid number of fields
48    InvalidNumberOfFields(usize),
49    /// Invalid format
50    InvalidFormat,
51    /// The deserializer was already consumed
52    AlreadyConsumed,
53}
54
55impl std::error::Error for DecodeError {}
56
57impl std::fmt::Display for DecodeError {
58    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
59        match self {
60            DecodeError::Json(error) => error.fmt(f),
61            DecodeError::JsonDerive(error) => error.fmt(f),
62            DecodeError::InvalidVersionField => {
63                write!(f, "`version` field missing or not an intege")
64            }
65            DecodeError::InvalidNumberOfFields(n) => write!(
66                f,
67                "the versioned struct needs to have exactly two elements but has {n}"
68            ),
69            DecodeError::InvalidFormat => write!(
70                f,
71                "the format is invalid, needs to be an object with version and data key"
72            ),
73            DecodeError::AlreadyConsumed => write!(f, "the deserializer was already consumed"),
74        }
75    }
76}
77impl From<JsonError> for DecodeError {
78    fn from(value: JsonError) -> Self {
79        Self::Json(value)
80    }
81}
82
83impl From<simd_json_derive::de::Error> for DecodeError {
84    fn from(value: simd_json_derive::de::Error) -> Self {
85        Self::JsonDerive(value)
86    }
87}
88
89impl<'data, T> Decode<Decoder<'data>> for T
90where
91    T: simd_json_derive::Deserialize<'data> + 'data,
92{
93    fn decode_data(decoder: &mut Decoder<'data>) -> Result<Self, DecodeError> {
94        let mut tape = decoder
95            .tape
96            .take()
97            .ok_or(DecodeError::AlreadyConsumed)?
98            .0
99            .into_iter()
100            .peekable();
101
102        let Some(Node::Object { len, .. }) = tape.next() else {
103            return Err(JsonError::generic(ErrorType::ExpectedMap).into());
104        };
105        if len != 2 {
106            return Err(DecodeError::InvalidNumberOfFields(len));
107        }
108
109        // we checked the first to b
110        loop {
111            let Some(Node::String(key)) = tape.next() else {
112                return Err(JsonError::generic(ErrorType::Eof).into());
113            };
114            if key == "v" {
115                match tape.next() {
116                    None => return Err(JsonError::generic(ErrorType::Eof).into()),
117                    Some(Node::Static(s)) if !s.is_u32() => {
118                        return Err(DecodeError::InvalidFormat);
119                    }
120                    _ => continue,
121                }
122            }
123            if key != "d" {
124                return Err(DecodeError::InvalidFormat);
125            }
126
127            return Ok(T::from_tape(&mut tape)?);
128        }
129    }
130}
131
132/// Error while encoding versiond data
133#[derive(Debug)]
134pub enum EncodeError {
135    /// A simd-json related error
136    Io(std::io::Error),
137    /// Version is already set
138    VersionAlreadyDefined,
139    /// Version not defined
140    VersionNotDefined,
141}
142impl std::error::Error for EncodeError {}
143
144impl From<std::io::Error> for EncodeError {
145    fn from(value: std::io::Error) -> Self {
146        Self::Io(value)
147    }
148}
149impl std::fmt::Display for EncodeError {
150    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
151        match self {
152            EncodeError::Io(error) => error.fmt(f),
153            EncodeError::VersionAlreadyDefined => write!(f, "version already defined"),
154            EncodeError::VersionNotDefined => write!(f, "version not defined"),
155        }
156    }
157}
158/// Bincode encoder using `std::io::Write`
159pub struct Encoder<W: Write> {
160    writer: W,
161    version: Option<u32>,
162}
163
164impl<W: Write> Encoder<W> {
165    /// Create a new encoder from a writer
166    #[must_use]
167    pub fn new(writer: W) -> Self {
168        Self {
169            writer,
170            version: None,
171        }
172    }
173}
174
175impl<W: Write> crate::Encoder for Encoder<W> {
176    type Error = EncodeError;
177    fn encode_version(&mut self, version: u32) -> Result<(), Self::Error> {
178        if self.version.replace(version).is_some() {
179            Err(EncodeError::VersionAlreadyDefined)
180        } else {
181            Ok(())
182        }
183    }
184}
185
186impl<W, T> Encode<Encoder<W>> for T
187where
188    T: simd_json_derive::Serialize,
189    W: Write,
190{
191    fn encode_data(&self, encoder: &mut Encoder<W>) -> Result<(), EncodeError> {
192        let Some(version) = encoder.version else {
193            return Err(EncodeError::VersionNotDefined);
194        };
195        encoder.writer.write_all(br#"{"v":"#)?;
196        version.json_write(&mut encoder.writer)?;
197        encoder.writer.write_all(br#","d":"#)?;
198        self.json_write(&mut encoder.writer)?;
199        encoder.writer.write_all(b"}")?;
200        Ok(())
201    }
202}
203
204#[cfg(test)]
205mod tests {
206
207    use super::{DecodeError, Decoder, EncodeError, Encoder};
208    use crate::{Decodable, Encodable, Upgrade, versioned};
209
210    #[derive(Debug, thiserror::Error)]
211    enum Error {
212        #[error("Invalid version: {0}")]
213        InvalidVersion(u32),
214        #[error(transparent)]
215        Decode(#[from] DecodeError),
216        #[error(transparent)]
217        Encoder(#[from] EncodeError),
218    }
219    impl crate::Error for Error {
220        fn invalid_version(version: u32) -> Self {
221            Self::InvalidVersion(version)
222        }
223    }
224
225    #[derive(Debug, PartialEq, Eq, simd_json_derive::Deserialize, simd_json_derive::Serialize)]
226    struct TestV0 {
227        data: u8,
228    }
229    versioned!(TestV0, 0);
230
231    #[derive(Debug, PartialEq, Eq, simd_json_derive::Deserialize, simd_json_derive::Serialize)]
232    struct TestV1 {
233        data: u16,
234    }
235    versioned!(TestV1, 1);
236
237    #[derive(Debug, PartialEq, Eq, simd_json_derive::Deserialize, simd_json_derive::Serialize)]
238    struct TestV2 {
239        data: u32,
240    }
241    versioned!(TestV2, 2);
242
243    impl TryFrom<TestV0> for TestV1 {
244        type Error = Error;
245        fn try_from(v0: TestV0) -> Result<Self, Self::Error> {
246            Ok(Self {
247                data: u16::from(v0.data),
248            })
249        }
250    }
251
252    impl TryFrom<TestV1> for TestV2 {
253        type Error = Error;
254        fn try_from(v1: TestV1) -> Result<Self, Self::Error> {
255            Ok(Self {
256                data: u32::from(v1.data),
257            })
258        }
259    }
260
261    #[test]
262    fn test_v0() -> Result<(), Error> {
263        let mut data = Vec::new();
264        let mut enc = Encoder::new(&mut data);
265        <TestV0 as Encodable<_, Error>>::encode(&TestV0 { data: 42 }, &mut enc)?;
266        let mut dec = Decoder::new(data.as_mut_slice())?;
267        let v0 = <TestV0 as Decodable<_, Error>>::decode(&mut dec)?;
268        assert_eq!(v0.data, 42);
269        Ok(())
270    }
271
272    #[test]
273    fn test_v1() -> Result<(), Error> {
274        let mut data = Vec::new();
275        let mut enc = Encoder::new(&mut data);
276        <TestV1 as Encodable<_, Error>>::encode(&TestV1 { data: 42 }, &mut enc)?;
277        let mut dec = Decoder::new(data.as_mut_slice())?;
278        let v1 = <TestV1 as Decodable<_, Error>>::decode(&mut dec)?;
279        assert_eq!(v1.data, 42);
280        Ok(())
281    }
282
283    #[test]
284    fn test_upgrade_v1() -> Result<(), Error> {
285        type Latest = Upgrade<TestV1, TestV0, Error>;
286        let mut data = Vec::new();
287        let mut enc = Encoder::new(&mut data);
288        <TestV0 as Encodable<_, Error>>::encode(&TestV0 { data: 42 }, &mut enc)?;
289        let mut dec = Decoder::new(data.as_mut_slice())?;
290        let v1 = Latest::decode(&mut dec)?;
291        assert_eq!(v1.data, 42);
292        Ok(())
293    }
294    #[test]
295    fn test_upgrade_v2() -> Result<(), Error> {
296        type Latest = Upgrade<TestV2, TestV1, Error>;
297        let mut data = Vec::new();
298        let mut enc = Encoder::new(&mut data);
299        <TestV1 as Encodable<_, Error>>::encode(&TestV1 { data: 42 }, &mut enc)?;
300        let mut dec = Decoder::new(data.as_mut_slice())?;
301        let v0 = Latest::decode(&mut dec)?;
302        assert_eq!(v0.data, 42);
303        Ok(())
304    }
305    #[test]
306    fn test_upgrade_all() -> Result<(), Error> {
307        type Latest = Upgrade<TestV2, Upgrade<TestV1, TestV0, Error>, Error>;
308        let mut data = Vec::new();
309        let mut enc = Encoder::new(&mut data);
310        <TestV0 as Encodable<_, Error>>::encode(&TestV0 { data: 42 }, &mut enc)?;
311        let mut dec = Decoder::new(data.as_mut_slice())?;
312        let v0 = Latest::decode(&mut dec)?;
313        assert_eq!(v0.data, 42);
314        Ok(())
315    }
316}