vortex_dtype/
arrow.rs

1//! Convert between Vortex [`crate::DType`] and Apache Arrow [`arrow_schema::DataType`].
2//!
3//! Apache Arrow's type system includes physical information, which could lead to ambiguities as
4//! Vortex treats encodings as separate from logical types.
5//!
6//! [`DType::to_arrow_schema`] and its sibling [`DType::to_arrow_dtype`] use a simple algorithm,
7//! where every logical type is encoded in its simplest corresponding Arrow type. This reflects the
8//! reality that most compute engines don't make use of the entire type range arrow-rs supports.
9//!
10//! For this reason, it's recommended to do as much computation as possible within Vortex, and then
11//! materialize an Arrow ArrayRef at the very end of the processing chain.
12
13use std::sync::Arc;
14
15use arrow_schema::{
16    DECIMAL128_MAX_SCALE, DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef,
17};
18use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
19
20use crate::datetime::arrow::{make_arrow_temporal_dtype, make_temporal_ext_dtype};
21use crate::datetime::is_temporal_ext_type;
22use crate::{DType, DecimalDType, FieldName, Nullability, PType, StructFields};
23
24/// Trait for converting Arrow types to Vortex types.
25pub trait FromArrowType<T>: Sized {
26    /// Convert the Arrow type to a Vortex type.
27    fn from_arrow(value: T) -> Self;
28}
29
30/// Trait for converting Vortex types to Arrow types.
31pub trait TryFromArrowType<T>: Sized {
32    /// Convert the Arrow type to a Vortex type.
33    fn try_from_arrow(value: T) -> VortexResult<Self>;
34}
35
36impl TryFromArrowType<&DataType> for PType {
37    fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
38        match value {
39            DataType::Int8 => Ok(Self::I8),
40            DataType::Int16 => Ok(Self::I16),
41            DataType::Int32 => Ok(Self::I32),
42            DataType::Int64 => Ok(Self::I64),
43            DataType::UInt8 => Ok(Self::U8),
44            DataType::UInt16 => Ok(Self::U16),
45            DataType::UInt32 => Ok(Self::U32),
46            DataType::UInt64 => Ok(Self::U64),
47            DataType::Float16 => Ok(Self::F16),
48            DataType::Float32 => Ok(Self::F32),
49            DataType::Float64 => Ok(Self::F64),
50            _ => Err(vortex_err!(
51                "Arrow datatype {:?} cannot be converted to ptype",
52                value
53            )),
54        }
55    }
56}
57
58impl TryFromArrowType<&DataType> for DecimalDType {
59    fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
60        match value {
61            DataType::Decimal128(precision, scale) => Ok(Self::new(*precision, *scale)),
62            DataType::Decimal256(precision, scale) => Ok(Self::new(*precision, *scale)),
63            _ => Err(vortex_err!(
64                "Arrow datatype {:?} cannot be converted to DecimalDType",
65                value
66            )),
67        }
68    }
69}
70
71impl FromArrowType<SchemaRef> for DType {
72    fn from_arrow(value: SchemaRef) -> Self {
73        Self::from_arrow(value.as_ref())
74    }
75}
76
77impl FromArrowType<&Schema> for DType {
78    fn from_arrow(value: &Schema) -> Self {
79        Self::Struct(
80            Arc::new(StructFields::from_arrow(value.fields())),
81            Nullability::NonNullable, // Must match From<RecordBatch> for Array
82        )
83    }
84}
85
86impl FromArrowType<&Fields> for StructFields {
87    fn from_arrow(value: &Fields) -> Self {
88        StructFields::from_iter(value.into_iter().map(|f| {
89            (
90                FieldName::from(f.name().as_str()),
91                DType::from_arrow(f.as_ref()),
92            )
93        }))
94    }
95}
96
97impl FromArrowType<(&DataType, Nullability)> for DType {
98    fn from_arrow((data_type, nullability): (&DataType, Nullability)) -> Self {
99        use crate::DType::*;
100
101        if data_type.is_integer() || data_type.is_floating() {
102            return Primitive(
103                PType::try_from_arrow(data_type).vortex_expect("arrow float/integer to ptype"),
104                nullability,
105            );
106        }
107
108        match data_type {
109            DataType::Null => Null,
110            DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => {
111                Decimal(DecimalDType::new(*precision, *scale), nullability)
112            }
113            DataType::Boolean => Bool(nullability),
114            DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Utf8(nullability),
115            DataType::Binary | DataType::LargeBinary | DataType::BinaryView => Binary(nullability),
116            DataType::Date32
117            | DataType::Date64
118            | DataType::Time32(_)
119            | DataType::Time64(_)
120            | DataType::Timestamp(..) => Extension(Arc::new(
121                make_temporal_ext_dtype(data_type).with_nullability(nullability),
122            )),
123            DataType::List(e) | DataType::LargeList(e) => {
124                List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
125            }
126            DataType::Struct(f) => Struct(Arc::new(StructFields::from_arrow(f)), nullability),
127            _ => unimplemented!("Arrow data type not yet supported: {:?}", data_type),
128        }
129    }
130}
131
132impl FromArrowType<&Field> for DType {
133    fn from_arrow(field: &Field) -> Self {
134        Self::from_arrow((field.data_type(), field.is_nullable().into()))
135    }
136}
137
138impl DType {
139    /// Convert a Vortex [`DType`] into an Arrow [`Schema`].
140    pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
141        let DType::Struct(struct_dtype, nullable) = self else {
142            vortex_bail!("only DType::Struct can be converted to arrow schema");
143        };
144
145        if *nullable != Nullability::NonNullable {
146            vortex_bail!("top-level struct in Schema must be NonNullable");
147        }
148
149        let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
150        for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
151            builder.push(FieldRef::from(Field::new(
152                field_name.to_string(),
153                field_dtype.to_arrow_dtype()?,
154                field_dtype.is_nullable(),
155            )));
156        }
157
158        Ok(builder.finish())
159    }
160
161    /// Returns the Arrow [`DataType`] that best corresponds to this Vortex [`DType`].
162    pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
163        Ok(match self {
164            DType::Null => DataType::Null,
165            DType::Bool(_) => DataType::Boolean,
166            DType::Primitive(ptype, _) => match ptype {
167                PType::U8 => DataType::UInt8,
168                PType::U16 => DataType::UInt16,
169                PType::U32 => DataType::UInt32,
170                PType::U64 => DataType::UInt64,
171                PType::I8 => DataType::Int8,
172                PType::I16 => DataType::Int16,
173                PType::I32 => DataType::Int32,
174                PType::I64 => DataType::Int64,
175                PType::F16 => DataType::Float16,
176                PType::F32 => DataType::Float32,
177                PType::F64 => DataType::Float64,
178            },
179            DType::Decimal(dt, _) => {
180                if dt.scale() > DECIMAL128_MAX_SCALE {
181                    DataType::Decimal256(dt.precision(), dt.scale())
182                } else {
183                    DataType::Decimal128(dt.precision(), dt.scale())
184                }
185            }
186            DType::Utf8(_) => DataType::Utf8View,
187            DType::Binary(_) => DataType::BinaryView,
188            DType::Struct(struct_dtype, _) => {
189                let mut fields = Vec::with_capacity(struct_dtype.names().len());
190                for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
191                {
192                    fields.push(FieldRef::from(Field::new(
193                        field_name.to_string(),
194                        field_dt.to_arrow_dtype()?,
195                        field_dt.is_nullable(),
196                    )));
197                }
198
199                DataType::Struct(Fields::from(fields))
200            }
201            // There are four kinds of lists: List (32-bit offsets), Large List (64-bit), List View
202            // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an
203            // Arrow dtype because we do not how large our offsets are.
204            DType::List(l, _) => DataType::List(FieldRef::new(Field::new_list_field(
205                l.to_arrow_dtype()?,
206                l.nullability().into(),
207            ))),
208            DType::Extension(ext_dtype) => {
209                // Try and match against the known extension DTypes.
210                if is_temporal_ext_type(ext_dtype.id()) {
211                    make_arrow_temporal_dtype(ext_dtype)
212                } else {
213                    vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
214                }
215            }
216        })
217    }
218}
219
220#[cfg(test)]
221mod test {
222    use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
223
224    use super::*;
225    use crate::{DType, ExtDType, ExtID, FieldName, FieldNames, Nullability, PType, StructFields};
226
227    #[test]
228    fn test_dtype_conversion_success() {
229        assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
230
231        assert_eq!(
232            DType::Bool(Nullability::NonNullable)
233                .to_arrow_dtype()
234                .unwrap(),
235            DataType::Boolean
236        );
237
238        assert_eq!(
239            DType::Primitive(PType::U64, Nullability::NonNullable)
240                .to_arrow_dtype()
241                .unwrap(),
242            DataType::UInt64
243        );
244
245        assert_eq!(
246            DType::Utf8(Nullability::NonNullable)
247                .to_arrow_dtype()
248                .unwrap(),
249            DataType::Utf8View
250        );
251
252        assert_eq!(
253            DType::Binary(Nullability::NonNullable)
254                .to_arrow_dtype()
255                .unwrap(),
256            DataType::BinaryView
257        );
258
259        assert_eq!(
260            DType::Struct(
261                Arc::new(StructFields::from_iter([
262                    ("field_a", DType::Bool(false.into())),
263                    ("field_b", DType::Utf8(true.into()))
264                ])),
265                Nullability::NonNullable,
266            )
267            .to_arrow_dtype()
268            .unwrap(),
269            DataType::Struct(Fields::from(vec![
270                FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
271                FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
272            ]))
273        );
274    }
275
276    #[test]
277    fn infer_nullable_list_element() {
278        let list_non_nullable = DType::List(
279            Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
280            Nullability::Nullable,
281        );
282
283        let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
284
285        let list_nullable = DType::List(
286            Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
287            Nullability::Nullable,
288        );
289        let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
290
291        assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
292        assert_eq!(
293            arrow_list_nullable,
294            DataType::new_list(DataType::Int64, true)
295        );
296        assert_eq!(
297            arrow_list_non_nullable,
298            DataType::new_list(DataType::Int64, false)
299        );
300    }
301
302    #[test]
303    #[should_panic]
304    fn test_dtype_conversion_panics() {
305        let _ = DType::Extension(Arc::new(ExtDType::new(
306            ExtID::from("my-fake-ext-dtype"),
307            Arc::new(DType::Utf8(Nullability::NonNullable)),
308            None,
309        )))
310        .to_arrow_dtype()
311        .unwrap();
312    }
313
314    #[test]
315    fn test_schema_conversion() {
316        let struct_dtype = the_struct();
317        let schema_nonnull = DType::Struct(struct_dtype, Nullability::NonNullable);
318
319        assert_eq!(
320            schema_nonnull.to_arrow_schema().unwrap(),
321            Schema::new(Fields::from(vec![
322                Field::new("field_a", DataType::Boolean, false),
323                Field::new("field_b", DataType::Utf8View, false),
324                Field::new("field_c", DataType::Int32, true),
325            ]))
326        );
327    }
328
329    #[test]
330    #[should_panic]
331    fn test_schema_conversion_panics() {
332        let struct_dtype = the_struct();
333        let schema_null = DType::Struct(struct_dtype, Nullability::Nullable);
334        let _ = schema_null.to_arrow_schema().unwrap();
335    }
336
337    fn the_struct() -> Arc<StructFields> {
338        Arc::new(StructFields::new(
339            FieldNames::from([
340                FieldName::from("field_a"),
341                FieldName::from("field_b"),
342                FieldName::from("field_c"),
343            ]),
344            vec![
345                DType::Bool(Nullability::NonNullable),
346                DType::Utf8(Nullability::NonNullable),
347                DType::Primitive(PType::I32, Nullability::Nullable),
348            ],
349        ))
350    }
351}