vortex_dtype/
arrow.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Convert between Vortex [`crate::DType`] and Apache Arrow [`arrow_schema::DataType`].
5//!
6//! Apache Arrow's type system includes physical information, which could lead to ambiguities as
7//! Vortex treats encodings as separate from logical types.
8//!
9//! [`DType::to_arrow_schema`] and its sibling [`DType::to_arrow_dtype`] use a simple algorithm,
10//! where every logical type is encoded in its simplest corresponding Arrow type. This reflects the
11//! reality that most compute engines don't make use of the entire type range arrow-rs supports.
12//!
13//! For this reason, it's recommended to do as much computation as possible within Vortex, and then
14//! materialize an Arrow ArrayRef at the very end of the processing chain.
15
16use std::sync::Arc;
17
18use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef};
19use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
20
21use crate::datetime::arrow::{make_arrow_temporal_dtype, make_temporal_ext_dtype};
22use crate::datetime::is_temporal_ext_type;
23use crate::{DType, DecimalDType, FieldName, Nullability, PType, StructFields};
24
25/// Trait for converting Arrow types to Vortex types.
26pub trait FromArrowType<T>: Sized {
27    /// Convert the Arrow type to a Vortex type.
28    fn from_arrow(value: T) -> Self;
29}
30
31/// Trait for converting Vortex types to Arrow types.
32pub trait TryFromArrowType<T>: Sized {
33    /// Convert the Arrow type to a Vortex type.
34    fn try_from_arrow(value: T) -> VortexResult<Self>;
35}
36
37impl TryFromArrowType<&DataType> for PType {
38    fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
39        match value {
40            DataType::Int8 => Ok(Self::I8),
41            DataType::Int16 => Ok(Self::I16),
42            DataType::Int32 => Ok(Self::I32),
43            DataType::Int64 => Ok(Self::I64),
44            DataType::UInt8 => Ok(Self::U8),
45            DataType::UInt16 => Ok(Self::U16),
46            DataType::UInt32 => Ok(Self::U32),
47            DataType::UInt64 => Ok(Self::U64),
48            DataType::Float16 => Ok(Self::F16),
49            DataType::Float32 => Ok(Self::F32),
50            DataType::Float64 => Ok(Self::F64),
51            _ => Err(vortex_err!(
52                "Arrow datatype {:?} cannot be converted to ptype",
53                value
54            )),
55        }
56    }
57}
58
59impl TryFromArrowType<&DataType> for DecimalDType {
60    fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
61        match value {
62            DataType::Decimal32(precision, scale)
63            | DataType::Decimal64(precision, scale)
64            | DataType::Decimal128(precision, scale)
65            | DataType::Decimal256(precision, scale) => Self::try_new(*precision, *scale),
66
67            _ => Err(vortex_err!(
68                "Arrow datatype {:?} cannot be converted to DecimalDType",
69                value
70            )),
71        }
72    }
73}
74
75impl FromArrowType<SchemaRef> for DType {
76    fn from_arrow(value: SchemaRef) -> Self {
77        Self::from_arrow(value.as_ref())
78    }
79}
80
81impl FromArrowType<&Schema> for DType {
82    fn from_arrow(value: &Schema) -> Self {
83        Self::Struct(
84            StructFields::from_arrow(value.fields()),
85            Nullability::NonNullable, // Must match From<RecordBatch> for Array
86        )
87    }
88}
89
90impl FromArrowType<&Fields> for StructFields {
91    fn from_arrow(value: &Fields) -> Self {
92        StructFields::from_iter(value.into_iter().map(|f| {
93            (
94                FieldName::from(f.name().as_str()),
95                DType::from_arrow(f.as_ref()),
96            )
97        }))
98    }
99}
100
101impl FromArrowType<(&DataType, Nullability)> for DType {
102    fn from_arrow((data_type, nullability): (&DataType, Nullability)) -> Self {
103        if data_type.is_integer() || data_type.is_floating() {
104            return DType::Primitive(
105                PType::try_from_arrow(data_type).vortex_expect("arrow float/integer to ptype"),
106                nullability,
107            );
108        }
109
110        match data_type {
111            DataType::Null => DType::Null,
112            DataType::Decimal32(precision, scale)
113            | DataType::Decimal64(precision, scale)
114            | DataType::Decimal128(precision, scale)
115            | DataType::Decimal256(precision, scale) => {
116                DType::Decimal(DecimalDType::new(*precision, *scale), nullability)
117            }
118            DataType::Boolean => DType::Bool(nullability),
119            DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => DType::Utf8(nullability),
120            DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
121                DType::Binary(nullability)
122            }
123            DataType::Date32
124            | DataType::Date64
125            | DataType::Time32(_)
126            | DataType::Time64(_)
127            | DataType::Timestamp(..) => DType::Extension(Arc::new(
128                make_temporal_ext_dtype(data_type).with_nullability(nullability),
129            )),
130            DataType::List(e)
131            | DataType::LargeList(e)
132            | DataType::ListView(e)
133            | DataType::LargeListView(e) => {
134                DType::List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
135            }
136            DataType::FixedSizeList(e, size) => DType::FixedSizeList(
137                Arc::new(Self::from_arrow(e.as_ref())),
138                *size as u32,
139                nullability,
140            ),
141            DataType::Struct(f) => DType::Struct(StructFields::from_arrow(f), nullability),
142            DataType::Dictionary(_, value_type) => {
143                Self::from_arrow((value_type.as_ref(), nullability))
144            }
145            _ => unimplemented!("Arrow data type not yet supported: {:?}", data_type),
146        }
147    }
148}
149
150impl FromArrowType<&Field> for DType {
151    fn from_arrow(field: &Field) -> Self {
152        Self::from_arrow((field.data_type(), field.is_nullable().into()))
153    }
154}
155
156impl DType {
157    /// Convert a Vortex [`DType`] into an Arrow [`Schema`].
158    pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
159        let DType::Struct(struct_dtype, nullable) = self else {
160            vortex_bail!("only DType::Struct can be converted to arrow schema");
161        };
162
163        if *nullable != Nullability::NonNullable {
164            vortex_bail!("top-level struct in Schema must be NonNullable");
165        }
166
167        let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
168        for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
169            builder.push(FieldRef::from(Field::new(
170                field_name.to_string(),
171                field_dtype.to_arrow_dtype()?,
172                field_dtype.is_nullable(),
173            )));
174        }
175
176        Ok(builder.finish())
177    }
178
179    /// Returns the Arrow [`DataType`] that best corresponds to this Vortex [`DType`].
180    pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
181        Ok(match self {
182            DType::Null => DataType::Null,
183            DType::Bool(_) => DataType::Boolean,
184            DType::Primitive(ptype, _) => match ptype {
185                PType::U8 => DataType::UInt8,
186                PType::U16 => DataType::UInt16,
187                PType::U32 => DataType::UInt32,
188                PType::U64 => DataType::UInt64,
189                PType::I8 => DataType::Int8,
190                PType::I16 => DataType::Int16,
191                PType::I32 => DataType::Int32,
192                PType::I64 => DataType::Int64,
193                PType::F16 => DataType::Float16,
194                PType::F32 => DataType::Float32,
195                PType::F64 => DataType::Float64,
196            },
197            DType::Decimal(dt, _) => {
198                let precision = dt.precision();
199                let scale = dt.scale();
200
201                match precision {
202                    // This code is commented out until DataFusion improves its support for smaller decimals.
203                    // // DECIMAL32_MAX_PRECISION
204                    // 0..=9 => DataType::Decimal32(precision, scale),
205                    // // DECIMAL64_MAX_PRECISION
206                    // 10..=18 => DataType::Decimal64(precision, scale),
207                    // DECIMAL128_MAX_PRECISION
208                    0..=38 => DataType::Decimal128(precision, scale),
209                    // DECIMAL256_MAX_PRECISION
210                    39.. => DataType::Decimal256(precision, scale),
211                }
212            }
213            DType::Utf8(_) => DataType::Utf8View,
214            DType::Binary(_) => DataType::BinaryView,
215            // There are four kinds of lists: List (32-bit offsets), Large List (64-bit), List View
216            // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an
217            // Arrow dtype because we do not how large our offsets are.
218            DType::List(elem_dtype, _) => DataType::List(FieldRef::new(Field::new_list_field(
219                elem_dtype.to_arrow_dtype()?,
220                elem_dtype.nullability().into(),
221            ))),
222            DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList(
223                FieldRef::new(Field::new_list_field(
224                    elem_dtype.to_arrow_dtype()?,
225                    elem_dtype.nullability().into(),
226                )),
227                *size as i32,
228            ),
229            DType::Struct(struct_dtype, _) => {
230                let mut fields = Vec::with_capacity(struct_dtype.names().len());
231                for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
232                {
233                    fields.push(FieldRef::from(Field::new(
234                        field_name.to_string(),
235                        field_dt.to_arrow_dtype()?,
236                        field_dt.is_nullable(),
237                    )));
238                }
239
240                DataType::Struct(Fields::from(fields))
241            }
242            DType::Extension(ext_dtype) => {
243                // Try and match against the known extension DTypes.
244                if is_temporal_ext_type(ext_dtype.id()) {
245                    make_arrow_temporal_dtype(ext_dtype)
246                } else {
247                    vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
248                }
249            }
250        })
251    }
252}
253
254#[cfg(test)]
255mod test {
256    use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
257    use rstest::{fixture, rstest};
258
259    use super::*;
260    use crate::{DType, ExtDType, ExtID, FieldName, FieldNames, Nullability, PType, StructFields};
261
262    #[test]
263    fn test_dtype_conversion_success() {
264        assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
265
266        assert_eq!(
267            DType::Bool(Nullability::NonNullable)
268                .to_arrow_dtype()
269                .unwrap(),
270            DataType::Boolean
271        );
272
273        assert_eq!(
274            DType::Primitive(PType::U64, Nullability::NonNullable)
275                .to_arrow_dtype()
276                .unwrap(),
277            DataType::UInt64
278        );
279
280        assert_eq!(
281            DType::Utf8(Nullability::NonNullable)
282                .to_arrow_dtype()
283                .unwrap(),
284            DataType::Utf8View
285        );
286
287        assert_eq!(
288            DType::Binary(Nullability::NonNullable)
289                .to_arrow_dtype()
290                .unwrap(),
291            DataType::BinaryView
292        );
293
294        assert_eq!(
295            DType::struct_(
296                [
297                    ("field_a", DType::Bool(false.into())),
298                    ("field_b", DType::Utf8(true.into()))
299                ],
300                Nullability::NonNullable,
301            )
302            .to_arrow_dtype()
303            .unwrap(),
304            DataType::Struct(Fields::from(vec![
305                FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
306                FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
307            ]))
308        );
309    }
310
311    #[test]
312    fn infer_nullable_list_element() {
313        let list_non_nullable = DType::List(
314            Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
315            Nullability::Nullable,
316        );
317
318        let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
319
320        let list_nullable = DType::List(
321            Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
322            Nullability::Nullable,
323        );
324        let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
325
326        assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
327        assert_eq!(
328            arrow_list_nullable,
329            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
330        );
331        assert_eq!(
332            arrow_list_non_nullable,
333            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
334        );
335    }
336
337    #[test]
338    #[should_panic]
339    fn test_dtype_conversion_panics() {
340        DType::Extension(Arc::new(ExtDType::new(
341            ExtID::from("my-fake-ext-dtype"),
342            Arc::new(DType::Utf8(Nullability::NonNullable)),
343            None,
344        )))
345        .to_arrow_dtype()
346        .unwrap();
347    }
348
349    #[fixture]
350    fn the_struct() -> StructFields {
351        StructFields::new(
352            FieldNames::from([
353                FieldName::from("field_a"),
354                FieldName::from("field_b"),
355                FieldName::from("field_c"),
356            ]),
357            vec![
358                DType::Bool(Nullability::NonNullable),
359                DType::Utf8(Nullability::NonNullable),
360                DType::Primitive(PType::I32, Nullability::Nullable),
361            ],
362        )
363    }
364
365    #[rstest]
366    fn test_schema_conversion(the_struct: StructFields) {
367        let schema_nonnull = DType::Struct(the_struct, Nullability::NonNullable);
368
369        assert_eq!(
370            schema_nonnull.to_arrow_schema().unwrap(),
371            Schema::new(Fields::from(vec![
372                Field::new("field_a", DataType::Boolean, false),
373                Field::new("field_b", DataType::Utf8View, false),
374                Field::new("field_c", DataType::Int32, true),
375            ]))
376        );
377    }
378
379    #[rstest]
380    #[should_panic]
381    fn test_schema_conversion_panics(the_struct: StructFields) {
382        let schema_null = DType::Struct(the_struct, Nullability::Nullable);
383        schema_null.to_arrow_schema().unwrap();
384    }
385}