hgtime/
serde_impl.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::fmt;
9
10use serde::de;
11use serde::de::Error;
12use serde::de::IgnoredAny;
13use serde::de::Unexpected;
14use serde::ser::SerializeTuple;
15use serde::Deserialize;
16use serde::Deserializer;
17use serde::Serialize;
18use serde::Serializer;
19
20use crate::HgTime;
21
22// serialize as a tuple of 2 integers: (time, offset).
23impl Serialize for HgTime {
24    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
25        let mut t = serializer.serialize_tuple(2)?;
26        t.serialize_element(&self.unixtime)?;
27        t.serialize_element(&self.offset)?;
28        t.end()
29    }
30}
31
32// deserialize from either:
33// - (time: int, offset: int): serialize format
34// - 'time offset': str, parse and deserialize
35struct HgTimeVisitor;
36
37impl<'de> de::Visitor<'de> for HgTimeVisitor {
38    type Value = HgTime;
39
40    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
41        f.write_str("(time, offset) tuple or string")
42    }
43
44    fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
45        // space separated int tuple
46        if let Some((unixtime_str, offset_str)) = v.split_once(' ') {
47            if let (Ok(unixtime), Ok(offset)) =
48                (unixtime_str.parse::<i64>(), offset_str.parse::<i32>())
49            {
50                return Ok(HgTime { unixtime, offset });
51            }
52        }
53        // date str
54        match HgTime::parse(v) {
55            Some(v) => Ok(v),
56            None => Err(E::invalid_value(Unexpected::Str(v), &"HgTime str")),
57        }
58    }
59
60    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
61    where
62        A: de::SeqAccess<'de>,
63    {
64        let unixtime: i64 = seq
65            .next_element()?
66            .ok_or_else(|| A::Error::missing_field("unixtime"))?;
67        let offset: i32 = seq
68            .next_element()?
69            .ok_or_else(|| A::Error::missing_field("offset"))?;
70        if let Some(remaining) = seq.size_hint() {
71            if remaining > 0 {
72                return Err(A::Error::invalid_length(2 + remaining, &"2"));
73            }
74        } else {
75            // No concrete size.
76            let next: Option<IgnoredAny> = seq.next_element()?;
77            if next.is_some() {
78                return Err(A::Error::invalid_length(3, &"2"));
79            }
80        }
81        Ok(HgTime { unixtime, offset })
82    }
83}
84
85impl<'de> Deserialize<'de> for HgTime {
86    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
87        deserializer.deserialize_any(HgTimeVisitor)
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    fn check_round_trip(t: HgTime) {
96        let s = serde_cbor::to_vec(&t).unwrap();
97        let t2 = serde_cbor::from_slice(&s).unwrap();
98        assert_eq!(t, t2);
99        let (unixtime, offset): (i64, i32) = serde_cbor::from_slice(&s).unwrap();
100        assert_eq!(t.unixtime, unixtime);
101        assert_eq!(t.offset, offset);
102    }
103
104    #[test]
105    fn test_basic_round_trip() {
106        for unixtime in [i64::MIN, -1, 0, 1, i64::MAX] {
107            for offset in [i32::MIN, -1, 0, 1, i32::MAX] {
108                check_round_trip(HgTime { unixtime, offset });
109            }
110        }
111    }
112
113    fn deserialize_from(v: impl Serialize) -> String {
114        let s = serde_cbor::to_vec(&v).unwrap();
115        match serde_cbor::from_slice::<HgTime>(&s) {
116            Err(e) => format!("Err({})", e),
117            Ok(v) => format!("{} {}", v.unixtime, v.offset),
118        }
119    }
120
121    #[test]
122    fn test_deserialize_from_custom_types() {
123        // sequences
124        assert_eq!(deserialize_from((12, 34)), "12 34");
125        assert_eq!(deserialize_from([12, 34]), "12 34");
126        assert_eq!(
127            deserialize_from((12, 34, 56)),
128            "Err(invalid length 3, expected 2)"
129        );
130        assert_eq!(deserialize_from((12,)), "Err(missing field `offset`)");
131
132        // strings
133        assert_eq!(deserialize_from("-11 -22"), "-11 -22");
134        assert_eq!(deserialize_from("2000-1-1 +0800"), "946656000 -28800");
135        assert_eq!(
136            deserialize_from("guess what"),
137            "Err(invalid value: string \"guess what\", expected HgTime str)"
138        );
139
140        // other types
141        assert_eq!(
142            deserialize_from(()),
143            "Err(invalid type: null, expected (time, offset) tuple or string)"
144        );
145    }
146}