1use std::fmt;
2
3use half::f16;
4use serde_json::Value;
5
6use crate::types::{VectorType, VectorTypeError};
7
8#[derive(Debug)]
10pub enum JsonError {
11 Parse(serde_json::Error),
12 NotAnArray,
13 NonNumericElement(usize),
14 Type(VectorTypeError),
15}
16
17impl fmt::Display for JsonError {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 match self {
20 Self::Parse(e) => write!(f, "invalid JSON: {e}"),
21 Self::NotAnArray => write!(f, "expected a JSON array"),
22 Self::NonNumericElement(i) => write!(f, "element {i} is not a number"),
23 Self::Type(e) => write!(f, "{e}"),
24 }
25 }
26}
27
28impl std::error::Error for JsonError {}
29
30pub fn json_to_blob(json: &str, vtype: VectorType) -> Result<Vec<u8>, JsonError> {
32 let value: Value = serde_json::from_str(json).map_err(JsonError::Parse)?;
33 let arr = value.as_array().ok_or(JsonError::NotAnArray)?;
34
35 match vtype {
36 VectorType::Float2 => {
37 let mut values = Vec::with_capacity(arr.len());
38 for (i, v) in arr.iter().enumerate() {
39 let n = v.as_f64().ok_or(JsonError::NonNumericElement(i))?;
40 let h = f16::from_f64(n);
41 if !h.is_finite() {
42 return Err(JsonError::Type(VectorTypeError::NonFiniteValue));
43 }
44 values.push(h);
45 }
46 Ok(vtype.slice_to_blob(&values))
47 }
48 VectorType::Float4 => {
49 let mut values = Vec::with_capacity(arr.len());
50 for (i, v) in arr.iter().enumerate() {
51 let n = v.as_f64().ok_or(JsonError::NonNumericElement(i))? as f32;
52 if !n.is_finite() {
53 return Err(JsonError::Type(VectorTypeError::NonFiniteValue));
54 }
55 values.push(n);
56 }
57 Ok(vtype.slice_to_blob(&values))
58 }
59 VectorType::Float8 => {
60 let mut values = Vec::with_capacity(arr.len());
61 for (i, v) in arr.iter().enumerate() {
62 let n = v.as_f64().ok_or(JsonError::NonNumericElement(i))?;
63 if !n.is_finite() {
64 return Err(JsonError::Type(VectorTypeError::NonFiniteValue));
65 }
66 values.push(n);
67 }
68 Ok(vtype.slice_to_blob(&values))
69 }
70 VectorType::Int1 => {
71 let mut values = Vec::with_capacity(arr.len());
72 for (i, v) in arr.iter().enumerate() {
73 let n = v.as_i64().ok_or(JsonError::NonNumericElement(i))? as i8;
74 values.push(n);
75 }
76 Ok(vtype.slice_to_blob(&values))
77 }
78 VectorType::Int2 => {
79 let mut values = Vec::with_capacity(arr.len());
80 for (i, v) in arr.iter().enumerate() {
81 let n = v.as_i64().ok_or(JsonError::NonNumericElement(i))? as i16;
82 values.push(n);
83 }
84 Ok(vtype.slice_to_blob(&values))
85 }
86 VectorType::Int4 => {
87 let mut values = Vec::with_capacity(arr.len());
88 for (i, v) in arr.iter().enumerate() {
89 let n = v.as_i64().ok_or(JsonError::NonNumericElement(i))? as i32;
90 values.push(n);
91 }
92 Ok(vtype.slice_to_blob(&values))
93 }
94 }
95}
96
97pub fn blob_to_json(blob: &[u8], vtype: VectorType) -> Result<String, JsonError> {
99 let values: Vec<Value> = match vtype {
100 VectorType::Float2 => {
101 let s: &[f16] = vtype.blob_to_slice(blob);
102 s.iter().map(|v| Value::from(v.to_f64())).collect()
103 }
104 VectorType::Float4 => {
105 let s: &[f32] = vtype.blob_to_slice(blob);
106 s.iter().map(|v| Value::from(*v)).collect()
107 }
108 VectorType::Float8 => {
109 let s: &[f64] = vtype.blob_to_slice(blob);
110 s.iter().map(|v| Value::from(*v)).collect()
111 }
112 VectorType::Int1 => {
113 let s: &[i8] = vtype.blob_to_slice(blob);
114 s.iter().map(|v| Value::from(*v as i64)).collect()
115 }
116 VectorType::Int2 => {
117 let s: &[i16] = vtype.blob_to_slice(blob);
118 s.iter().map(|v| Value::from(*v as i64)).collect()
119 }
120 VectorType::Int4 => {
121 let s: &[i32] = vtype.blob_to_slice(blob);
122 s.iter().map(|v| Value::from(*v as i64)).collect()
123 }
124 };
125 serde_json::to_string(&values).map_err(JsonError::Parse)
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 fn parse_json_floats(s: &str) -> Vec<f64> {
139 let v: Vec<serde_json::Value> = serde_json::from_str(s).unwrap();
140 v.iter().map(|x| x.as_f64().unwrap()).collect()
141 }
142
143 fn parse_json_ints(s: &str) -> Vec<i64> {
144 let v: Vec<serde_json::Value> = serde_json::from_str(s).unwrap();
145 v.iter().map(|x| x.as_i64().unwrap()).collect()
146 }
147
148 #[test]
153 fn round_trip_float2() {
154 let json = "[1.0, -0.5, 0.25]";
157 let blob = json_to_blob(json, VectorType::Float2).unwrap();
158 assert_eq!(blob.len(), 6);
160 let out = blob_to_json(&blob, VectorType::Float2).unwrap();
161 let vals = parse_json_floats(&out);
162 assert_eq!(vals.len(), 3);
163 assert!((vals[0] - 1.0).abs() < 1e-3);
164 assert!((vals[1] - (-0.5)).abs() < 1e-3);
165 assert!((vals[2] - 0.25).abs() < 1e-3);
166 }
167
168 #[test]
169 fn round_trip_float4() {
170 let json = "[1.5, -2.25, 0.0, 100.0]";
171 let blob = json_to_blob(json, VectorType::Float4).unwrap();
172 assert_eq!(blob.len(), 16); let out = blob_to_json(&blob, VectorType::Float4).unwrap();
174 let vals = parse_json_floats(&out);
175 assert_eq!(vals.len(), 4);
176 assert!((vals[0] - 1.5).abs() < 1e-6);
177 assert!((vals[1] - (-2.25)).abs() < 1e-6);
178 assert!((vals[2] - 0.0).abs() < 1e-6);
179 assert!((vals[3] - 100.0).abs() < 1e-3);
180 }
181
182 #[test]
183 fn round_trip_float8() {
184 let json = "[3.141592653589793, -2.718281828459045, 0.0]";
185 let blob = json_to_blob(json, VectorType::Float8).unwrap();
186 assert_eq!(blob.len(), 24); let out = blob_to_json(&blob, VectorType::Float8).unwrap();
188 let vals = parse_json_floats(&out);
189 assert_eq!(vals.len(), 3);
190 assert!((vals[0] - std::f64::consts::PI).abs() < 1e-15);
191 assert!((vals[1] - (-std::f64::consts::E)).abs() < 1e-15);
192 assert!((vals[2] - 0.0).abs() < 1e-15);
193 }
194
195 #[test]
196 fn round_trip_int1() {
197 let json = "[0, 127, -128, -1, 42]";
198 let blob = json_to_blob(json, VectorType::Int1).unwrap();
199 assert_eq!(blob.len(), 5); let out = blob_to_json(&blob, VectorType::Int1).unwrap();
201 assert_eq!(parse_json_ints(&out), vec![0, 127, -128, -1, 42]);
202 }
203
204 #[test]
205 fn round_trip_int2() {
206 let json = "[0, 32767, -32768, -1, 1000]";
207 let blob = json_to_blob(json, VectorType::Int2).unwrap();
208 assert_eq!(blob.len(), 10); let out = blob_to_json(&blob, VectorType::Int2).unwrap();
210 assert_eq!(parse_json_ints(&out), vec![0, 32767, -32768, -1, 1000]);
211 }
212
213 #[test]
214 fn round_trip_int4() {
215 let json = "[0, 2147483647, -2147483648, -1, 99999]";
216 let blob = json_to_blob(json, VectorType::Int4).unwrap();
217 assert_eq!(blob.len(), 20); let out = blob_to_json(&blob, VectorType::Int4).unwrap();
219 assert_eq!(
220 parse_json_ints(&out),
221 vec![0, 2147483647, -2147483648, -1, 99999]
222 );
223 }
224
225 #[test]
230 fn json_to_blob_rejects_object() {
231 let err = json_to_blob("{\"x\": 1}", VectorType::Float4).unwrap_err();
232 assert!(
233 matches!(err, JsonError::NotAnArray),
234 "expected NotAnArray, got {err}"
235 );
236 }
237
238 #[test]
239 fn json_to_blob_rejects_bare_number() {
240 let err = json_to_blob("42", VectorType::Int4).unwrap_err();
241 assert!(matches!(err, JsonError::NotAnArray));
242 }
243
244 #[test]
245 fn json_to_blob_rejects_bare_string() {
246 let err = json_to_blob("\"hello\"", VectorType::Float8).unwrap_err();
247 assert!(matches!(err, JsonError::NotAnArray));
248 }
249
250 #[test]
251 fn json_to_blob_rejects_malformed_json() {
252 let err = json_to_blob("[1, 2,", VectorType::Float4).unwrap_err();
253 assert!(matches!(err, JsonError::Parse(_)));
254 }
255
256 #[test]
257 fn json_to_blob_rejects_string_element_float4() {
258 let err = json_to_blob("[1.0, \"two\", 3.0]", VectorType::Float4).unwrap_err();
259 assert!(matches!(err, JsonError::NonNumericElement(1)));
260 }
261
262 #[test]
263 fn json_to_blob_rejects_string_element_int2() {
264 let err = json_to_blob("[\"bad\", 2]", VectorType::Int2).unwrap_err();
265 assert!(matches!(err, JsonError::NonNumericElement(0)));
266 }
267
268 #[test]
269 fn json_to_blob_rejects_null_element_float2() {
270 let err = json_to_blob("[1.0, null]", VectorType::Float2).unwrap_err();
272 assert!(matches!(err, JsonError::NonNumericElement(1)));
273 }
274
275 #[test]
280 fn json_to_blob_empty_array_all_types() {
281 for vtype in [
282 VectorType::Float2,
283 VectorType::Float4,
284 VectorType::Float8,
285 VectorType::Int1,
286 VectorType::Int2,
287 VectorType::Int4,
288 ] {
289 let blob = json_to_blob("[]", vtype)
290 .unwrap_or_else(|e| panic!("empty array failed for {vtype:?}: {e}"));
291 assert!(
292 blob.is_empty(),
293 "expected empty blob for {vtype:?}, got {} bytes",
294 blob.len()
295 );
296 }
297 }
298
299 #[test]
300 fn blob_to_json_empty_blob_all_types() {
301 let empty_f16: &[f16] = &[];
305 let empty_f32: &[f32] = &[];
306 let empty_f64: &[f64] = &[];
307 let empty_i8: &[i8] = &[];
308 let empty_i16: &[i16] = &[];
309 let empty_i32: &[i32] = &[];
310
311 let cases: &[(&[u8], VectorType)] = &[
312 (bytemuck::cast_slice(empty_f16), VectorType::Float2),
313 (bytemuck::cast_slice(empty_f32), VectorType::Float4),
314 (bytemuck::cast_slice(empty_f64), VectorType::Float8),
315 (bytemuck::cast_slice(empty_i8), VectorType::Int1),
316 (bytemuck::cast_slice(empty_i16), VectorType::Int2),
317 (bytemuck::cast_slice(empty_i32), VectorType::Int4),
318 ];
319
320 for (blob, vtype) in cases {
321 let out = blob_to_json(blob, *vtype)
322 .unwrap_or_else(|e| panic!("empty blob failed for {vtype:?}: {e}"));
323 assert_eq!(out, "[]", "expected '[]' for {vtype:?}, got {out:?}");
324 }
325 }
326
327 #[test]
332 fn float4_precision_survives_round_trip() {
333 let inputs: Vec<f32> = vec![0.1, 0.2, 0.3, -0.1, 1.0 / 3.0];
335 let blob = VectorType::Float4.slice_to_blob(&inputs);
336 let out = blob_to_json(&blob, VectorType::Float4).unwrap();
337 let vals = parse_json_floats(&out);
338 for (expected, actual) in inputs.iter().zip(vals.iter()) {
339 assert!(
342 (actual - *expected as f64).abs() < f32::EPSILON as f64,
343 "f32 precision lost: expected {expected}, got {actual}"
344 );
345 }
346 }
347
348 #[test]
349 fn float8_full_precision_survives_round_trip() {
350 let inputs: Vec<f64> = vec![
355 1.0 / 7.0,
356 std::f64::consts::PI,
357 -std::f64::consts::SQRT_2,
358 1.234_567_890_123_456_8e10,
359 ];
360 let blob = VectorType::Float8.slice_to_blob(&inputs);
361 let out = blob_to_json(&blob, VectorType::Float8).unwrap();
362 let vals = parse_json_floats(&out);
363 for (expected, actual) in inputs.iter().zip(vals.iter()) {
364 assert_eq!(
367 actual.to_bits(),
368 expected.to_bits(),
369 "f64 bit pattern changed: expected {expected}, got {actual}"
370 );
371 }
372 }
373
374 #[test]
379 fn int1_negative_and_zero() {
380 let json = "[-128, -1, 0, 1, 127]";
381 let blob = json_to_blob(json, VectorType::Int1).unwrap();
382 let slice: &[i8] = VectorType::Int1.blob_to_slice(&blob);
383 assert_eq!(slice, &[-128_i8, -1, 0, 1, 127]);
384 }
385
386 #[test]
387 fn int2_negative_and_zero() {
388 let json = "[-32768, -100, 0, 100, 32767]";
389 let blob = json_to_blob(json, VectorType::Int2).unwrap();
390 let slice: &[i16] = VectorType::Int2.blob_to_slice(&blob);
391 assert_eq!(slice, &[-32768_i16, -100, 0, 100, 32767]);
392 }
393
394 #[test]
395 fn int4_negative_and_zero() {
396 let json = "[-2147483648, -1, 0, 1, 2147483647]";
397 let blob = json_to_blob(json, VectorType::Int4).unwrap();
398 let slice: &[i32] = VectorType::Int4.blob_to_slice(&blob);
399 assert_eq!(slice, &[-2147483648_i32, -1, 0, 1, 2147483647]);
400 }
401
402 #[test]
407 fn blob_size_matches_element_size_times_count() {
408 let cases: &[(&str, VectorType, usize, usize)] = &[
409 ("[1.0]", VectorType::Float2, 1, 2),
410 ("[1.0, 2.0]", VectorType::Float4, 2, 4),
411 ("[1.0, 2.0, 3.0]", VectorType::Float8, 3, 8),
412 ("[1]", VectorType::Int1, 1, 1),
413 ("[1, 2]", VectorType::Int2, 2, 2),
414 ("[1, 2, 3, 4]", VectorType::Int4, 4, 4),
415 ];
416 for (json, vtype, count, elem_bytes) in cases {
417 let blob = json_to_blob(json, *vtype).unwrap();
418 assert_eq!(
419 blob.len(),
420 count * elem_bytes,
421 "{vtype:?}: expected {} bytes, got {}",
422 count * elem_bytes,
423 blob.len()
424 );
425 }
426 }
427}