typedb_driver/answer/
json.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20use std::{
21    borrow::Cow,
22    collections::HashMap,
23    fmt::{self, Write},
24    iter,
25};
26
27use itertools::Itertools;
28use serde::{
29    ser::{SerializeMap, SerializeSeq},
30    Deserialize, Serialize,
31};
32
33#[derive(Clone, Debug, PartialEq)]
34pub enum JSON {
35    Object(HashMap<Cow<'static, str>, JSON>),
36    Array(Vec<JSON>),
37    String(Cow<'static, str>),
38    Number(f64),
39    Boolean(bool),
40    Null,
41}
42
43impl fmt::Display for JSON {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            JSON::Object(object) => {
47                f.write_char('{')?;
48                for (i, (k, v)) in object.iter().enumerate() {
49                    if i > 0 {
50                        f.write_str(", ")?;
51                    }
52                    write!(f, r#""{}": {}"#, k, v)?;
53                }
54                f.write_char('}')?;
55            }
56            JSON::Array(list) => {
57                f.write_char('[')?;
58                for (i, v) in list.iter().enumerate() {
59                    if i > 0 {
60                        f.write_str(", ")?;
61                    }
62                    write!(f, "{}", v)?;
63                }
64                f.write_char(']')?;
65            }
66            JSON::String(string) => write_escaped_string(string, f)?,
67            JSON::Number(number) => write!(f, "{number}")?,
68            JSON::Boolean(boolean) => write!(f, "{boolean}")?,
69            JSON::Null => write!(f, "null")?,
70        }
71        Ok(())
72    }
73}
74
75fn write_escaped_string(string: &str, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76    const HEX: u8 = 0;
77    const BSP: u8 = b'b';
78    const TAB: u8 = b't';
79    const LF_: u8 = b'n';
80    const FF_: u8 = b'f';
81    const CR_: u8 = b'r';
82
83    const ASCII_CONTROL: usize = 0x20;
84
85    const ESCAPE: [u8; ASCII_CONTROL] = [
86        HEX, HEX, HEX, HEX, HEX, HEX, HEX, HEX, //
87        BSP, TAB, LF_, HEX, FF_, CR_, HEX, HEX, //
88        HEX, HEX, HEX, HEX, HEX, HEX, HEX, HEX, //
89        HEX, HEX, HEX, HEX, HEX, HEX, HEX, HEX, //
90    ];
91
92    const HEX_DIGITS: &[u8; 0x10] = b"0123456789abcdef";
93
94    let mut buf = Vec::with_capacity(string.len());
95
96    for byte in string.bytes() {
97        if (byte as usize) < ASCII_CONTROL {
98            match ESCAPE[byte as usize] {
99                HEX => {
100                    buf.extend_from_slice(&[
101                        b'\\',
102                        b'u',
103                        b'0',
104                        b'0',
105                        HEX_DIGITS[(byte as usize & 0xF0) >> 4],
106                        HEX_DIGITS[byte as usize & 0x0F],
107                    ]);
108                }
109                special => buf.extend_from_slice(&[b'\\', special]),
110            }
111        } else {
112            match byte {
113                b'"' | b'\\' => buf.extend_from_slice(&[b'\\', byte]),
114                _ => buf.push(byte),
115            }
116        }
117    }
118
119    write!(f, r#""{}""#, unsafe { String::from_utf8_unchecked(buf) })
120}
121
122impl Serialize for JSON {
123    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
124    where
125        S: serde::Serializer,
126    {
127        match self {
128            Self::Object(object) => {
129                let mut map = serializer.serialize_map(Some(object.len()))?;
130                for (key, value) in object {
131                    map.serialize_entry(key, value)?;
132                }
133                map.end()
134            }
135            Self::Array(array) => {
136                let mut seq = serializer.serialize_seq(Some(array.len()))?;
137                for item in array {
138                    seq.serialize_element(item)?;
139                }
140                seq.end()
141            }
142            Self::String(string) => serializer.serialize_str(string),
143            &Self::Number(number) => serializer.serialize_f64(number),
144            &Self::Boolean(boolean) => serializer.serialize_bool(boolean),
145            Self::Null => serializer.serialize_unit(),
146        }
147    }
148}
149
150impl<'de> Deserialize<'de> for JSON {
151    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
152    where
153        D: serde::Deserializer<'de>,
154    {
155        struct Visitor;
156
157        impl<'de> serde::de::Visitor<'de> for Visitor {
158            type Value = JSON;
159
160            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
161                formatter.write_str("a valid JSON value")
162            }
163
164            fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
165            where
166                E: serde::de::Error,
167            {
168                Ok(JSON::Boolean(value))
169            }
170
171            fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
172            where
173                E: serde::de::Error,
174            {
175                Ok(JSON::Number(value as f64))
176            }
177
178            fn visit_i128<E>(self, value: i128) -> Result<Self::Value, E>
179            where
180                E: serde::de::Error,
181            {
182                Ok(JSON::Number(value as f64))
183            }
184
185            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
186            where
187                E: serde::de::Error,
188            {
189                Ok(JSON::Number(value as f64))
190            }
191
192            fn visit_u128<E>(self, value: u128) -> Result<Self::Value, E>
193            where
194                E: serde::de::Error,
195            {
196                Ok(JSON::Number(value as f64))
197            }
198
199            fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
200            where
201                E: serde::de::Error,
202            {
203                Ok(JSON::Number(value))
204            }
205
206            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
207            where
208                E: serde::de::Error,
209            {
210                Ok(JSON::String(Cow::Owned(value.to_owned())))
211            }
212
213            fn visit_string<E>(self, value: String) -> Result<Self::Value, E>
214            where
215                E: serde::de::Error,
216            {
217                Ok(JSON::String(Cow::Owned(value)))
218            }
219
220            fn visit_none<E>(self) -> Result<Self::Value, E>
221            where
222                E: serde::de::Error,
223            {
224                Ok(JSON::Null)
225            }
226
227            fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
228            where
229                D: serde::Deserializer<'de>,
230            {
231                JSON::deserialize(deserializer)
232            }
233
234            fn visit_unit<E>(self) -> Result<Self::Value, E>
235            where
236                E: serde::de::Error,
237            {
238                Ok(JSON::Null)
239            }
240
241            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
242            where
243                A: serde::de::SeqAccess<'de>,
244            {
245                Ok(JSON::Array(iter::from_fn(|| seq.next_element().transpose()).try_collect()?))
246            }
247
248            fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
249            where
250                A: serde::de::MapAccess<'de>,
251            {
252                Ok(JSON::Object(iter::from_fn(|| map.next_entry().transpose()).try_collect()?))
253            }
254        }
255
256        deserializer.deserialize_any(Visitor)
257    }
258}
259
260#[cfg(test)]
261mod test {
262    use std::{borrow::Cow, collections::HashMap, iter};
263
264    use rand::{
265        distributions::{DistString, Distribution, Standard, WeightedIndex},
266        rngs::ThreadRng,
267        thread_rng, Rng,
268    };
269    use serde_json::json;
270
271    use super::JSON;
272
273    #[test]
274    fn test_against_serde() {
275        let string: String =
276            (0u8..0x7fu8).map(|byte| byte as char).chain("lorem ипсум どぉる سيتامعت".chars()).collect();
277        let serde_json_value = serde_json::value::Value::String(string.clone());
278        let json_string = JSON::String(Cow::Owned(string));
279        assert_eq!(serde_json::to_string(&serde_json_value).unwrap(), json_string.to_string());
280    }
281
282    fn sample_json() -> JSON {
283        JSON::Object(HashMap::from([
284            ("array".into(), JSON::Array(vec![JSON::Boolean(true), JSON::String("string".into())])),
285            ("number".into(), JSON::Number(123.4)),
286        ]))
287    }
288
289    #[test]
290    fn serialize() {
291        let ser = serde_json::to_value(sample_json()).unwrap();
292        let value = json!( { "array": [true, "string"], "number": 123.4 });
293        assert_eq!(ser, value);
294    }
295
296    #[test]
297    fn deserialize() {
298        let deser: JSON = serde_json::from_str(r#"{ "array": [true, "string"], "number": 123.4 }"#).unwrap();
299        let json = sample_json();
300        assert_eq!(deser, json);
301    }
302
303    fn random_string(rng: &mut impl Rng) -> String {
304        let len = rng.gen_range(0..64);
305        Standard.sample_string(rng, len)
306    }
307
308    fn random_json<R: Rng>(rng: &mut R) -> JSON {
309        let weights = [1, 1, 3, 3, 3, 3];
310        let generators: &[fn(&mut R) -> JSON] = &[
311            |rng| {
312                let len = rng.gen_range(0..12);
313                JSON::Object(
314                    iter::from_fn(|| Some((Cow::Owned(random_string(rng)), random_json(rng)))).take(len).collect(),
315                )
316            },
317            |rng| {
318                let len = rng.gen_range(0..12);
319                JSON::Array(iter::from_fn(|| Some(random_json(rng))).take(len).collect())
320            },
321            |rng| JSON::String(Cow::Owned(random_string(rng))),
322            |rng| JSON::Number(rng.gen()),
323            |rng| JSON::Boolean(rng.gen()),
324            |_| JSON::Null,
325        ];
326        let dist = WeightedIndex::new(&weights).unwrap();
327        generators[dist.sample(rng)](rng)
328    }
329
330    #[test]
331    fn serde_roundtrip() {
332        let mut rng = thread_rng();
333        for _ in 0..1000 {
334            let json = random_json(&mut rng);
335            let deser = serde_json::from_value(serde_json::to_value(&json).unwrap()).unwrap();
336            assert_eq!(json, deser);
337        }
338    }
339}