vortex_dtype/
arrow.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Convert between Vortex [`crate::DType`] and Apache Arrow [`arrow_schema::DataType`].
5//!
6//! Apache Arrow's type system includes physical information, which could lead to ambiguities as
7//! Vortex treats encodings as separate from logical types.
8//!
9//! [`DType::to_arrow_schema`] and its sibling [`DType::to_arrow_dtype`] use a simple algorithm,
10//! where every logical type is encoded in its simplest corresponding Arrow type. This reflects the
11//! reality that most compute engines don't make use of the entire type range arrow-rs supports.
12//!
13//! For this reason, it's recommended to do as much computation as possible within Vortex, and then
14//! materialize an Arrow ArrayRef at the very end of the processing chain.
15
16use std::sync::Arc;
17
18use arrow_schema::{
19    DECIMAL128_MAX_PRECISION, DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef,
20};
21use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
22
23use crate::datetime::arrow::{make_arrow_temporal_dtype, make_temporal_ext_dtype};
24use crate::datetime::is_temporal_ext_type;
25use crate::{DType, DecimalDType, FieldName, Nullability, PType, StructFields};
26
27/// Trait for converting Arrow types to Vortex types.
28pub trait FromArrowType<T>: Sized {
29    /// Convert the Arrow type to a Vortex type.
30    fn from_arrow(value: T) -> Self;
31}
32
33/// Trait for converting Vortex types to Arrow types.
34pub trait TryFromArrowType<T>: Sized {
35    /// Convert the Arrow type to a Vortex type.
36    fn try_from_arrow(value: T) -> VortexResult<Self>;
37}
38
39impl TryFromArrowType<&DataType> for PType {
40    fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
41        match value {
42            DataType::Int8 => Ok(Self::I8),
43            DataType::Int16 => Ok(Self::I16),
44            DataType::Int32 => Ok(Self::I32),
45            DataType::Int64 => Ok(Self::I64),
46            DataType::UInt8 => Ok(Self::U8),
47            DataType::UInt16 => Ok(Self::U16),
48            DataType::UInt32 => Ok(Self::U32),
49            DataType::UInt64 => Ok(Self::U64),
50            DataType::Float16 => Ok(Self::F16),
51            DataType::Float32 => Ok(Self::F32),
52            DataType::Float64 => Ok(Self::F64),
53            _ => Err(vortex_err!(
54                "Arrow datatype {:?} cannot be converted to ptype",
55                value
56            )),
57        }
58    }
59}
60
61impl TryFromArrowType<&DataType> for DecimalDType {
62    fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
63        match value {
64            DataType::Decimal128(precision, scale) => Self::try_new(*precision, *scale),
65            DataType::Decimal256(precision, scale) => Self::try_new(*precision, *scale),
66            _ => Err(vortex_err!(
67                "Arrow datatype {:?} cannot be converted to DecimalDType",
68                value
69            )),
70        }
71    }
72}
73
74impl FromArrowType<SchemaRef> for DType {
75    fn from_arrow(value: SchemaRef) -> Self {
76        Self::from_arrow(value.as_ref())
77    }
78}
79
80impl FromArrowType<&Schema> for DType {
81    fn from_arrow(value: &Schema) -> Self {
82        Self::Struct(
83            StructFields::from_arrow(value.fields()),
84            Nullability::NonNullable, // Must match From<RecordBatch> for Array
85        )
86    }
87}
88
89impl FromArrowType<&Fields> for StructFields {
90    fn from_arrow(value: &Fields) -> Self {
91        StructFields::from_iter(value.into_iter().map(|f| {
92            (
93                FieldName::from(f.name().as_str()),
94                DType::from_arrow(f.as_ref()),
95            )
96        }))
97    }
98}
99
100impl FromArrowType<(&DataType, Nullability)> for DType {
101    fn from_arrow((data_type, nullability): (&DataType, Nullability)) -> Self {
102        if data_type.is_integer() || data_type.is_floating() {
103            return DType::Primitive(
104                PType::try_from_arrow(data_type).vortex_expect("arrow float/integer to ptype"),
105                nullability,
106            );
107        }
108
109        match data_type {
110            DataType::Null => DType::Null,
111            DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => {
112                DType::Decimal(DecimalDType::new(*precision, *scale), nullability)
113            }
114            DataType::Boolean => DType::Bool(nullability),
115            DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => DType::Utf8(nullability),
116            DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
117                DType::Binary(nullability)
118            }
119            DataType::Date32
120            | DataType::Date64
121            | DataType::Time32(_)
122            | DataType::Time64(_)
123            | DataType::Timestamp(..) => DType::Extension(Arc::new(
124                make_temporal_ext_dtype(data_type).with_nullability(nullability),
125            )),
126            DataType::List(e) | DataType::LargeList(e) => {
127                DType::List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
128            }
129            DataType::FixedSizeList(e, size) => DType::FixedSizeList(
130                Arc::new(Self::from_arrow(e.as_ref())),
131                *size as u32,
132                nullability,
133            ),
134            DataType::Struct(f) => DType::Struct(StructFields::from_arrow(f), nullability),
135            DataType::Dictionary(_, value_type) => {
136                Self::from_arrow((value_type.as_ref(), nullability))
137            }
138            _ => unimplemented!("Arrow data type not yet supported: {:?}", data_type),
139        }
140    }
141}
142
143impl FromArrowType<&Field> for DType {
144    fn from_arrow(field: &Field) -> Self {
145        Self::from_arrow((field.data_type(), field.is_nullable().into()))
146    }
147}
148
149impl DType {
150    /// Convert a Vortex [`DType`] into an Arrow [`Schema`].
151    pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
152        let DType::Struct(struct_dtype, nullable) = self else {
153            vortex_bail!("only DType::Struct can be converted to arrow schema");
154        };
155
156        if *nullable != Nullability::NonNullable {
157            vortex_bail!("top-level struct in Schema must be NonNullable");
158        }
159
160        let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
161        for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
162            builder.push(FieldRef::from(Field::new(
163                field_name.to_string(),
164                field_dtype.to_arrow_dtype()?,
165                field_dtype.is_nullable(),
166            )));
167        }
168
169        Ok(builder.finish())
170    }
171
172    /// Returns the Arrow [`DataType`] that best corresponds to this Vortex [`DType`].
173    pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
174        Ok(match self {
175            DType::Null => DataType::Null,
176            DType::Bool(_) => DataType::Boolean,
177            DType::Primitive(ptype, _) => match ptype {
178                PType::U8 => DataType::UInt8,
179                PType::U16 => DataType::UInt16,
180                PType::U32 => DataType::UInt32,
181                PType::U64 => DataType::UInt64,
182                PType::I8 => DataType::Int8,
183                PType::I16 => DataType::Int16,
184                PType::I32 => DataType::Int32,
185                PType::I64 => DataType::Int64,
186                PType::F16 => DataType::Float16,
187                PType::F32 => DataType::Float32,
188                PType::F64 => DataType::Float64,
189            },
190            DType::Decimal(dt, _) => {
191                if dt.precision() > DECIMAL128_MAX_PRECISION {
192                    DataType::Decimal256(dt.precision(), dt.scale())
193                } else {
194                    DataType::Decimal128(dt.precision(), dt.scale())
195                }
196            }
197            DType::Utf8(_) => DataType::Utf8View,
198            DType::Binary(_) => DataType::BinaryView,
199            DType::Struct(struct_dtype, _) => {
200                let mut fields = Vec::with_capacity(struct_dtype.names().len());
201                for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
202                {
203                    fields.push(FieldRef::from(Field::new(
204                        field_name.to_string(),
205                        field_dt.to_arrow_dtype()?,
206                        field_dt.is_nullable(),
207                    )));
208                }
209
210                DataType::Struct(Fields::from(fields))
211            }
212            // There are four kinds of lists: List (32-bit offsets), Large List (64-bit), List View
213            // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an
214            // Arrow dtype because we do not how large our offsets are.
215            DType::List(elem_dtype, _) => DataType::List(FieldRef::new(Field::new_list_field(
216                elem_dtype.to_arrow_dtype()?,
217                elem_dtype.nullability().into(),
218            ))),
219            DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList(
220                FieldRef::new(Field::new_list_field(
221                    elem_dtype.to_arrow_dtype()?,
222                    elem_dtype.nullability().into(),
223                )),
224                *size as i32,
225            ),
226            DType::Extension(ext_dtype) => {
227                // Try and match against the known extension DTypes.
228                if is_temporal_ext_type(ext_dtype.id()) {
229                    make_arrow_temporal_dtype(ext_dtype)
230                } else {
231                    vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
232                }
233            }
234        })
235    }
236}
237
238#[cfg(test)]
239mod test {
240    use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
241    use rstest::{fixture, rstest};
242
243    use super::*;
244    use crate::{DType, ExtDType, ExtID, FieldName, FieldNames, Nullability, PType, StructFields};
245
246    #[test]
247    fn test_dtype_conversion_success() {
248        assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
249
250        assert_eq!(
251            DType::Bool(Nullability::NonNullable)
252                .to_arrow_dtype()
253                .unwrap(),
254            DataType::Boolean
255        );
256
257        assert_eq!(
258            DType::Primitive(PType::U64, Nullability::NonNullable)
259                .to_arrow_dtype()
260                .unwrap(),
261            DataType::UInt64
262        );
263
264        assert_eq!(
265            DType::Utf8(Nullability::NonNullable)
266                .to_arrow_dtype()
267                .unwrap(),
268            DataType::Utf8View
269        );
270
271        assert_eq!(
272            DType::Binary(Nullability::NonNullable)
273                .to_arrow_dtype()
274                .unwrap(),
275            DataType::BinaryView
276        );
277
278        assert_eq!(
279            DType::struct_(
280                [
281                    ("field_a", DType::Bool(false.into())),
282                    ("field_b", DType::Utf8(true.into()))
283                ],
284                Nullability::NonNullable,
285            )
286            .to_arrow_dtype()
287            .unwrap(),
288            DataType::Struct(Fields::from(vec![
289                FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
290                FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
291            ]))
292        );
293    }
294
295    #[test]
296    fn infer_nullable_list_element() {
297        let list_non_nullable = DType::List(
298            Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
299            Nullability::Nullable,
300        );
301
302        let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
303
304        let list_nullable = DType::List(
305            Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
306            Nullability::Nullable,
307        );
308        let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
309
310        assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
311        assert_eq!(
312            arrow_list_nullable,
313            DataType::new_list(DataType::Int64, true)
314        );
315        assert_eq!(
316            arrow_list_non_nullable,
317            DataType::new_list(DataType::Int64, false)
318        );
319    }
320
321    #[test]
322    #[should_panic]
323    fn test_dtype_conversion_panics() {
324        let _ = DType::Extension(Arc::new(ExtDType::new(
325            ExtID::from("my-fake-ext-dtype"),
326            Arc::new(DType::Utf8(Nullability::NonNullable)),
327            None,
328        )))
329        .to_arrow_dtype()
330        .unwrap();
331    }
332
333    #[fixture]
334    fn the_struct() -> StructFields {
335        StructFields::new(
336            FieldNames::from([
337                FieldName::from("field_a"),
338                FieldName::from("field_b"),
339                FieldName::from("field_c"),
340            ]),
341            vec![
342                DType::Bool(Nullability::NonNullable),
343                DType::Utf8(Nullability::NonNullable),
344                DType::Primitive(PType::I32, Nullability::Nullable),
345            ],
346        )
347    }
348
349    #[rstest]
350    fn test_schema_conversion(the_struct: StructFields) {
351        let schema_nonnull = DType::Struct(the_struct, Nullability::NonNullable);
352
353        assert_eq!(
354            schema_nonnull.to_arrow_schema().unwrap(),
355            Schema::new(Fields::from(vec![
356                Field::new("field_a", DataType::Boolean, false),
357                Field::new("field_b", DataType::Utf8View, false),
358                Field::new("field_c", DataType::Int32, true),
359            ]))
360        );
361    }
362
363    #[rstest]
364    #[should_panic]
365    fn test_schema_conversion_panics(the_struct: StructFields) {
366        let schema_null = DType::Struct(the_struct, Nullability::Nullable);
367        let _ = schema_null.to_arrow_schema().unwrap();
368    }
369}