vortex_dtype/
arrow.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Convert between Vortex [`crate::DType`] and Apache Arrow [`arrow_schema::DataType`].
5//!
6//! Apache Arrow's type system includes physical information, which could lead to ambiguities as
7//! Vortex treats encodings as separate from logical types.
8//!
9//! [`DType::to_arrow_schema`] and its sibling [`DType::to_arrow_dtype`] use a simple algorithm,
10//! where every logical type is encoded in its simplest corresponding Arrow type. This reflects the
11//! reality that most compute engines don't make use of the entire type range arrow-rs supports.
12//!
13//! For this reason, it's recommended to do as much computation as possible within Vortex, and then
14//! materialize an Arrow ArrayRef at the very end of the processing chain.
15
16use std::sync::Arc;
17
18use arrow_schema::DataType;
19use arrow_schema::Field;
20use arrow_schema::FieldRef;
21use arrow_schema::Fields;
22use arrow_schema::Schema;
23use arrow_schema::SchemaBuilder;
24use arrow_schema::SchemaRef;
25use vortex_error::VortexExpect;
26use vortex_error::VortexResult;
27use vortex_error::vortex_bail;
28use vortex_error::vortex_err;
29
30use crate::DType;
31use crate::DecimalDType;
32use crate::FieldName;
33use crate::Nullability;
34use crate::PType;
35use crate::StructFields;
36use crate::datetime::arrow::make_arrow_temporal_dtype;
37use crate::datetime::arrow::make_temporal_ext_dtype;
38use crate::datetime::is_temporal_ext_type;
39
40/// Trait for converting Arrow types to Vortex types.
41pub trait FromArrowType<T>: Sized {
42    /// Convert the Arrow type to a Vortex type.
43    fn from_arrow(value: T) -> Self;
44}
45
46/// Trait for converting Vortex types to Arrow types.
47pub trait TryFromArrowType<T>: Sized {
48    /// Convert the Arrow type to a Vortex type.
49    fn try_from_arrow(value: T) -> VortexResult<Self>;
50}
51
52impl TryFromArrowType<&DataType> for PType {
53    fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
54        match value {
55            DataType::Int8 => Ok(Self::I8),
56            DataType::Int16 => Ok(Self::I16),
57            DataType::Int32 => Ok(Self::I32),
58            DataType::Int64 => Ok(Self::I64),
59            DataType::UInt8 => Ok(Self::U8),
60            DataType::UInt16 => Ok(Self::U16),
61            DataType::UInt32 => Ok(Self::U32),
62            DataType::UInt64 => Ok(Self::U64),
63            DataType::Float16 => Ok(Self::F16),
64            DataType::Float32 => Ok(Self::F32),
65            DataType::Float64 => Ok(Self::F64),
66            _ => Err(vortex_err!(
67                "Arrow datatype {:?} cannot be converted to ptype",
68                value
69            )),
70        }
71    }
72}
73
74impl TryFromArrowType<&DataType> for DecimalDType {
75    fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
76        match value {
77            DataType::Decimal32(precision, scale)
78            | DataType::Decimal64(precision, scale)
79            | DataType::Decimal128(precision, scale)
80            | DataType::Decimal256(precision, scale) => Self::try_new(*precision, *scale),
81
82            _ => Err(vortex_err!(
83                "Arrow datatype {:?} cannot be converted to DecimalDType",
84                value
85            )),
86        }
87    }
88}
89
90impl FromArrowType<SchemaRef> for DType {
91    fn from_arrow(value: SchemaRef) -> Self {
92        Self::from_arrow(value.as_ref())
93    }
94}
95
96impl FromArrowType<&Schema> for DType {
97    fn from_arrow(value: &Schema) -> Self {
98        Self::Struct(
99            StructFields::from_arrow(value.fields()),
100            Nullability::NonNullable, // Must match From<RecordBatch> for Array
101        )
102    }
103}
104
105impl FromArrowType<&Fields> for StructFields {
106    fn from_arrow(value: &Fields) -> Self {
107        StructFields::from_iter(value.into_iter().map(|f| {
108            (
109                FieldName::from(f.name().as_str()),
110                DType::from_arrow(f.as_ref()),
111            )
112        }))
113    }
114}
115
116impl FromArrowType<(&DataType, Nullability)> for DType {
117    fn from_arrow((data_type, nullability): (&DataType, Nullability)) -> Self {
118        if data_type.is_integer() || data_type.is_floating() {
119            return DType::Primitive(
120                PType::try_from_arrow(data_type).vortex_expect("arrow float/integer to ptype"),
121                nullability,
122            );
123        }
124
125        match data_type {
126            DataType::Null => DType::Null,
127            DataType::Decimal32(precision, scale)
128            | DataType::Decimal64(precision, scale)
129            | DataType::Decimal128(precision, scale)
130            | DataType::Decimal256(precision, scale) => {
131                DType::Decimal(DecimalDType::new(*precision, *scale), nullability)
132            }
133            DataType::Boolean => DType::Bool(nullability),
134            DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => DType::Utf8(nullability),
135            DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
136                DType::Binary(nullability)
137            }
138            DataType::Date32
139            | DataType::Date64
140            | DataType::Time32(_)
141            | DataType::Time64(_)
142            | DataType::Timestamp(..) => DType::Extension(Arc::new(
143                make_temporal_ext_dtype(data_type).with_nullability(nullability),
144            )),
145            DataType::List(e)
146            | DataType::LargeList(e)
147            | DataType::ListView(e)
148            | DataType::LargeListView(e) => {
149                DType::List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
150            }
151            DataType::FixedSizeList(e, size) => DType::FixedSizeList(
152                Arc::new(Self::from_arrow(e.as_ref())),
153                *size as u32,
154                nullability,
155            ),
156            DataType::Struct(f) => DType::Struct(StructFields::from_arrow(f), nullability),
157            DataType::Dictionary(_, value_type) => {
158                Self::from_arrow((value_type.as_ref(), nullability))
159            }
160            _ => unimplemented!("Arrow data type not yet supported: {:?}", data_type),
161        }
162    }
163}
164
165impl FromArrowType<&Field> for DType {
166    fn from_arrow(field: &Field) -> Self {
167        Self::from_arrow((field.data_type(), field.is_nullable().into()))
168    }
169}
170
171impl DType {
172    /// Convert a Vortex [`DType`] into an Arrow [`Schema`].
173    pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
174        let DType::Struct(struct_dtype, nullable) = self else {
175            vortex_bail!("only DType::Struct can be converted to arrow schema");
176        };
177
178        if *nullable != Nullability::NonNullable {
179            vortex_bail!("top-level struct in Schema must be NonNullable");
180        }
181
182        let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
183        for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
184            builder.push(FieldRef::from(Field::new(
185                field_name.to_string(),
186                field_dtype.to_arrow_dtype()?,
187                field_dtype.is_nullable(),
188            )));
189        }
190
191        Ok(builder.finish())
192    }
193
194    /// Returns the Arrow [`DataType`] that best corresponds to this Vortex [`DType`].
195    pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
196        Ok(match self {
197            DType::Null => DataType::Null,
198            DType::Bool(_) => DataType::Boolean,
199            DType::Primitive(ptype, _) => match ptype {
200                PType::U8 => DataType::UInt8,
201                PType::U16 => DataType::UInt16,
202                PType::U32 => DataType::UInt32,
203                PType::U64 => DataType::UInt64,
204                PType::I8 => DataType::Int8,
205                PType::I16 => DataType::Int16,
206                PType::I32 => DataType::Int32,
207                PType::I64 => DataType::Int64,
208                PType::F16 => DataType::Float16,
209                PType::F32 => DataType::Float32,
210                PType::F64 => DataType::Float64,
211            },
212            DType::Decimal(dt, _) => {
213                let precision = dt.precision();
214                let scale = dt.scale();
215
216                match precision {
217                    // This code is commented out until DataFusion improves its support for smaller decimals.
218                    // // DECIMAL32_MAX_PRECISION
219                    // 0..=9 => DataType::Decimal32(precision, scale),
220                    // // DECIMAL64_MAX_PRECISION
221                    // 10..=18 => DataType::Decimal64(precision, scale),
222                    // DECIMAL128_MAX_PRECISION
223                    0..=38 => DataType::Decimal128(precision, scale),
224                    // DECIMAL256_MAX_PRECISION
225                    39.. => DataType::Decimal256(precision, scale),
226                }
227            }
228            DType::Utf8(_) => DataType::Utf8View,
229            DType::Binary(_) => DataType::BinaryView,
230            // There are four kinds of lists: List (32-bit offsets), Large List (64-bit), List View
231            // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an
232            // Arrow dtype because we do not how large our offsets are.
233            DType::List(elem_dtype, _) => DataType::List(FieldRef::new(Field::new_list_field(
234                elem_dtype.to_arrow_dtype()?,
235                elem_dtype.nullability().into(),
236            ))),
237            DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList(
238                FieldRef::new(Field::new_list_field(
239                    elem_dtype.to_arrow_dtype()?,
240                    elem_dtype.nullability().into(),
241                )),
242                *size as i32,
243            ),
244            DType::Struct(struct_dtype, _) => {
245                let mut fields = Vec::with_capacity(struct_dtype.names().len());
246                for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
247                {
248                    fields.push(FieldRef::from(Field::new(
249                        field_name.to_string(),
250                        field_dt.to_arrow_dtype()?,
251                        field_dt.is_nullable(),
252                    )));
253                }
254
255                DataType::Struct(Fields::from(fields))
256            }
257            DType::Extension(ext_dtype) => {
258                // Try and match against the known extension DTypes.
259                if is_temporal_ext_type(ext_dtype.id()) {
260                    make_arrow_temporal_dtype(ext_dtype)
261                } else {
262                    vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
263                }
264            }
265        })
266    }
267}
268
269#[cfg(test)]
270mod test {
271    use arrow_schema::DataType;
272    use arrow_schema::Field;
273    use arrow_schema::FieldRef;
274    use arrow_schema::Fields;
275    use arrow_schema::Schema;
276    use rstest::fixture;
277    use rstest::rstest;
278
279    use super::*;
280    use crate::DType;
281    use crate::ExtDType;
282    use crate::ExtID;
283    use crate::FieldName;
284    use crate::FieldNames;
285    use crate::Nullability;
286    use crate::PType;
287    use crate::StructFields;
288
289    #[test]
290    fn test_dtype_conversion_success() {
291        assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
292
293        assert_eq!(
294            DType::Bool(Nullability::NonNullable)
295                .to_arrow_dtype()
296                .unwrap(),
297            DataType::Boolean
298        );
299
300        assert_eq!(
301            DType::Primitive(PType::U64, Nullability::NonNullable)
302                .to_arrow_dtype()
303                .unwrap(),
304            DataType::UInt64
305        );
306
307        assert_eq!(
308            DType::Utf8(Nullability::NonNullable)
309                .to_arrow_dtype()
310                .unwrap(),
311            DataType::Utf8View
312        );
313
314        assert_eq!(
315            DType::Binary(Nullability::NonNullable)
316                .to_arrow_dtype()
317                .unwrap(),
318            DataType::BinaryView
319        );
320
321        assert_eq!(
322            DType::struct_(
323                [
324                    ("field_a", DType::Bool(false.into())),
325                    ("field_b", DType::Utf8(true.into()))
326                ],
327                Nullability::NonNullable,
328            )
329            .to_arrow_dtype()
330            .unwrap(),
331            DataType::Struct(Fields::from(vec![
332                FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
333                FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
334            ]))
335        );
336    }
337
338    #[test]
339    fn infer_nullable_list_element() {
340        let list_non_nullable = DType::List(
341            Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
342            Nullability::Nullable,
343        );
344
345        let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
346
347        let list_nullable = DType::List(
348            Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
349            Nullability::Nullable,
350        );
351        let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
352
353        assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
354        assert_eq!(
355            arrow_list_nullable,
356            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
357        );
358        assert_eq!(
359            arrow_list_non_nullable,
360            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
361        );
362    }
363
364    #[test]
365    #[should_panic]
366    fn test_dtype_conversion_panics() {
367        DType::Extension(Arc::new(ExtDType::new(
368            ExtID::from("my-fake-ext-dtype"),
369            Arc::new(DType::Utf8(Nullability::NonNullable)),
370            None,
371        )))
372        .to_arrow_dtype()
373        .unwrap();
374    }
375
376    #[fixture]
377    fn the_struct() -> StructFields {
378        StructFields::new(
379            FieldNames::from([
380                FieldName::from("field_a"),
381                FieldName::from("field_b"),
382                FieldName::from("field_c"),
383            ]),
384            vec![
385                DType::Bool(Nullability::NonNullable),
386                DType::Utf8(Nullability::NonNullable),
387                DType::Primitive(PType::I32, Nullability::Nullable),
388            ],
389        )
390    }
391
392    #[rstest]
393    fn test_schema_conversion(the_struct: StructFields) {
394        let schema_nonnull = DType::Struct(the_struct, Nullability::NonNullable);
395
396        assert_eq!(
397            schema_nonnull.to_arrow_schema().unwrap(),
398            Schema::new(Fields::from(vec![
399                Field::new("field_a", DataType::Boolean, false),
400                Field::new("field_b", DataType::Utf8View, false),
401                Field::new("field_c", DataType::Int32, true),
402            ]))
403        );
404    }
405
406    #[rstest]
407    #[should_panic]
408    fn test_schema_conversion_panics(the_struct: StructFields) {
409        let schema_null = DType::Struct(the_struct, Nullability::Nullable);
410        schema_null.to_arrow_schema().unwrap();
411    }
412}