1use std::io::Cursor;
2use std::sync::Arc;
3
4use arrow_array::*;
5use arrow_array::{ArrayRef, FixedSizeListArray, RecordBatch};
6use arrow_ipc::reader::StreamReader;
7use arrow_ipc::writer::StreamWriter;
8use arrow_schema::{DataType, Field, Schema};
9
10use crate::types::VectorType;
11
12#[derive(Debug)]
13pub struct ArrowError(pub String);
14
15impl std::fmt::Display for ArrowError {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 write!(f, "arrow error: {}", self.0)
18 }
19}
20
21impl std::error::Error for ArrowError {}
22
23impl From<arrow_schema::ArrowError> for ArrowError {
24 fn from(e: arrow_schema::ArrowError) -> Self {
25 Self(e.to_string())
26 }
27}
28
29pub fn vectors_to_arrow_ipc(
31 blobs: &[Vec<u8>],
32 vtype: VectorType,
33 dim: usize,
34) -> Result<Vec<u8>, ArrowError> {
35 let (inner_dt, values_array) = build_values_array(blobs, vtype, dim)?;
36
37 let field = Arc::new(Field::new("item", inner_dt, true));
38 let list_array = FixedSizeListArray::new(field, dim as i32, values_array, None);
39
40 let schema = Schema::new(vec![Field::new(
41 "vector",
42 list_array.data_type().clone(),
43 false,
44 )]);
45 let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(list_array)])
46 .map_err(|e| ArrowError(e.to_string()))?;
47
48 let mut buf = Vec::new();
49 let mut writer =
50 StreamWriter::try_new(&mut buf, &schema).map_err(|e| ArrowError(e.to_string()))?;
51 writer
52 .write(&batch)
53 .map_err(|e| ArrowError(e.to_string()))?;
54 writer.finish().map_err(|e| ArrowError(e.to_string()))?;
55 drop(writer);
56
57 Ok(buf)
58}
59
60pub fn arrow_ipc_to_vectors(
65 ipc_bytes: &[u8],
66 vtype: VectorType,
67 dim: usize,
68) -> Result<Vec<Vec<u8>>, ArrowError> {
69 let reader = StreamReader::try_new(Cursor::new(ipc_bytes), None)
70 .map_err(|e| ArrowError(e.to_string()))?;
71
72 let mut all_blobs = Vec::new();
73 for batch_result in reader {
74 let batch = batch_result.map_err(|e| ArrowError(e.to_string()))?;
75 let list_col = batch
76 .column(0)
77 .as_any()
78 .downcast_ref::<FixedSizeListArray>()
79 .ok_or_else(|| ArrowError("expected FixedSizeListArray".into()))?;
80
81 let effective_dim = if dim == 0 {
83 list_col.value_length() as usize
84 } else {
85 dim
86 };
87
88 for i in 0..list_col.len() {
89 let sub = list_col.value(i);
90 let blob = extract_blob_from_array(&sub, vtype, effective_dim)?;
91 all_blobs.push(blob);
92 }
93 }
94
95 Ok(all_blobs)
96}
97
98fn build_values_array(
99 blobs: &[Vec<u8>],
100 vtype: VectorType,
101 dim: usize,
102) -> Result<(DataType, ArrayRef), ArrowError> {
103 let total_elements = blobs.len() * dim;
104 match vtype {
105 VectorType::Float4 => {
106 let mut flat = Vec::with_capacity(total_elements);
107 for blob in blobs {
108 let v: &[f32] = vtype.blob_to_slice(blob);
109 flat.extend_from_slice(v);
110 }
111 Ok((DataType::Float32, Arc::new(Float32Array::from(flat))))
112 }
113 VectorType::Float8 => {
114 let mut flat = Vec::with_capacity(total_elements);
115 for blob in blobs {
116 let v: &[f64] = vtype.blob_to_slice(blob);
117 flat.extend_from_slice(v);
118 }
119 Ok((DataType::Float64, Arc::new(Float64Array::from(flat))))
120 }
121 VectorType::Float2 => {
122 let mut flat = Vec::with_capacity(total_elements);
123 for blob in blobs {
124 let v: &[half::f16] = vtype.blob_to_slice(blob);
125 flat.extend(v.iter().copied());
126 }
127 Ok((DataType::Float16, Arc::new(Float16Array::from(flat))))
128 }
129 VectorType::Int1 => {
130 let mut flat = Vec::with_capacity(total_elements);
131 for blob in blobs {
132 let v: &[i8] = vtype.blob_to_slice(blob);
133 flat.extend_from_slice(v);
134 }
135 Ok((DataType::Int8, Arc::new(Int8Array::from(flat))))
136 }
137 VectorType::Int2 => {
138 let mut flat = Vec::with_capacity(total_elements);
139 for blob in blobs {
140 let v: &[i16] = vtype.blob_to_slice(blob);
141 flat.extend_from_slice(v);
142 }
143 Ok((DataType::Int16, Arc::new(Int16Array::from(flat))))
144 }
145 VectorType::Int4 => {
146 let mut flat = Vec::with_capacity(total_elements);
147 for blob in blobs {
148 let v: &[i32] = vtype.blob_to_slice(blob);
149 flat.extend_from_slice(v);
150 }
151 Ok((DataType::Int32, Arc::new(Int32Array::from(flat))))
152 }
153 }
154}
155
156fn extract_blob_from_array(
157 array: &ArrayRef,
158 vtype: VectorType,
159 dim: usize,
160) -> Result<Vec<u8>, ArrowError> {
161 match vtype {
162 VectorType::Float4 => {
163 let a = array
164 .as_any()
165 .downcast_ref::<Float32Array>()
166 .ok_or_else(|| ArrowError("expected Float32Array".into()))?;
167 let values: Vec<f32> = (0..dim).map(|i| a.value(i)).collect();
168 Ok(vtype.slice_to_blob(&values))
169 }
170 VectorType::Float8 => {
171 let a = array
172 .as_any()
173 .downcast_ref::<Float64Array>()
174 .ok_or_else(|| ArrowError("expected Float64Array".into()))?;
175 let values: Vec<f64> = (0..dim).map(|i| a.value(i)).collect();
176 Ok(vtype.slice_to_blob(&values))
177 }
178 VectorType::Float2 => {
179 let a = array
180 .as_any()
181 .downcast_ref::<Float16Array>()
182 .ok_or_else(|| ArrowError("expected Float16Array".into()))?;
183 let values: Vec<half::f16> = (0..dim).map(|i| a.value(i)).collect();
184 Ok(vtype.slice_to_blob(&values))
185 }
186 VectorType::Int1 => {
187 let a = array
188 .as_any()
189 .downcast_ref::<Int8Array>()
190 .ok_or_else(|| ArrowError("expected Int8Array".into()))?;
191 let values: Vec<i8> = (0..dim).map(|i| a.value(i)).collect();
192 Ok(vtype.slice_to_blob(&values))
193 }
194 VectorType::Int2 => {
195 let a = array
196 .as_any()
197 .downcast_ref::<Int16Array>()
198 .ok_or_else(|| ArrowError("expected Int16Array".into()))?;
199 let values: Vec<i16> = (0..dim).map(|i| a.value(i)).collect();
200 Ok(vtype.slice_to_blob(&values))
201 }
202 VectorType::Int4 => {
203 let a = array
204 .as_any()
205 .downcast_ref::<Int32Array>()
206 .ok_or_else(|| ArrowError("expected Int32Array".into()))?;
207 let values: Vec<i32> = (0..dim).map(|i| a.value(i)).collect();
208 Ok(vtype.slice_to_blob(&values))
209 }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::types::VectorType;
217 use half::f16;
218
219 fn f32_blob(values: &[f32]) -> Vec<u8> {
225 VectorType::Float4.slice_to_blob(values)
226 }
227
228 fn f64_blob(values: &[f64]) -> Vec<u8> {
230 VectorType::Float8.slice_to_blob(values)
231 }
232
233 fn i8_blob(values: &[i8]) -> Vec<u8> {
235 VectorType::Int1.slice_to_blob(values)
236 }
237
238 fn i16_blob(values: &[i16]) -> Vec<u8> {
240 VectorType::Int2.slice_to_blob(values)
241 }
242
243 fn i32_blob(values: &[i32]) -> Vec<u8> {
245 VectorType::Int4.slice_to_blob(values)
246 }
247
248 fn f16_blob(values: &[f16]) -> Vec<u8> {
250 VectorType::Float2.slice_to_blob(values)
251 }
252
253 #[test]
258 fn round_trip_float4() {
259 let blobs = vec![f32_blob(&[1.0_f32, 2.0, 3.0])];
260 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float4, 3).unwrap();
261 let result = arrow_ipc_to_vectors(&ipc, VectorType::Float4, 3).unwrap();
262 assert_eq!(result, blobs);
263 }
264
265 #[test]
270 fn round_trip_float8() {
271 let blobs = vec![f64_blob(&[1.0_f64, -2.5, 3.125])];
272 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float8, 3).unwrap();
273 let result = arrow_ipc_to_vectors(&ipc, VectorType::Float8, 3).unwrap();
274 assert_eq!(result, blobs);
275 }
276
277 #[test]
282 fn round_trip_int1() {
283 let blobs = vec![i8_blob(&[i8::MIN, 0, i8::MAX])];
284 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int1, 3).unwrap();
285 let result = arrow_ipc_to_vectors(&ipc, VectorType::Int1, 3).unwrap();
286 assert_eq!(result, blobs);
287 }
288
289 #[test]
290 fn round_trip_int2() {
291 let blobs = vec![i16_blob(&[i16::MIN, 0, i16::MAX])];
292 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int2, 3).unwrap();
293 let result = arrow_ipc_to_vectors(&ipc, VectorType::Int2, 3).unwrap();
294 assert_eq!(result, blobs);
295 }
296
297 #[test]
298 fn round_trip_int4() {
299 let blobs = vec![i32_blob(&[i32::MIN, 0, i32::MAX])];
300 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int4, 3).unwrap();
301 let result = arrow_ipc_to_vectors(&ipc, VectorType::Int4, 3).unwrap();
302 assert_eq!(result, blobs);
303 }
304
305 #[test]
310 fn round_trip_float2() {
311 let values = vec![f16::from_f32(1.0), f16::from_f32(-0.5), f16::from_f32(0.25)];
312 let blobs = vec![f16_blob(&values)];
313 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float2, 3).unwrap();
314 let result = arrow_ipc_to_vectors(&ipc, VectorType::Float2, 3).unwrap();
315 assert_eq!(result, blobs);
316 }
317
318 #[test]
323 fn empty_blobs_round_trip() {
324 let blobs: Vec<Vec<u8>> = vec![];
330 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float4, 4).unwrap();
331 assert!(
332 !ipc.is_empty(),
333 "IPC buffer must contain at least the schema header"
334 );
335 let result = arrow_ipc_to_vectors(&ipc, VectorType::Float4, 4).unwrap();
336 assert!(result.is_empty());
337 }
338
339 #[test]
344 fn dim_auto_detection_float4() {
345 let blobs = vec![f32_blob(&[10.0_f32, 20.0, 30.0])];
346 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float4, 3).unwrap();
347 let result = arrow_ipc_to_vectors(&ipc, VectorType::Float4, 0).unwrap();
349 assert_eq!(result, blobs);
350 }
351
352 #[test]
353 fn dim_auto_detection_int2() {
354 let blobs = vec![i16_blob(&[1_i16, 2, 3])];
355 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int2, 3).unwrap();
356 let result = arrow_ipc_to_vectors(&ipc, VectorType::Int2, 0).unwrap();
357 assert_eq!(result, blobs);
358 }
359
360 #[test]
365 fn multiple_vectors_float4() {
366 let blobs: Vec<Vec<u8>> = (0..5_u32)
367 .map(|i| {
368 let base = i as f32;
369 f32_blob(&[base, base + 1.0, base + 2.0, base + 3.0])
370 })
371 .collect();
372 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float4, 4).unwrap();
373 let result = arrow_ipc_to_vectors(&ipc, VectorType::Float4, 4).unwrap();
374 assert_eq!(result.len(), 5);
375 assert_eq!(result, blobs);
376 }
377
378 #[test]
379 fn multiple_vectors_int4() {
380 let blobs: Vec<Vec<u8>> = (0..5_i32)
381 .map(|i| i32_blob(&[i * 10, i * 10 + 1, i * 10 + 2]))
382 .collect();
383 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int4, 3).unwrap();
384 let result = arrow_ipc_to_vectors(&ipc, VectorType::Int4, 3).unwrap();
385 assert_eq!(result.len(), 5);
386 assert_eq!(result, blobs);
387 }
388
389 #[test]
394 fn single_vector_float8() {
395 let blobs = vec![f64_blob(&[std::f64::consts::PI, std::f64::consts::E])];
396 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Float8, 2).unwrap();
397 let result = arrow_ipc_to_vectors(&ipc, VectorType::Float8, 2).unwrap();
398 assert_eq!(result.len(), 1);
399 assert_eq!(result, blobs);
400 }
401
402 #[test]
403 fn single_vector_int1() {
404 let blobs = vec![i8_blob(&[-1_i8, 0, 127])];
405 let ipc = vectors_to_arrow_ipc(&blobs, VectorType::Int1, 3).unwrap();
406 let result = arrow_ipc_to_vectors(&ipc, VectorType::Int1, 3).unwrap();
407 assert_eq!(result.len(), 1);
408 assert_eq!(result, blobs);
409 }
410}