Skip to main content

sqlite_vector_rs/
json.rs

1use std::fmt;
2
3use half::f16;
4use serde_json::Value;
5
6use crate::types::{VectorType, VectorTypeError};
7
8/// Errors from JSON conversion.
9#[derive(Debug)]
10pub enum JsonError {
11    Parse(serde_json::Error),
12    NotAnArray,
13    NonNumericElement(usize),
14    Type(VectorTypeError),
15}
16
17impl fmt::Display for JsonError {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        match self {
20            Self::Parse(e) => write!(f, "invalid JSON: {e}"),
21            Self::NotAnArray => write!(f, "expected a JSON array"),
22            Self::NonNumericElement(i) => write!(f, "element {i} is not a number"),
23            Self::Type(e) => write!(f, "{e}"),
24        }
25    }
26}
27
28impl std::error::Error for JsonError {}
29
30/// Parse a JSON array string into a vector blob of the given type.
31pub fn json_to_blob(json: &str, vtype: VectorType) -> Result<Vec<u8>, JsonError> {
32    let value: Value = serde_json::from_str(json).map_err(JsonError::Parse)?;
33    let arr = value.as_array().ok_or(JsonError::NotAnArray)?;
34
35    match vtype {
36        VectorType::Float2 => {
37            let mut values = Vec::with_capacity(arr.len());
38            for (i, v) in arr.iter().enumerate() {
39                let n = v.as_f64().ok_or(JsonError::NonNumericElement(i))?;
40                let h = f16::from_f64(n);
41                if !h.is_finite() {
42                    return Err(JsonError::Type(VectorTypeError::NonFiniteValue));
43                }
44                values.push(h);
45            }
46            Ok(vtype.slice_to_blob(&values))
47        }
48        VectorType::Float4 => {
49            let mut values = Vec::with_capacity(arr.len());
50            for (i, v) in arr.iter().enumerate() {
51                let n = v.as_f64().ok_or(JsonError::NonNumericElement(i))? as f32;
52                if !n.is_finite() {
53                    return Err(JsonError::Type(VectorTypeError::NonFiniteValue));
54                }
55                values.push(n);
56            }
57            Ok(vtype.slice_to_blob(&values))
58        }
59        VectorType::Float8 => {
60            let mut values = Vec::with_capacity(arr.len());
61            for (i, v) in arr.iter().enumerate() {
62                let n = v.as_f64().ok_or(JsonError::NonNumericElement(i))?;
63                if !n.is_finite() {
64                    return Err(JsonError::Type(VectorTypeError::NonFiniteValue));
65                }
66                values.push(n);
67            }
68            Ok(vtype.slice_to_blob(&values))
69        }
70        VectorType::Int1 => {
71            let mut values = Vec::with_capacity(arr.len());
72            for (i, v) in arr.iter().enumerate() {
73                let n = v.as_i64().ok_or(JsonError::NonNumericElement(i))? as i8;
74                values.push(n);
75            }
76            Ok(vtype.slice_to_blob(&values))
77        }
78        VectorType::Int2 => {
79            let mut values = Vec::with_capacity(arr.len());
80            for (i, v) in arr.iter().enumerate() {
81                let n = v.as_i64().ok_or(JsonError::NonNumericElement(i))? as i16;
82                values.push(n);
83            }
84            Ok(vtype.slice_to_blob(&values))
85        }
86        VectorType::Int4 => {
87            let mut values = Vec::with_capacity(arr.len());
88            for (i, v) in arr.iter().enumerate() {
89                let n = v.as_i64().ok_or(JsonError::NonNumericElement(i))? as i32;
90                values.push(n);
91            }
92            Ok(vtype.slice_to_blob(&values))
93        }
94    }
95}
96
97/// Convert a vector blob back to a JSON array string.
98pub fn blob_to_json(blob: &[u8], vtype: VectorType) -> Result<String, JsonError> {
99    let values: Vec<Value> = match vtype {
100        VectorType::Float2 => {
101            let s: &[f16] = vtype.blob_to_slice(blob);
102            s.iter().map(|v| Value::from(v.to_f64())).collect()
103        }
104        VectorType::Float4 => {
105            let s: &[f32] = vtype.blob_to_slice(blob);
106            s.iter().map(|v| Value::from(*v)).collect()
107        }
108        VectorType::Float8 => {
109            let s: &[f64] = vtype.blob_to_slice(blob);
110            s.iter().map(|v| Value::from(*v)).collect()
111        }
112        VectorType::Int1 => {
113            let s: &[i8] = vtype.blob_to_slice(blob);
114            s.iter().map(|v| Value::from(*v as i64)).collect()
115        }
116        VectorType::Int2 => {
117            let s: &[i16] = vtype.blob_to_slice(blob);
118            s.iter().map(|v| Value::from(*v as i64)).collect()
119        }
120        VectorType::Int4 => {
121            let s: &[i32] = vtype.blob_to_slice(blob);
122            s.iter().map(|v| Value::from(*v as i64)).collect()
123        }
124    };
125    serde_json::to_string(&values).map_err(JsonError::Parse)
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    // ------------------------------------------------------------------ //
133    // Helpers
134    // ------------------------------------------------------------------ //
135
136    /// Parse a JSON string produced by blob_to_json into a Vec<f64> for
137    /// comparison, avoiding floating-point Display formatting issues.
138    fn parse_json_floats(s: &str) -> Vec<f64> {
139        let v: Vec<serde_json::Value> = serde_json::from_str(s).unwrap();
140        v.iter().map(|x| x.as_f64().unwrap()).collect()
141    }
142
143    fn parse_json_ints(s: &str) -> Vec<i64> {
144        let v: Vec<serde_json::Value> = serde_json::from_str(s).unwrap();
145        v.iter().map(|x| x.as_i64().unwrap()).collect()
146    }
147
148    // ------------------------------------------------------------------ //
149    // Round-trip tests for all 6 VectorType variants
150    // ------------------------------------------------------------------ //
151
152    #[test]
153    fn round_trip_float2() {
154        // f16 has ~3 significant decimal digits; use values exactly
155        // representable in half precision.
156        let json = "[1.0, -0.5, 0.25]";
157        let blob = json_to_blob(json, VectorType::Float2).unwrap();
158        // 3 elements * 2 bytes each
159        assert_eq!(blob.len(), 6);
160        let out = blob_to_json(&blob, VectorType::Float2).unwrap();
161        let vals = parse_json_floats(&out);
162        assert_eq!(vals.len(), 3);
163        assert!((vals[0] - 1.0).abs() < 1e-3);
164        assert!((vals[1] - (-0.5)).abs() < 1e-3);
165        assert!((vals[2] - 0.25).abs() < 1e-3);
166    }
167
168    #[test]
169    fn round_trip_float4() {
170        let json = "[1.5, -2.25, 0.0, 100.0]";
171        let blob = json_to_blob(json, VectorType::Float4).unwrap();
172        assert_eq!(blob.len(), 16); // 4 elements * 4 bytes
173        let out = blob_to_json(&blob, VectorType::Float4).unwrap();
174        let vals = parse_json_floats(&out);
175        assert_eq!(vals.len(), 4);
176        assert!((vals[0] - 1.5).abs() < 1e-6);
177        assert!((vals[1] - (-2.25)).abs() < 1e-6);
178        assert!((vals[2] - 0.0).abs() < 1e-6);
179        assert!((vals[3] - 100.0).abs() < 1e-3);
180    }
181
182    #[test]
183    fn round_trip_float8() {
184        let json = "[3.141592653589793, -2.718281828459045, 0.0]";
185        let blob = json_to_blob(json, VectorType::Float8).unwrap();
186        assert_eq!(blob.len(), 24); // 3 elements * 8 bytes
187        let out = blob_to_json(&blob, VectorType::Float8).unwrap();
188        let vals = parse_json_floats(&out);
189        assert_eq!(vals.len(), 3);
190        assert!((vals[0] - std::f64::consts::PI).abs() < 1e-15);
191        assert!((vals[1] - (-std::f64::consts::E)).abs() < 1e-15);
192        assert!((vals[2] - 0.0).abs() < 1e-15);
193    }
194
195    #[test]
196    fn round_trip_int1() {
197        let json = "[0, 127, -128, -1, 42]";
198        let blob = json_to_blob(json, VectorType::Int1).unwrap();
199        assert_eq!(blob.len(), 5); // 5 elements * 1 byte
200        let out = blob_to_json(&blob, VectorType::Int1).unwrap();
201        assert_eq!(parse_json_ints(&out), vec![0, 127, -128, -1, 42]);
202    }
203
204    #[test]
205    fn round_trip_int2() {
206        let json = "[0, 32767, -32768, -1, 1000]";
207        let blob = json_to_blob(json, VectorType::Int2).unwrap();
208        assert_eq!(blob.len(), 10); // 5 elements * 2 bytes
209        let out = blob_to_json(&blob, VectorType::Int2).unwrap();
210        assert_eq!(parse_json_ints(&out), vec![0, 32767, -32768, -1, 1000]);
211    }
212
213    #[test]
214    fn round_trip_int4() {
215        let json = "[0, 2147483647, -2147483648, -1, 99999]";
216        let blob = json_to_blob(json, VectorType::Int4).unwrap();
217        assert_eq!(blob.len(), 20); // 5 elements * 4 bytes
218        let out = blob_to_json(&blob, VectorType::Int4).unwrap();
219        assert_eq!(
220            parse_json_ints(&out),
221            vec![0, 2147483647, -2147483648, -1, 99999]
222        );
223    }
224
225    // ------------------------------------------------------------------ //
226    // json_to_blob error cases
227    // ------------------------------------------------------------------ //
228
229    #[test]
230    fn json_to_blob_rejects_object() {
231        let err = json_to_blob("{\"x\": 1}", VectorType::Float4).unwrap_err();
232        assert!(
233            matches!(err, JsonError::NotAnArray),
234            "expected NotAnArray, got {err}"
235        );
236    }
237
238    #[test]
239    fn json_to_blob_rejects_bare_number() {
240        let err = json_to_blob("42", VectorType::Int4).unwrap_err();
241        assert!(matches!(err, JsonError::NotAnArray));
242    }
243
244    #[test]
245    fn json_to_blob_rejects_bare_string() {
246        let err = json_to_blob("\"hello\"", VectorType::Float8).unwrap_err();
247        assert!(matches!(err, JsonError::NotAnArray));
248    }
249
250    #[test]
251    fn json_to_blob_rejects_malformed_json() {
252        let err = json_to_blob("[1, 2,", VectorType::Float4).unwrap_err();
253        assert!(matches!(err, JsonError::Parse(_)));
254    }
255
256    #[test]
257    fn json_to_blob_rejects_string_element_float4() {
258        let err = json_to_blob("[1.0, \"two\", 3.0]", VectorType::Float4).unwrap_err();
259        assert!(matches!(err, JsonError::NonNumericElement(1)));
260    }
261
262    #[test]
263    fn json_to_blob_rejects_string_element_int2() {
264        let err = json_to_blob("[\"bad\", 2]", VectorType::Int2).unwrap_err();
265        assert!(matches!(err, JsonError::NonNumericElement(0)));
266    }
267
268    #[test]
269    fn json_to_blob_rejects_null_element_float2() {
270        // null is not a number; expect NonNumericElement
271        let err = json_to_blob("[1.0, null]", VectorType::Float2).unwrap_err();
272        assert!(matches!(err, JsonError::NonNumericElement(1)));
273    }
274
275    // ------------------------------------------------------------------ //
276    // Empty array / empty blob
277    // ------------------------------------------------------------------ //
278
279    #[test]
280    fn json_to_blob_empty_array_all_types() {
281        for vtype in [
282            VectorType::Float2,
283            VectorType::Float4,
284            VectorType::Float8,
285            VectorType::Int1,
286            VectorType::Int2,
287            VectorType::Int4,
288        ] {
289            let blob = json_to_blob("[]", vtype)
290                .unwrap_or_else(|e| panic!("empty array failed for {vtype:?}: {e}"));
291            assert!(
292                blob.is_empty(),
293                "expected empty blob for {vtype:?}, got {} bytes",
294                blob.len()
295            );
296        }
297    }
298
299    #[test]
300    fn blob_to_json_empty_blob_all_types() {
301        // bytemuck::cast_slice requires the input slice to be aligned for the
302        // target type even when it has zero length.  Build each empty blob via
303        // the canonical helper so the pointer is correctly aligned.
304        let empty_f16: &[f16] = &[];
305        let empty_f32: &[f32] = &[];
306        let empty_f64: &[f64] = &[];
307        let empty_i8: &[i8] = &[];
308        let empty_i16: &[i16] = &[];
309        let empty_i32: &[i32] = &[];
310
311        let cases: &[(&[u8], VectorType)] = &[
312            (bytemuck::cast_slice(empty_f16), VectorType::Float2),
313            (bytemuck::cast_slice(empty_f32), VectorType::Float4),
314            (bytemuck::cast_slice(empty_f64), VectorType::Float8),
315            (bytemuck::cast_slice(empty_i8), VectorType::Int1),
316            (bytemuck::cast_slice(empty_i16), VectorType::Int2),
317            (bytemuck::cast_slice(empty_i32), VectorType::Int4),
318        ];
319
320        for (blob, vtype) in cases {
321            let out = blob_to_json(blob, *vtype)
322                .unwrap_or_else(|e| panic!("empty blob failed for {vtype:?}: {e}"));
323            assert_eq!(out, "[]", "expected '[]' for {vtype:?}, got {out:?}");
324        }
325    }
326
327    // ------------------------------------------------------------------ //
328    // Float precision
329    // ------------------------------------------------------------------ //
330
331    #[test]
332    fn float4_precision_survives_round_trip() {
333        // These values are exactly representable in f32.
334        let inputs: Vec<f32> = vec![0.1, 0.2, 0.3, -0.1, 1.0 / 3.0];
335        let blob = VectorType::Float4.slice_to_blob(&inputs);
336        let out = blob_to_json(&blob, VectorType::Float4).unwrap();
337        let vals = parse_json_floats(&out);
338        for (expected, actual) in inputs.iter().zip(vals.iter()) {
339            // f32 -> f64 -> string -> f64 should recover the f32 value
340            // within f32 epsilon.
341            assert!(
342                (actual - *expected as f64).abs() < f32::EPSILON as f64,
343                "f32 precision lost: expected {expected}, got {actual}"
344            );
345        }
346    }
347
348    #[test]
349    fn float8_full_precision_survives_round_trip() {
350        // Use values whose magnitude keeps relative error within f64::EPSILON.
351        // serde_json uses ryu, which emits the shortest decimal that round-
352        // trips back to the same f64 bits, so the recovered value must equal
353        // the input exactly (zero absolute error for normal magnitudes).
354        let inputs: Vec<f64> = vec![
355            1.0 / 7.0,
356            std::f64::consts::PI,
357            -std::f64::consts::SQRT_2,
358            1.234_567_890_123_456_8e10,
359        ];
360        let blob = VectorType::Float8.slice_to_blob(&inputs);
361        let out = blob_to_json(&blob, VectorType::Float8).unwrap();
362        let vals = parse_json_floats(&out);
363        for (expected, actual) in inputs.iter().zip(vals.iter()) {
364            // Ryu guarantees the shortest round-trip representation, so the
365            // parsed value should be bit-identical to the original.
366            assert_eq!(
367                actual.to_bits(),
368                expected.to_bits(),
369                "f64 bit pattern changed: expected {expected}, got {actual}"
370            );
371        }
372    }
373
374    // ------------------------------------------------------------------ //
375    // Integer edge cases: negatives and zero
376    // ------------------------------------------------------------------ //
377
378    #[test]
379    fn int1_negative_and_zero() {
380        let json = "[-128, -1, 0, 1, 127]";
381        let blob = json_to_blob(json, VectorType::Int1).unwrap();
382        let slice: &[i8] = VectorType::Int1.blob_to_slice(&blob);
383        assert_eq!(slice, &[-128_i8, -1, 0, 1, 127]);
384    }
385
386    #[test]
387    fn int2_negative_and_zero() {
388        let json = "[-32768, -100, 0, 100, 32767]";
389        let blob = json_to_blob(json, VectorType::Int2).unwrap();
390        let slice: &[i16] = VectorType::Int2.blob_to_slice(&blob);
391        assert_eq!(slice, &[-32768_i16, -100, 0, 100, 32767]);
392    }
393
394    #[test]
395    fn int4_negative_and_zero() {
396        let json = "[-2147483648, -1, 0, 1, 2147483647]";
397        let blob = json_to_blob(json, VectorType::Int4).unwrap();
398        let slice: &[i32] = VectorType::Int4.blob_to_slice(&blob);
399        assert_eq!(slice, &[-2147483648_i32, -1, 0, 1, 2147483647]);
400    }
401
402    // ------------------------------------------------------------------ //
403    // Blob size invariants
404    // ------------------------------------------------------------------ //
405
406    #[test]
407    fn blob_size_matches_element_size_times_count() {
408        let cases: &[(&str, VectorType, usize, usize)] = &[
409            ("[1.0]", VectorType::Float2, 1, 2),
410            ("[1.0, 2.0]", VectorType::Float4, 2, 4),
411            ("[1.0, 2.0, 3.0]", VectorType::Float8, 3, 8),
412            ("[1]", VectorType::Int1, 1, 1),
413            ("[1, 2]", VectorType::Int2, 2, 2),
414            ("[1, 2, 3, 4]", VectorType::Int4, 4, 4),
415        ];
416        for (json, vtype, count, elem_bytes) in cases {
417            let blob = json_to_blob(json, *vtype).unwrap();
418            assert_eq!(
419                blob.len(),
420                count * elem_bytes,
421                "{vtype:?}: expected {} bytes, got {}",
422                count * elem_bytes,
423                blob.len()
424            );
425        }
426    }
427}