vortex_array/arrow/
convert.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_array::array::{
5    Array as ArrowArray, ArrowPrimitiveType, BooleanArray as ArrowBooleanArray, GenericByteArray,
6    NullArray as ArrowNullArray, OffsetSizeTrait, PrimitiveArray as ArrowPrimitiveArray,
7    StructArray as ArrowStructArray,
8};
9use arrow_array::cast::{AsArray, as_null_array};
10use arrow_array::types::{
11    ByteArrayType, ByteViewType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
12    Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type,
13    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
14    TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
15    TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
16};
17use arrow_array::{GenericByteViewArray, GenericListArray, RecordBatch, make_array};
18use arrow_buffer::buffer::{NullBuffer, OffsetBuffer};
19use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer as ArrowBuffer, ScalarBuffer};
20use arrow_schema::{DataType, TimeUnit as ArrowTimeUnit};
21use itertools::Itertools;
22use vortex_buffer::{Alignment, Buffer, ByteBuffer};
23use vortex_dtype::datetime::TimeUnit;
24use vortex_dtype::{DType, DecimalDType, NativePType, PType};
25use vortex_error::{VortexExpect as _, vortex_panic};
26use vortex_scalar::i256;
27
28use crate::arrays::{
29    BoolArray, DecimalArray, ListArray, NullArray, PrimitiveArray, StructArray, TemporalArray,
30    VarBinArray, VarBinViewArray,
31};
32use crate::arrow::FromArrowArray;
33use crate::validity::Validity;
34use crate::{ArrayRef, IntoArray};
35
36impl IntoArray for ArrowBuffer {
37    fn into_array(self) -> ArrayRef {
38        PrimitiveArray::from_byte_buffer(
39            ByteBuffer::from_arrow_buffer(self, Alignment::of::<u8>()),
40            PType::U8,
41            Validity::NonNullable,
42        )
43        .into_array()
44    }
45}
46
47impl IntoArray for BooleanBuffer {
48    fn into_array(self) -> ArrayRef {
49        BoolArray::new(self, Validity::NonNullable).into_array()
50    }
51}
52
53impl<T> IntoArray for ScalarBuffer<T>
54where
55    T: ArrowNativeType + NativePType,
56{
57    fn into_array(self) -> ArrayRef {
58        PrimitiveArray::new(
59            Buffer::<T>::from_arrow_scalar_buffer(self),
60            Validity::NonNullable,
61        )
62        .into_array()
63    }
64}
65
66impl<O> IntoArray for OffsetBuffer<O>
67where
68    O: NativePType + OffsetSizeTrait,
69{
70    fn into_array(self) -> ArrayRef {
71        let primitive = PrimitiveArray::new(
72            Buffer::from_arrow_scalar_buffer(self.into_inner()),
73            Validity::NonNullable,
74        );
75
76        primitive.into_array()
77    }
78}
79
80macro_rules! impl_from_arrow_primitive {
81    ($ty:path) => {
82        impl FromArrowArray<&ArrowPrimitiveArray<$ty>> for ArrayRef {
83            fn from_arrow(value: &ArrowPrimitiveArray<$ty>, nullable: bool) -> Self {
84                let buffer = Buffer::from_arrow_scalar_buffer(value.values().clone());
85                let validity = nulls(value.nulls(), nullable);
86                PrimitiveArray::new(buffer, validity).into_array()
87            }
88        }
89    };
90}
91
92impl_from_arrow_primitive!(Int8Type);
93impl_from_arrow_primitive!(Int16Type);
94impl_from_arrow_primitive!(Int32Type);
95impl_from_arrow_primitive!(Int64Type);
96impl_from_arrow_primitive!(UInt8Type);
97impl_from_arrow_primitive!(UInt16Type);
98impl_from_arrow_primitive!(UInt32Type);
99impl_from_arrow_primitive!(UInt64Type);
100impl_from_arrow_primitive!(Float16Type);
101impl_from_arrow_primitive!(Float32Type);
102impl_from_arrow_primitive!(Float64Type);
103
104impl FromArrowArray<&ArrowPrimitiveArray<Decimal128Type>> for ArrayRef {
105    fn from_arrow(array: &ArrowPrimitiveArray<Decimal128Type>, nullable: bool) -> Self {
106        let decimal_type = DecimalDType::new(array.precision(), array.scale());
107        let buffer = Buffer::from_arrow_scalar_buffer(array.values().clone());
108        let validity = nulls(array.nulls(), nullable);
109        DecimalArray::new(buffer, decimal_type, validity).into_array()
110    }
111}
112
113impl FromArrowArray<&ArrowPrimitiveArray<Decimal256Type>> for ArrayRef {
114    fn from_arrow(array: &ArrowPrimitiveArray<Decimal256Type>, nullable: bool) -> Self {
115        let decimal_type = DecimalDType::new(array.precision(), array.scale());
116        let buffer = Buffer::from_arrow_scalar_buffer(array.values().clone());
117        // SAFETY: Our i256 implementation has the same bit-pattern representation of the
118        //  arrow_buffer::i256 type. It is safe to treat values held inside the buffer as values
119        //  of either type.
120        let buffer =
121            unsafe { std::mem::transmute::<Buffer<arrow_buffer::i256>, Buffer<i256>>(buffer) };
122        let validity = nulls(array.nulls(), nullable);
123        DecimalArray::new(buffer, decimal_type, validity).into_array()
124    }
125}
126
127macro_rules! impl_from_arrow_temporal {
128    ($ty:path) => {
129        impl FromArrowArray<&ArrowPrimitiveArray<$ty>> for ArrayRef {
130            fn from_arrow(value: &ArrowPrimitiveArray<$ty>, nullable: bool) -> Self {
131                temporal_array(value, nullable)
132            }
133        }
134    };
135}
136
137// timestamp
138impl_from_arrow_temporal!(TimestampSecondType);
139impl_from_arrow_temporal!(TimestampMillisecondType);
140impl_from_arrow_temporal!(TimestampMicrosecondType);
141impl_from_arrow_temporal!(TimestampNanosecondType);
142
143// time
144impl_from_arrow_temporal!(Time32SecondType);
145impl_from_arrow_temporal!(Time32MillisecondType);
146impl_from_arrow_temporal!(Time64MicrosecondType);
147impl_from_arrow_temporal!(Time64NanosecondType);
148
149// date
150impl_from_arrow_temporal!(Date32Type);
151impl_from_arrow_temporal!(Date64Type);
152
153fn temporal_array<T: ArrowPrimitiveType>(value: &ArrowPrimitiveArray<T>, nullable: bool) -> ArrayRef
154where
155    T::Native: NativePType,
156{
157    let arr = PrimitiveArray::new(
158        Buffer::from_arrow_scalar_buffer(value.values().clone()),
159        nulls(value.nulls(), nullable),
160    )
161    .into_array();
162
163    match T::DATA_TYPE {
164        DataType::Timestamp(time_unit, tz) => {
165            let tz = tz.map(|s| s.to_string());
166            TemporalArray::new_timestamp(arr, time_unit.into(), tz).into()
167        }
168        DataType::Time32(time_unit) => TemporalArray::new_time(arr, time_unit.into()).into(),
169        DataType::Time64(time_unit) => TemporalArray::new_time(arr, time_unit.into()).into(),
170        DataType::Date32 => TemporalArray::new_date(arr, TimeUnit::D).into(),
171        DataType::Date64 => TemporalArray::new_date(arr, TimeUnit::Ms).into(),
172        DataType::Duration(_) => unimplemented!(),
173        DataType::Interval(_) => unimplemented!(),
174        _ => vortex_panic!("Invalid temporal type: {}", T::DATA_TYPE),
175    }
176}
177
178impl<T: ByteArrayType> FromArrowArray<&GenericByteArray<T>> for ArrayRef
179where
180    <T as ByteArrayType>::Offset: NativePType,
181{
182    fn from_arrow(value: &GenericByteArray<T>, nullable: bool) -> Self {
183        let dtype = match T::DATA_TYPE {
184            DataType::Binary | DataType::LargeBinary => DType::Binary(nullable.into()),
185            DataType::Utf8 | DataType::LargeUtf8 => DType::Utf8(nullable.into()),
186            _ => vortex_panic!("Invalid data type for ByteArray: {}", T::DATA_TYPE),
187        };
188        VarBinArray::try_new(
189            value.offsets().clone().into_array(),
190            ByteBuffer::from_arrow_buffer(value.values().clone(), Alignment::of::<u8>()),
191            dtype,
192            nulls(value.nulls(), nullable),
193        )
194        .vortex_expect("Failed to convert Arrow GenericByteArray to Vortex VarBinArray")
195        .into_array()
196    }
197}
198
199impl<T: ByteViewType> FromArrowArray<&GenericByteViewArray<T>> for ArrayRef {
200    fn from_arrow(value: &GenericByteViewArray<T>, nullable: bool) -> Self {
201        let dtype = match T::DATA_TYPE {
202            DataType::BinaryView => DType::Binary(nullable.into()),
203            DataType::Utf8View => DType::Utf8(nullable.into()),
204            _ => vortex_panic!("Invalid data type for ByteViewArray: {}", T::DATA_TYPE),
205        };
206
207        let views_buffer = Buffer::from_byte_buffer(
208            Buffer::from_arrow_scalar_buffer(value.views().clone()).into_byte_buffer(),
209        );
210
211        VarBinViewArray::try_new(
212            views_buffer,
213            value
214                .data_buffers()
215                .iter()
216                .map(|b| ByteBuffer::from_arrow_buffer(b.clone(), Alignment::of::<u8>()))
217                .collect::<Vec<_>>(),
218            dtype,
219            nulls(value.nulls(), nullable),
220        )
221        .vortex_expect("Failed to convert Arrow GenericByteViewArray to Vortex VarBinViewArray")
222        .into_array()
223    }
224}
225
226impl FromArrowArray<&ArrowBooleanArray> for ArrayRef {
227    fn from_arrow(value: &ArrowBooleanArray, nullable: bool) -> Self {
228        BoolArray::new(value.values().clone(), nulls(value.nulls(), nullable)).into_array()
229    }
230}
231
232/// Strip out the nulls from this array and return a new array without nulls.
233fn remove_nulls(data: arrow_data::ArrayData) -> arrow_data::ArrayData {
234    if data.null_count() == 0 {
235        // No nulls to remove, return the array as is
236        return data;
237    }
238
239    let children = match data.data_type() {
240        DataType::Struct(fields) => Some(
241            fields
242                .iter()
243                .zip(data.child_data().iter())
244                .map(|(field, child_data)| {
245                    if field.is_nullable() {
246                        child_data.clone()
247                    } else {
248                        remove_nulls(child_data.clone())
249                    }
250                })
251                .collect_vec(),
252        ),
253        DataType::List(f)
254        | DataType::LargeList(f)
255        | DataType::ListView(f)
256        | DataType::LargeListView(f)
257        | DataType::FixedSizeList(f, _)
258            if !f.is_nullable() =>
259        {
260            // All list types only have one child
261            assert_eq!(
262                data.child_data().len(),
263                1,
264                "List types should have one child"
265            );
266            Some(vec![remove_nulls(data.child_data()[0].clone())])
267        }
268        _ => None,
269    };
270
271    let mut builder = data.into_builder().nulls(None);
272    if let Some(children) = children {
273        builder = builder.child_data(children);
274    }
275    builder
276        .build()
277        .vortex_expect("reconstructing array without nulls")
278}
279
280impl FromArrowArray<&ArrowStructArray> for ArrayRef {
281    fn from_arrow(value: &ArrowStructArray, nullable: bool) -> Self {
282        StructArray::try_new(
283            value.column_names().iter().copied().collect(),
284            value
285                .columns()
286                .iter()
287                .zip(value.fields())
288                .map(|(c, field)| {
289                    // Arrow pushes down nulls, even into non-nullable fields. So we strip them
290                    // out here because Vortex is a little more strict.
291                    if c.null_count() > 0 && !field.is_nullable() {
292                        let stripped = make_array(remove_nulls(c.into_data()));
293                        Self::from_arrow(stripped.as_ref(), false)
294                    } else {
295                        Self::from_arrow(c.as_ref(), field.is_nullable())
296                    }
297                })
298                .collect(),
299            value.len(),
300            nulls(value.nulls(), nullable),
301        )
302        .vortex_expect("Failed to convert Arrow StructArray to Vortex StructArray")
303        .into_array()
304    }
305}
306
307impl<O: OffsetSizeTrait + NativePType> FromArrowArray<&GenericListArray<O>> for ArrayRef {
308    fn from_arrow(value: &GenericListArray<O>, nullable: bool) -> Self {
309        // Extract the validity of the underlying element array
310        let elem_nullable = match value.data_type() {
311            DataType::List(field) => field.is_nullable(),
312            DataType::LargeList(field) => field.is_nullable(),
313            dt => vortex_panic!("Invalid data type for ListArray: {dt}"),
314        };
315        ListArray::try_new(
316            Self::from_arrow(value.values().as_ref(), elem_nullable),
317            // offsets are always non-nullable
318            value.offsets().clone().into_array(),
319            nulls(value.nulls(), nullable),
320        )
321        .vortex_expect("Failed to convert Arrow StructArray to Vortex StructArray")
322        .into_array()
323    }
324}
325
326impl FromArrowArray<&ArrowNullArray> for ArrayRef {
327    fn from_arrow(value: &ArrowNullArray, nullable: bool) -> Self {
328        assert!(nullable);
329        NullArray::new(value.len()).into_array()
330    }
331}
332
333fn nulls(nulls: Option<&NullBuffer>, nullable: bool) -> Validity {
334    if nullable {
335        nulls
336            .map(|nulls| {
337                if nulls.null_count() == nulls.len() {
338                    Validity::AllInvalid
339                } else {
340                    Validity::from(nulls.inner().clone())
341                }
342            })
343            .unwrap_or_else(|| Validity::AllValid)
344    } else {
345        assert!(nulls.map(|x| x.null_count() == 0).unwrap_or(true));
346        Validity::NonNullable
347    }
348}
349
350impl FromArrowArray<&dyn ArrowArray> for ArrayRef {
351    fn from_arrow(array: &dyn ArrowArray, nullable: bool) -> Self {
352        match array.data_type() {
353            DataType::Boolean => Self::from_arrow(array.as_boolean(), nullable),
354            DataType::UInt8 => Self::from_arrow(array.as_primitive::<UInt8Type>(), nullable),
355            DataType::UInt16 => Self::from_arrow(array.as_primitive::<UInt16Type>(), nullable),
356            DataType::UInt32 => Self::from_arrow(array.as_primitive::<UInt32Type>(), nullable),
357            DataType::UInt64 => Self::from_arrow(array.as_primitive::<UInt64Type>(), nullable),
358            DataType::Int8 => Self::from_arrow(array.as_primitive::<Int8Type>(), nullable),
359            DataType::Int16 => Self::from_arrow(array.as_primitive::<Int16Type>(), nullable),
360            DataType::Int32 => Self::from_arrow(array.as_primitive::<Int32Type>(), nullable),
361            DataType::Int64 => Self::from_arrow(array.as_primitive::<Int64Type>(), nullable),
362            DataType::Float16 => Self::from_arrow(array.as_primitive::<Float16Type>(), nullable),
363            DataType::Float32 => Self::from_arrow(array.as_primitive::<Float32Type>(), nullable),
364            DataType::Float64 => Self::from_arrow(array.as_primitive::<Float64Type>(), nullable),
365            DataType::Utf8 => Self::from_arrow(array.as_string::<i32>(), nullable),
366            DataType::LargeUtf8 => Self::from_arrow(array.as_string::<i64>(), nullable),
367            DataType::Binary => Self::from_arrow(array.as_binary::<i32>(), nullable),
368            DataType::LargeBinary => Self::from_arrow(array.as_binary::<i64>(), nullable),
369            DataType::BinaryView => Self::from_arrow(array.as_binary_view(), nullable),
370            DataType::Utf8View => Self::from_arrow(array.as_string_view(), nullable),
371            DataType::Struct(_) => Self::from_arrow(array.as_struct(), nullable),
372            DataType::List(_) => Self::from_arrow(array.as_list::<i32>(), nullable),
373            DataType::LargeList(_) => Self::from_arrow(array.as_list::<i64>(), nullable),
374            DataType::Null => Self::from_arrow(as_null_array(array), nullable),
375            DataType::Timestamp(u, _) => match u {
376                ArrowTimeUnit::Second => {
377                    Self::from_arrow(array.as_primitive::<TimestampSecondType>(), nullable)
378                }
379                ArrowTimeUnit::Millisecond => {
380                    Self::from_arrow(array.as_primitive::<TimestampMillisecondType>(), nullable)
381                }
382                ArrowTimeUnit::Microsecond => {
383                    Self::from_arrow(array.as_primitive::<TimestampMicrosecondType>(), nullable)
384                }
385                ArrowTimeUnit::Nanosecond => {
386                    Self::from_arrow(array.as_primitive::<TimestampNanosecondType>(), nullable)
387                }
388            },
389            DataType::Date32 => Self::from_arrow(array.as_primitive::<Date32Type>(), nullable),
390            DataType::Date64 => Self::from_arrow(array.as_primitive::<Date64Type>(), nullable),
391            DataType::Time32(u) => match u {
392                ArrowTimeUnit::Second => {
393                    Self::from_arrow(array.as_primitive::<Time32SecondType>(), nullable)
394                }
395                ArrowTimeUnit::Millisecond => {
396                    Self::from_arrow(array.as_primitive::<Time32MillisecondType>(), nullable)
397                }
398                _ => unreachable!(),
399            },
400            DataType::Time64(u) => match u {
401                ArrowTimeUnit::Microsecond => {
402                    Self::from_arrow(array.as_primitive::<Time64MicrosecondType>(), nullable)
403                }
404                ArrowTimeUnit::Nanosecond => {
405                    Self::from_arrow(array.as_primitive::<Time64NanosecondType>(), nullable)
406                }
407                _ => unreachable!(),
408            },
409            DataType::Decimal128(..) => {
410                Self::from_arrow(array.as_primitive::<Decimal128Type>(), nullable)
411            }
412            DataType::Decimal256(..) => {
413                Self::from_arrow(array.as_primitive::<Decimal256Type>(), nullable)
414            }
415            _ => vortex_panic!(
416                "Array encoding not implemented for Arrow data type {}",
417                array.data_type().clone()
418            ),
419        }
420    }
421}
422
423impl FromArrowArray<RecordBatch> for ArrayRef {
424    fn from_arrow(array: RecordBatch, nullable: bool) -> Self {
425        ArrayRef::from_arrow(&arrow_array::StructArray::from(array), nullable)
426    }
427}
428
429impl FromArrowArray<&RecordBatch> for ArrayRef {
430    fn from_arrow(array: &RecordBatch, nullable: bool) -> Self {
431        Self::from_arrow(array.clone(), nullable)
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use arrow_array::new_null_array;
438    use arrow_schema::{DataType, Field, Fields};
439
440    use crate::ArrayRef;
441    use crate::arrow::FromArrowArray as _;
442
443    #[test]
444    pub fn nullable_may_contain_non_nullable() {
445        let null_struct_array_with_non_nullable_field = new_null_array(
446            &DataType::Struct(Fields::from(vec![Field::new(
447                "non_nullable_inner",
448                DataType::Int32,
449                false,
450            )])),
451            1,
452        );
453        ArrayRef::from_arrow(null_struct_array_with_non_nullable_field.as_ref(), true);
454    }
455
456    #[test]
457    pub fn nullable_may_contain_deeply_nested_non_nullable() {
458        let null_struct_array_with_non_nullable_field = new_null_array(
459            &DataType::Struct(Fields::from(vec![Field::new(
460                "non_nullable_inner",
461                DataType::Struct(Fields::from(vec![Field::new(
462                    "non_nullable_deeper_inner",
463                    DataType::Int32,
464                    false,
465                )])),
466                false,
467            )])),
468            1,
469        );
470        ArrayRef::from_arrow(null_struct_array_with_non_nullable_field.as_ref(), true);
471    }
472
473    #[test]
474    #[should_panic]
475    pub fn cannot_handle_nullable_struct_containing_non_nullable_dictionary() {
476        let null_struct_array_with_non_nullable_field = new_null_array(
477            &DataType::Struct(Fields::from(vec![Field::new(
478                "non_nullable_deeper_inner",
479                DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
480                false,
481            )])),
482            1,
483        );
484
485        ArrayRef::from_arrow(null_struct_array_with_non_nullable_field.as_ref(), true);
486    }
487}