Skip to main content

sqlite_vector_rs/
arrow_io.rs

1use std::io::Cursor;
2use std::sync::Arc;
3
4use arrow_array::*;
5use arrow_array::{ArrayRef, FixedSizeListArray, RecordBatch};
6use arrow_ipc::reader::StreamReader;
7use arrow_ipc::writer::StreamWriter;
8use arrow_schema::{DataType, Field, Schema};
9
10use crate::types::VectorType;
11
12#[derive(Debug)]
13pub struct ArrowError(pub String);
14
15impl std::fmt::Display for ArrowError {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        write!(f, "arrow error: {}", self.0)
18    }
19}
20
21impl std::error::Error for ArrowError {}
22
23impl From<arrow_schema::ArrowError> for ArrowError {
24    fn from(e: arrow_schema::ArrowError) -> Self {
25        Self(e.to_string())
26    }
27}
28
29/// Convert a list of raw vector blobs into an Arrow IPC byte buffer.
30pub fn vectors_to_arrow_ipc(
31    blobs: &[Vec<u8>],
32    vtype: VectorType,
33    dim: usize,
34) -> Result<Vec<u8>, ArrowError> {
35    let (inner_dt, values_array) = build_values_array(blobs, vtype, dim)?;
36
37    let field = Arc::new(Field::new("item", inner_dt, true));
38    let list_array = FixedSizeListArray::new(field, dim as i32, values_array, None);
39
40    let schema = Schema::new(vec![Field::new(
41        "vector",
42        list_array.data_type().clone(),
43        false,
44    )]);
45    let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(list_array)])
46        .map_err(|e| ArrowError(e.to_string()))?;
47
48    let mut buf = Vec::new();
49    let mut writer =
50        StreamWriter::try_new(&mut buf, &schema).map_err(|e| ArrowError(e.to_string()))?;
51    writer
52        .write(&batch)
53        .map_err(|e| ArrowError(e.to_string()))?;
54    writer.finish().map_err(|e| ArrowError(e.to_string()))?;
55    drop(writer);
56
57    Ok(buf)
58}
59
60/// Parse an Arrow IPC byte buffer back into raw vector blobs.
61///
62/// If `dim` is 0 the dimension is inferred from the FixedSizeListArray's
63/// `value_length()`, which is encoded in the Arrow schema.
64pub fn arrow_ipc_to_vectors(
65    ipc_bytes: &[u8],
66    vtype: VectorType,
67    dim: usize,
68) -> Result<Vec<Vec<u8>>, ArrowError> {
69    let reader = StreamReader::try_new(Cursor::new(ipc_bytes), None)
70        .map_err(|e| ArrowError(e.to_string()))?;
71
72    let mut all_blobs = Vec::new();
73    for batch_result in reader {
74        let batch = batch_result.map_err(|e| ArrowError(e.to_string()))?;
75        let list_col = batch
76            .column(0)
77            .as_any()
78            .downcast_ref::<FixedSizeListArray>()
79            .ok_or_else(|| ArrowError("expected FixedSizeListArray".into()))?;
80
81        // Auto-detect dimension from the Arrow schema when caller passes 0.
82        let effective_dim = if dim == 0 {
83            list_col.value_length() as usize
84        } else {
85            dim
86        };
87
88        for i in 0..list_col.len() {
89            let sub = list_col.value(i);
90            let blob = extract_blob_from_array(&sub, vtype, effective_dim)?;
91            all_blobs.push(blob);
92        }
93    }
94
95    Ok(all_blobs)
96}
97
98fn build_values_array(
99    blobs: &[Vec<u8>],
100    vtype: VectorType,
101    dim: usize,
102) -> Result<(DataType, ArrayRef), ArrowError> {
103    let total_elements = blobs.len() * dim;
104    match vtype {
105        VectorType::Float4 => {
106            let mut flat = Vec::with_capacity(total_elements);
107            for blob in blobs {
108                let v: &[f32] = vtype.blob_to_slice(blob);
109                flat.extend_from_slice(v);
110            }
111            Ok((DataType::Float32, Arc::new(Float32Array::from(flat))))
112        }
113        VectorType::Float8 => {
114            let mut flat = Vec::with_capacity(total_elements);
115            for blob in blobs {
116                let v: &[f64] = vtype.blob_to_slice(blob);
117                flat.extend_from_slice(v);
118            }
119            Ok((DataType::Float64, Arc::new(Float64Array::from(flat))))
120        }
121        VectorType::Float2 => {
122            let mut flat = Vec::with_capacity(total_elements);
123            for blob in blobs {
124                let v: &[half::f16] = vtype.blob_to_slice(blob);
125                flat.extend(v.iter().copied());
126            }
127            Ok((DataType::Float16, Arc::new(Float16Array::from(flat))))
128        }
129        VectorType::Int1 => {
130            let mut flat = Vec::with_capacity(total_elements);
131            for blob in blobs {
132                let v: &[i8] = vtype.blob_to_slice(blob);
133                flat.extend_from_slice(v);
134            }
135            Ok((DataType::Int8, Arc::new(Int8Array::from(flat))))
136        }
137        VectorType::Int2 => {
138            let mut flat = Vec::with_capacity(total_elements);
139            for blob in blobs {
140                let v: &[i16] = vtype.blob_to_slice(blob);
141                flat.extend_from_slice(v);
142            }
143            Ok((DataType::Int16, Arc::new(Int16Array::from(flat))))
144        }
145        VectorType::Int4 => {
146            let mut flat = Vec::with_capacity(total_elements);
147            for blob in blobs {
148                let v: &[i32] = vtype.blob_to_slice(blob);
149                flat.extend_from_slice(v);
150            }
151            Ok((DataType::Int32, Arc::new(Int32Array::from(flat))))
152        }
153    }
154}
155
156fn extract_blob_from_array(
157    array: &ArrayRef,
158    vtype: VectorType,
159    dim: usize,
160) -> Result<Vec<u8>, ArrowError> {
161    match vtype {
162        VectorType::Float4 => {
163            let a = array
164                .as_any()
165                .downcast_ref::<Float32Array>()
166                .ok_or_else(|| ArrowError("expected Float32Array".into()))?;
167            let values: Vec<f32> = (0..dim).map(|i| a.value(i)).collect();
168            Ok(vtype.slice_to_blob(&values))
169        }
170        VectorType::Float8 => {
171            let a = array
172                .as_any()
173                .downcast_ref::<Float64Array>()
174                .ok_or_else(|| ArrowError("expected Float64Array".into()))?;
175            let values: Vec<f64> = (0..dim).map(|i| a.value(i)).collect();
176            Ok(vtype.slice_to_blob(&values))
177        }
178        VectorType::Float2 => {
179            let a = array
180                .as_any()
181                .downcast_ref::<Float16Array>()
182                .ok_or_else(|| ArrowError("expected Float16Array".into()))?;
183            let values: Vec<half::f16> = (0..dim).map(|i| a.value(i)).collect();
184            Ok(vtype.slice_to_blob(&values))
185        }
186        VectorType::Int1 => {
187            let a = array
188                .as_any()
189                .downcast_ref::<Int8Array>()
190                .ok_or_else(|| ArrowError("expected Int8Array".into()))?;
191            let values: Vec<i8> = (0..dim).map(|i| a.value(i)).collect();
192            Ok(vtype.slice_to_blob(&values))
193        }
194        VectorType::Int2 => {
195            let a = array
196                .as_any()
197                .downcast_ref::<Int16Array>()
198                .ok_or_else(|| ArrowError("expected Int16Array".into()))?;
199            let values: Vec<i16> = (0..dim).map(|i| a.value(i)).collect();
200            Ok(vtype.slice_to_blob(&values))
201        }
202        VectorType::Int4 => {
203            let a = array
204                .as_any()
205                .downcast_ref::<Int32Array>()
206                .ok_or_else(|| ArrowError("expected Int32Array".into()))?;
207            let values: Vec<i32> = (0..dim).map(|i| a.value(i)).collect();
208            Ok(vtype.slice_to_blob(&values))
209        }
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::types::VectorType;
217    use half::f16;
218
219    // ----------------------------------------------------------------
220    // helpers
221    // ----------------------------------------------------------------
222
223    /// Build a Float4 blob from a slice of f32 values.
224    fn f32_blob(values: &[f32]) -> Vec<u8> {
225        VectorType::Float4.slice_to_blob(values)
226    }
227
228    /// Build a Float8 blob from a slice of f64 values.
229    fn f64_blob(values: &[f64]) -> Vec<u8> {
230        VectorType::Float8.slice_to_blob(values)
231    }
232
233    /// Build an Int1 blob from a slice of i8 values.
234    fn i8_blob(values: &[i8]) -> Vec<u8> {
235        VectorType::Int1.slice_to_blob(values)
236    }
237
238    /// Build an Int2 blob from a slice of i16 values.
239    fn i16_blob(values: &[i16]) -> Vec<u8> {
240        VectorType::Int2.slice_to_blob(values)
241    }
242
243    /// Build an Int4 blob from a slice of i32 values.
244    fn i32_blob(values: &[i32]) -> Vec<u8> {
245        VectorType::Int4.slice_to_blob(values)
246    }
247
248    /// Build a Float2 blob from a slice of f16 values.
249    fn f16_blob(values: &[f16]) -> Vec<u8> {
250        VectorType::Float2.slice_to_blob(values)
251    }
252
253    // ----------------------------------------------------------------
254    // 1. Round-trip Float4
255    // ----------------------------------------------------------------
256
257    #[test]
258    fn round_trip_float4() {
259        let blobs = vec![f32_blob(&[1.0_f32, 2.0, 3.0])];
260        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float4, 3).unwrap();
261        let result = arrow_ipc_to_vectors(&ipc, VectorType::Float4, 3).unwrap();
262        assert_eq!(result, blobs);
263    }
264
265    // ----------------------------------------------------------------
266    // 2. Round-trip Float8
267    // ----------------------------------------------------------------
268
269    #[test]
270    fn round_trip_float8() {
271        let blobs = vec![f64_blob(&[1.0_f64, -2.5, 3.125])];
272        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float8, 3).unwrap();
273        let result = arrow_ipc_to_vectors(&ipc, VectorType::Float8, 3).unwrap();
274        assert_eq!(result, blobs);
275    }
276
277    // ----------------------------------------------------------------
278    // 3. Round-trip Int1, Int2, Int4
279    // ----------------------------------------------------------------
280
281    #[test]
282    fn round_trip_int1() {
283        let blobs = vec![i8_blob(&[i8::MIN, 0, i8::MAX])];
284        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int1, 3).unwrap();
285        let result = arrow_ipc_to_vectors(&ipc, VectorType::Int1, 3).unwrap();
286        assert_eq!(result, blobs);
287    }
288
289    #[test]
290    fn round_trip_int2() {
291        let blobs = vec![i16_blob(&[i16::MIN, 0, i16::MAX])];
292        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int2, 3).unwrap();
293        let result = arrow_ipc_to_vectors(&ipc, VectorType::Int2, 3).unwrap();
294        assert_eq!(result, blobs);
295    }
296
297    #[test]
298    fn round_trip_int4() {
299        let blobs = vec![i32_blob(&[i32::MIN, 0, i32::MAX])];
300        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int4, 3).unwrap();
301        let result = arrow_ipc_to_vectors(&ipc, VectorType::Int4, 3).unwrap();
302        assert_eq!(result, blobs);
303    }
304
305    // ----------------------------------------------------------------
306    // 4. Round-trip Float2 (half precision)
307    // ----------------------------------------------------------------
308
309    #[test]
310    fn round_trip_float2() {
311        let values = vec![f16::from_f32(1.0), f16::from_f32(-0.5), f16::from_f32(0.25)];
312        let blobs = vec![f16_blob(&values)];
313        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float2, 3).unwrap();
314        let result = arrow_ipc_to_vectors(&ipc, VectorType::Float2, 3).unwrap();
315        assert_eq!(result, blobs);
316    }
317
318    // ----------------------------------------------------------------
319    // 5. Empty blobs list → empty IPC → empty result
320    // ----------------------------------------------------------------
321
322    #[test]
323    fn empty_blobs_round_trip() {
324        // vectors_to_arrow_ipc with an empty slice should produce valid IPC
325        // that decodes back to an empty vector list.
326        //
327        // Note: dim must be non-zero for Arrow schema encoding even when
328        // there are no rows; we use dim=4 as a representative dimension.
329        let blobs: Vec<Vec<u8>> = vec![];
330        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float4, 4).unwrap();
331        assert!(
332            !ipc.is_empty(),
333            "IPC buffer must contain at least the schema header"
334        );
335        let result = arrow_ipc_to_vectors(&ipc, VectorType::Float4, 4).unwrap();
336        assert!(result.is_empty());
337    }
338
339    // ----------------------------------------------------------------
340    // 6. Dim auto-detection: encode with dim=3, decode with dim=0
341    // ----------------------------------------------------------------
342
343    #[test]
344    fn dim_auto_detection_float4() {
345        let blobs = vec![f32_blob(&[10.0_f32, 20.0, 30.0])];
346        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float4, 3).unwrap();
347        // Pass dim=0 to trigger auto-detection from the Arrow schema.
348        let result = arrow_ipc_to_vectors(&ipc, VectorType::Float4, 0).unwrap();
349        assert_eq!(result, blobs);
350    }
351
352    #[test]
353    fn dim_auto_detection_int2() {
354        let blobs = vec![i16_blob(&[1_i16, 2, 3])];
355        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int2, 3).unwrap();
356        let result = arrow_ipc_to_vectors(&ipc, VectorType::Int2, 0).unwrap();
357        assert_eq!(result, blobs);
358    }
359
360    // ----------------------------------------------------------------
361    // 7. Multiple vectors round-trip (5 vectors)
362    // ----------------------------------------------------------------
363
364    #[test]
365    fn multiple_vectors_float4() {
366        let blobs: Vec<Vec<u8>> = (0..5_u32)
367            .map(|i| {
368                let base = i as f32;
369                f32_blob(&[base, base + 1.0, base + 2.0, base + 3.0])
370            })
371            .collect();
372        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float4, 4).unwrap();
373        let result = arrow_ipc_to_vectors(&ipc, VectorType::Float4, 4).unwrap();
374        assert_eq!(result.len(), 5);
375        assert_eq!(result, blobs);
376    }
377
378    #[test]
379    fn multiple_vectors_int4() {
380        let blobs: Vec<Vec<u8>> = (0..5_i32)
381            .map(|i| i32_blob(&[i * 10, i * 10 + 1, i * 10 + 2]))
382            .collect();
383        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int4, 3).unwrap();
384        let result = arrow_ipc_to_vectors(&ipc, VectorType::Int4, 3).unwrap();
385        assert_eq!(result.len(), 5);
386        assert_eq!(result, blobs);
387    }
388
389    // ----------------------------------------------------------------
390    // 8. Single vector round-trip
391    // ----------------------------------------------------------------
392
393    #[test]
394    fn single_vector_float8() {
395        let blobs = vec![f64_blob(&[std::f64::consts::PI, std::f64::consts::E])];
396        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float8, 2).unwrap();
397        let result = arrow_ipc_to_vectors(&ipc, VectorType::Float8, 2).unwrap();
398        assert_eq!(result.len(), 1);
399        assert_eq!(result, blobs);
400    }
401
402    #[test]
403    fn single_vector_int1() {
404        let blobs = vec![i8_blob(&[-1_i8, 0, 127])];
405        let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int1, 3).unwrap();
406        let result = arrow_ipc_to_vectors(&ipc, VectorType::Int1, 3).unwrap();
407        assert_eq!(result.len(), 1);
408        assert_eq!(result, blobs);
409    }
410}