Skip to main content

vortex_dtype/
arrow.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Convert between Vortex [`crate::DType`] and Apache Arrow [`arrow_schema::DataType`].
5//!
6//! Apache Arrow's type system includes physical information, which could lead to ambiguities as
7//! Vortex treats encodings as separate from logical types.
8//!
9//! [`DType::to_arrow_schema`] and its sibling [`DType::to_arrow_dtype`] use a simple algorithm,
10//! where every logical type is encoded in its simplest corresponding Arrow type. This reflects the
11//! reality that most compute engines don't make use of the entire type range arrow-rs supports.
12//!
13//! For this reason, it's recommended to do as much computation as possible within Vortex, and then
14//! materialize an Arrow ArrayRef at the very end of the processing chain.
15
16use std::sync::Arc;
17
18use arrow_schema::DataType;
19use arrow_schema::Field;
20use arrow_schema::FieldRef;
21use arrow_schema::Fields;
22use arrow_schema::Schema;
23use arrow_schema::SchemaBuilder;
24use arrow_schema::SchemaRef;
25use arrow_schema::TimeUnit as ArrowTimeUnit;
26use vortex_error::VortexError;
27use vortex_error::VortexExpect;
28use vortex_error::VortexResult;
29use vortex_error::vortex_bail;
30use vortex_error::vortex_err;
31use vortex_error::vortex_panic;
32
33use crate::DType;
34use crate::DecimalDType;
35use crate::FieldName;
36use crate::Nullability;
37use crate::PType;
38use crate::StructFields;
39use crate::datetime::AnyTemporal;
40use crate::datetime::Date;
41use crate::datetime::TemporalMetadata;
42use crate::datetime::Time;
43use crate::datetime::TimeUnit;
44use crate::datetime::Timestamp;
45
46/// Trait for converting Arrow types to Vortex types.
47pub trait FromArrowType<T>: Sized {
48    /// Convert the Arrow type to a Vortex type.
49    fn from_arrow(value: T) -> Self;
50}
51
52/// Trait for converting Vortex types to Arrow types.
53pub trait TryFromArrowType<T>: Sized {
54    /// Convert the Arrow type to a Vortex type.
55    fn try_from_arrow(value: T) -> VortexResult<Self>;
56}
57
58impl TryFromArrowType<&DataType> for PType {
59    fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
60        match value {
61            DataType::Int8 => Ok(Self::I8),
62            DataType::Int16 => Ok(Self::I16),
63            DataType::Int32 => Ok(Self::I32),
64            DataType::Int64 => Ok(Self::I64),
65            DataType::UInt8 => Ok(Self::U8),
66            DataType::UInt16 => Ok(Self::U16),
67            DataType::UInt32 => Ok(Self::U32),
68            DataType::UInt64 => Ok(Self::U64),
69            DataType::Float16 => Ok(Self::F16),
70            DataType::Float32 => Ok(Self::F32),
71            DataType::Float64 => Ok(Self::F64),
72            _ => Err(vortex_err!(
73                "Arrow datatype {:?} cannot be converted to ptype",
74                value
75            )),
76        }
77    }
78}
79
80impl TryFromArrowType<&DataType> for DecimalDType {
81    fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
82        match value {
83            DataType::Decimal32(precision, scale)
84            | DataType::Decimal64(precision, scale)
85            | DataType::Decimal128(precision, scale)
86            | DataType::Decimal256(precision, scale) => Self::try_new(*precision, *scale),
87
88            _ => Err(vortex_err!(
89                "Arrow datatype {:?} cannot be converted to DecimalDType",
90                value
91            )),
92        }
93    }
94}
95
96impl From<&ArrowTimeUnit> for TimeUnit {
97    fn from(value: &ArrowTimeUnit) -> Self {
98        (*value).into()
99    }
100}
101
102impl From<ArrowTimeUnit> for TimeUnit {
103    fn from(value: ArrowTimeUnit) -> Self {
104        match value {
105            ArrowTimeUnit::Second => Self::Seconds,
106            ArrowTimeUnit::Millisecond => Self::Milliseconds,
107            ArrowTimeUnit::Microsecond => Self::Microseconds,
108            ArrowTimeUnit::Nanosecond => Self::Nanoseconds,
109        }
110    }
111}
112
113impl TryFrom<TimeUnit> for ArrowTimeUnit {
114    type Error = VortexError;
115
116    fn try_from(value: TimeUnit) -> VortexResult<Self> {
117        Ok(match value {
118            TimeUnit::Seconds => Self::Second,
119            TimeUnit::Milliseconds => Self::Millisecond,
120            TimeUnit::Microseconds => Self::Microsecond,
121            TimeUnit::Nanoseconds => Self::Nanosecond,
122            _ => vortex_bail!("Cannot convert {value} to Arrow TimeUnit"),
123        })
124    }
125}
126
127impl FromArrowType<SchemaRef> for DType {
128    fn from_arrow(value: SchemaRef) -> Self {
129        Self::from_arrow(value.as_ref())
130    }
131}
132
133impl FromArrowType<&Schema> for DType {
134    fn from_arrow(value: &Schema) -> Self {
135        Self::Struct(
136            StructFields::from_arrow(value.fields()),
137            Nullability::NonNullable, // Must match From<RecordBatch> for Array
138        )
139    }
140}
141
142impl FromArrowType<&Fields> for StructFields {
143    fn from_arrow(value: &Fields) -> Self {
144        StructFields::from_iter(value.into_iter().map(|f| {
145            (
146                FieldName::from(f.name().as_str()),
147                DType::from_arrow(f.as_ref()),
148            )
149        }))
150    }
151}
152
153impl FromArrowType<(&DataType, Nullability)> for DType {
154    fn from_arrow((data_type, nullability): (&DataType, Nullability)) -> Self {
155        if data_type.is_integer() || data_type.is_floating() {
156            return DType::Primitive(
157                PType::try_from_arrow(data_type).vortex_expect("arrow float/integer to ptype"),
158                nullability,
159            );
160        }
161
162        match data_type {
163            DataType::Null => DType::Null,
164            DataType::Decimal32(precision, scale)
165            | DataType::Decimal64(precision, scale)
166            | DataType::Decimal128(precision, scale)
167            | DataType::Decimal256(precision, scale) => {
168                DType::Decimal(DecimalDType::new(*precision, *scale), nullability)
169            }
170            DataType::Boolean => DType::Bool(nullability),
171            DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => DType::Utf8(nullability),
172            DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
173                DType::Binary(nullability)
174            }
175            DataType::Date32 => DType::Extension(Date::new(TimeUnit::Days, nullability).erased()),
176            DataType::Date64 => {
177                DType::Extension(Date::new(TimeUnit::Milliseconds, nullability).erased())
178            }
179            DataType::Time32(unit) => {
180                DType::Extension(Time::new(unit.into(), nullability).erased())
181            }
182            DataType::Time64(unit) => {
183                DType::Extension(Time::new(unit.into(), nullability).erased())
184            }
185            DataType::Timestamp(unit, tz) => DType::Extension(
186                Timestamp::new_with_tz(unit.into(), tz.clone(), nullability).erased(),
187            ),
188            DataType::List(e)
189            | DataType::LargeList(e)
190            | DataType::ListView(e)
191            | DataType::LargeListView(e) => {
192                DType::List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
193            }
194            DataType::FixedSizeList(e, size) => DType::FixedSizeList(
195                Arc::new(Self::from_arrow(e.as_ref())),
196                *size as u32,
197                nullability,
198            ),
199            DataType::Struct(f) => DType::Struct(StructFields::from_arrow(f), nullability),
200            DataType::Dictionary(_, value_type) => {
201                Self::from_arrow((value_type.as_ref(), nullability))
202            }
203            _ => unimplemented!("Arrow data type not yet supported: {:?}", data_type),
204        }
205    }
206}
207
208impl FromArrowType<&Field> for DType {
209    fn from_arrow(field: &Field) -> Self {
210        Self::from_arrow((field.data_type(), field.is_nullable().into()))
211    }
212}
213
214impl DType {
215    /// Convert a Vortex [`DType`] into an Arrow [`Schema`].
216    pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
217        let DType::Struct(struct_dtype, nullable) = self else {
218            vortex_bail!("only DType::Struct can be converted to arrow schema");
219        };
220
221        if *nullable != Nullability::NonNullable {
222            vortex_bail!("top-level struct in Schema must be NonNullable");
223        }
224
225        let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
226        for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
227            builder.push(FieldRef::from(Field::new(
228                field_name.as_ref(),
229                field_dtype.to_arrow_dtype()?,
230                field_dtype.is_nullable(),
231            )));
232        }
233
234        Ok(builder.finish())
235    }
236
237    /// Returns the Arrow [`DataType`] that best corresponds to this Vortex [`DType`].
238    pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
239        Ok(match self {
240            DType::Null => DataType::Null,
241            DType::Bool(_) => DataType::Boolean,
242            DType::Primitive(ptype, _) => match ptype {
243                PType::U8 => DataType::UInt8,
244                PType::U16 => DataType::UInt16,
245                PType::U32 => DataType::UInt32,
246                PType::U64 => DataType::UInt64,
247                PType::I8 => DataType::Int8,
248                PType::I16 => DataType::Int16,
249                PType::I32 => DataType::Int32,
250                PType::I64 => DataType::Int64,
251                PType::F16 => DataType::Float16,
252                PType::F32 => DataType::Float32,
253                PType::F64 => DataType::Float64,
254            },
255            DType::Decimal(dt, _) => {
256                let precision = dt.precision();
257                let scale = dt.scale();
258
259                match precision {
260                    // This code is commented out until DataFusion improves its support for smaller decimals.
261                    // // DECIMAL32_MAX_PRECISION
262                    // 0..=9 => DataType::Decimal32(precision, scale),
263                    // // DECIMAL64_MAX_PRECISION
264                    // 10..=18 => DataType::Decimal64(precision, scale),
265                    // DECIMAL128_MAX_PRECISION
266                    0..=38 => DataType::Decimal128(precision, scale),
267                    // DECIMAL256_MAX_PRECISION
268                    39.. => DataType::Decimal256(precision, scale),
269                }
270            }
271            DType::Utf8(_) => DataType::Utf8View,
272            DType::Binary(_) => DataType::BinaryView,
273            // There are four kinds of lists: List (32-bit offsets), Large List (64-bit), List View
274            // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an
275            // Arrow dtype because we do not how large our offsets are.
276            DType::List(elem_dtype, _) => DataType::List(FieldRef::new(Field::new_list_field(
277                elem_dtype.to_arrow_dtype()?,
278                elem_dtype.nullability().into(),
279            ))),
280            DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList(
281                FieldRef::new(Field::new_list_field(
282                    elem_dtype.to_arrow_dtype()?,
283                    elem_dtype.nullability().into(),
284                )),
285                *size as i32,
286            ),
287            DType::Struct(struct_dtype, _) => {
288                let mut fields = Vec::with_capacity(struct_dtype.names().len());
289                for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
290                {
291                    fields.push(FieldRef::from(Field::new(
292                        field_name.as_ref(),
293                        field_dt.to_arrow_dtype()?,
294                        field_dt.is_nullable(),
295                    )));
296                }
297
298                DataType::Struct(Fields::from(fields))
299            }
300            DType::Extension(ext_dtype) => {
301                // Try and match against the known extension DTypes.
302                if let Some(temporal) = ext_dtype.metadata_opt::<AnyTemporal>() {
303                    return Ok(match temporal {
304                        TemporalMetadata::Timestamp(unit, tz) => {
305                            DataType::Timestamp(ArrowTimeUnit::try_from(*unit)?, tz.clone())
306                        }
307                        TemporalMetadata::Date(unit) => match unit {
308                            TimeUnit::Days => DataType::Date32,
309                            TimeUnit::Milliseconds => DataType::Date64,
310                            TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => {
311                                vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id())
312                            }
313                        },
314                        TemporalMetadata::Time(unit) => match unit {
315                            TimeUnit::Seconds => DataType::Time32(ArrowTimeUnit::Second),
316                            TimeUnit::Milliseconds => DataType::Time32(ArrowTimeUnit::Millisecond),
317                            TimeUnit::Microseconds => DataType::Time64(ArrowTimeUnit::Microsecond),
318                            TimeUnit::Nanoseconds => DataType::Time64(ArrowTimeUnit::Nanosecond),
319                            TimeUnit::Days => {
320                                vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id())
321                            }
322                        },
323                    });
324                };
325
326                vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
327            }
328        })
329    }
330}
331
332#[cfg(test)]
333mod test {
334    use arrow_schema::DataType;
335    use arrow_schema::Field;
336    use arrow_schema::FieldRef;
337    use arrow_schema::Fields;
338    use arrow_schema::Schema;
339    use rstest::fixture;
340    use rstest::rstest;
341
342    use super::*;
343    use crate::DType;
344    use crate::FieldName;
345    use crate::FieldNames;
346    use crate::Nullability;
347    use crate::PType;
348    use crate::StructFields;
349
350    #[test]
351    fn test_dtype_conversion_success() {
352        assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
353
354        assert_eq!(
355            DType::Bool(Nullability::NonNullable)
356                .to_arrow_dtype()
357                .unwrap(),
358            DataType::Boolean
359        );
360
361        assert_eq!(
362            DType::Primitive(PType::U64, Nullability::NonNullable)
363                .to_arrow_dtype()
364                .unwrap(),
365            DataType::UInt64
366        );
367
368        assert_eq!(
369            DType::Utf8(Nullability::NonNullable)
370                .to_arrow_dtype()
371                .unwrap(),
372            DataType::Utf8View
373        );
374
375        assert_eq!(
376            DType::Binary(Nullability::NonNullable)
377                .to_arrow_dtype()
378                .unwrap(),
379            DataType::BinaryView
380        );
381
382        assert_eq!(
383            DType::struct_(
384                [
385                    ("field_a", DType::Bool(false.into())),
386                    ("field_b", DType::Utf8(true.into()))
387                ],
388                Nullability::NonNullable,
389            )
390            .to_arrow_dtype()
391            .unwrap(),
392            DataType::Struct(Fields::from(vec![
393                FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
394                FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
395            ]))
396        );
397    }
398
399    #[test]
400    fn infer_nullable_list_element() {
401        let list_non_nullable = DType::List(
402            Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
403            Nullability::Nullable,
404        );
405
406        let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
407
408        let list_nullable = DType::List(
409            Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
410            Nullability::Nullable,
411        );
412        let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
413
414        assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
415        assert_eq!(
416            arrow_list_nullable,
417            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
418        );
419        assert_eq!(
420            arrow_list_non_nullable,
421            DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
422        );
423    }
424
425    #[fixture]
426    fn the_struct() -> StructFields {
427        StructFields::new(
428            FieldNames::from([
429                FieldName::from("field_a"),
430                FieldName::from("field_b"),
431                FieldName::from("field_c"),
432            ]),
433            vec![
434                DType::Bool(Nullability::NonNullable),
435                DType::Utf8(Nullability::NonNullable),
436                DType::Primitive(PType::I32, Nullability::Nullable),
437            ],
438        )
439    }
440
441    #[rstest]
442    fn test_schema_conversion(the_struct: StructFields) {
443        let schema_nonnull = DType::Struct(the_struct, Nullability::NonNullable);
444
445        assert_eq!(
446            schema_nonnull.to_arrow_schema().unwrap(),
447            Schema::new(Fields::from(vec![
448                Field::new("field_a", DataType::Boolean, false),
449                Field::new("field_b", DataType::Utf8View, false),
450                Field::new("field_c", DataType::Int32, true),
451            ]))
452        );
453    }
454
455    #[rstest]
456    #[should_panic]
457    fn test_schema_conversion_panics(the_struct: StructFields) {
458        let schema_null = DType::Struct(the_struct, Nullability::Nullable);
459        schema_null.to_arrow_schema().unwrap();
460    }
461
462    #[test]
463    fn test_unicode_field_names_roundtrip() {
464        // Regression test for https://github.com/vortex-data/vortex/issues/5979.
465
466        // Unicode characters in field names should survive an Arrow roundtrip without
467        // double-escaping.
468        let unicode_field_name = "\u{5}=A";
469        let original_dtype = DType::struct_(
470            [(
471                unicode_field_name,
472                DType::Primitive(PType::I8, Nullability::Nullable),
473            )],
474            Nullability::NonNullable,
475        );
476
477        let arrow_dtype = original_dtype.to_arrow_dtype().unwrap();
478        let roundtripped_dtype = DType::from_arrow((&arrow_dtype, Nullability::NonNullable));
479
480        assert_eq!(original_dtype, roundtripped_dtype);
481    }
482
483    #[test]
484    fn test_unicode_field_names_nested_roundtrip() {
485        // Regression test for https://github.com/vortex-data/vortex/issues/5979.
486
487        // Nested structs with unicode field names should also survive an Arrow roundtrip.
488        let inner_struct = DType::struct_(
489            [(
490                "\u{6}=inner",
491                DType::Primitive(PType::I32, Nullability::Nullable),
492            )],
493            Nullability::Nullable,
494        );
495        let original_dtype =
496            DType::struct_([("\u{7}=outer", inner_struct)], Nullability::NonNullable);
497
498        let arrow_dtype = original_dtype.to_arrow_dtype().unwrap();
499        let roundtripped_dtype = DType::from_arrow((&arrow_dtype, Nullability::NonNullable));
500
501        assert_eq!(original_dtype, roundtripped_dtype);
502    }
503}