vortex_dtype/
arrow.rs

1//! Convert between Vortex [`crate::DType`] and Apache Arrow [`arrow_schema::DataType`].
2//!
3//! Apache Arrow's type system includes physical information, which could lead to ambiguities as
4//! Vortex treats encodings as separate from logical types.
5//!
6//! [`DType::to_arrow_schema`] and its sibling [`DType::to_arrow_dtype`] use a simple algorithm,
7//! where every logical type is encoded in its simplest corresponding Arrow type. This reflects the
8//! reality that most compute engines don't make use of the entire type range arrow-rs supports.
9//!
10//! For this reason, it's recommended to do as much computation as possible within Vortex, and then
11//! materialize an Arrow ArrayRef at the very end of the processing chain.
12
13use std::sync::Arc;
14
15use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef};
16use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
17
18use crate::datetime::arrow::{make_arrow_temporal_dtype, make_temporal_ext_dtype};
19use crate::datetime::is_temporal_ext_type;
20use crate::{DType, FieldName, Nullability, PType, StructDType};
21
22/// Trait for converting Arrow types to Vortex types.
23pub trait FromArrowType<T>: Sized {
24    /// Convert the Arrow type to a Vortex type.
25    fn from_arrow(value: T) -> Self;
26}
27
28/// Trait for converting Vortex types to Arrow types.
29pub trait TryFromArrowType<T>: Sized {
30    /// Convert the Arrow type to a Vortex type.
31    fn try_from_arrow(value: T) -> VortexResult<Self>;
32}
33
34impl TryFromArrowType<&DataType> for PType {
35    fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
36        match value {
37            DataType::Int8 => Ok(Self::I8),
38            DataType::Int16 => Ok(Self::I16),
39            DataType::Int32 => Ok(Self::I32),
40            DataType::Int64 => Ok(Self::I64),
41            DataType::UInt8 => Ok(Self::U8),
42            DataType::UInt16 => Ok(Self::U16),
43            DataType::UInt32 => Ok(Self::U32),
44            DataType::UInt64 => Ok(Self::U64),
45            DataType::Float16 => Ok(Self::F16),
46            DataType::Float32 => Ok(Self::F32),
47            DataType::Float64 => Ok(Self::F64),
48            _ => Err(vortex_err!(
49                "Arrow datatype {:?} cannot be converted to ptype",
50                value
51            )),
52        }
53    }
54}
55
56impl FromArrowType<SchemaRef> for DType {
57    fn from_arrow(value: SchemaRef) -> Self {
58        Self::from_arrow(value.as_ref())
59    }
60}
61
62impl FromArrowType<&Schema> for DType {
63    fn from_arrow(value: &Schema) -> Self {
64        Self::Struct(
65            Arc::new(StructDType::from_arrow(value.fields())),
66            Nullability::NonNullable, // Must match From<RecordBatch> for Array
67        )
68    }
69}
70
71impl FromArrowType<&Fields> for StructDType {
72    fn from_arrow(value: &Fields) -> Self {
73        StructDType::from_iter(value.into_iter().map(|f| {
74            (
75                FieldName::from(f.name().as_str()),
76                DType::from_arrow(f.as_ref()),
77            )
78        }))
79    }
80}
81
82impl FromArrowType<&Field> for DType {
83    fn from_arrow(field: &Field) -> Self {
84        use crate::DType::*;
85
86        let nullability: Nullability = field.is_nullable().into();
87
88        if field.data_type().is_integer() || field.data_type().is_floating() {
89            return Primitive(
90                PType::try_from_arrow(field.data_type())
91                    .vortex_expect("arrow float/integer to ptype"),
92                nullability,
93            );
94        }
95
96        match field.data_type() {
97            DataType::Null => Null,
98            DataType::Boolean => Bool(nullability),
99            DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Utf8(nullability),
100            DataType::Binary | DataType::LargeBinary | DataType::BinaryView => Binary(nullability),
101            DataType::Date32
102            | DataType::Date64
103            | DataType::Time32(_)
104            | DataType::Time64(_)
105            | DataType::Timestamp(..) => Extension(Arc::new(
106                make_temporal_ext_dtype(field.data_type()).with_nullability(nullability),
107            )),
108            DataType::List(e) | DataType::LargeList(e) => {
109                List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
110            }
111            DataType::Struct(f) => Struct(Arc::new(StructDType::from_arrow(f)), nullability),
112            _ => unimplemented!("Arrow data type not yet supported: {:?}", field.data_type()),
113        }
114    }
115}
116
117impl DType {
118    /// Convert a Vortex [`DType`] into an Arrow [`Schema`].
119    pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
120        let DType::Struct(struct_dtype, nullable) = self else {
121            vortex_bail!("only DType::Struct can be converted to arrow schema");
122        };
123
124        if *nullable != Nullability::NonNullable {
125            vortex_bail!("top-level struct in Schema must be NonNullable");
126        }
127
128        let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
129        for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
130            builder.push(FieldRef::from(Field::new(
131                field_name.to_string(),
132                field_dtype.to_arrow_dtype()?,
133                field_dtype.is_nullable(),
134            )));
135        }
136
137        Ok(builder.finish())
138    }
139
140    /// Returns the Arrow [`DataType`] that best corresponds to this Vortex [`DType`].
141    pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
142        Ok(match self {
143            DType::Null => DataType::Null,
144            DType::Bool(_) => DataType::Boolean,
145            DType::Primitive(ptype, _) => match ptype {
146                PType::U8 => DataType::UInt8,
147                PType::U16 => DataType::UInt16,
148                PType::U32 => DataType::UInt32,
149                PType::U64 => DataType::UInt64,
150                PType::I8 => DataType::Int8,
151                PType::I16 => DataType::Int16,
152                PType::I32 => DataType::Int32,
153                PType::I64 => DataType::Int64,
154                PType::F16 => DataType::Float16,
155                PType::F32 => DataType::Float32,
156                PType::F64 => DataType::Float64,
157            },
158            DType::Utf8(_) => DataType::Utf8View,
159            DType::Binary(_) => DataType::BinaryView,
160            DType::Struct(struct_dtype, _) => {
161                let mut fields = Vec::with_capacity(struct_dtype.names().len());
162                for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
163                {
164                    fields.push(FieldRef::from(Field::new(
165                        field_name.to_string(),
166                        field_dt.to_arrow_dtype()?,
167                        field_dt.is_nullable(),
168                    )));
169                }
170
171                DataType::Struct(Fields::from(fields))
172            }
173            // There are four kinds of lists: List (32-bit offsets), Large List (64-bit), List View
174            // (32-bit), Large List View (64-bit). We cannot both guarantee zero-copy and commit to an
175            // Arrow dtype because we do not how large our offsets are.
176            DType::List(l, _) => DataType::List(FieldRef::new(Field::new_list_field(
177                l.to_arrow_dtype()?,
178                l.nullability().into(),
179            ))),
180            DType::Extension(ext_dtype) => {
181                // Try and match against the known extension DTypes.
182                if is_temporal_ext_type(ext_dtype.id()) {
183                    make_arrow_temporal_dtype(ext_dtype)
184                } else {
185                    vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
186                }
187            }
188        })
189    }
190}
191
192#[cfg(test)]
193mod test {
194    use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
195
196    use super::*;
197    use crate::{DType, ExtDType, ExtID, FieldName, FieldNames, Nullability, PType, StructDType};
198
199    #[test]
200    fn test_dtype_conversion_success() {
201        assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
202
203        assert_eq!(
204            DType::Bool(Nullability::NonNullable)
205                .to_arrow_dtype()
206                .unwrap(),
207            DataType::Boolean
208        );
209
210        assert_eq!(
211            DType::Primitive(PType::U64, Nullability::NonNullable)
212                .to_arrow_dtype()
213                .unwrap(),
214            DataType::UInt64
215        );
216
217        assert_eq!(
218            DType::Utf8(Nullability::NonNullable)
219                .to_arrow_dtype()
220                .unwrap(),
221            DataType::Utf8View
222        );
223
224        assert_eq!(
225            DType::Binary(Nullability::NonNullable)
226                .to_arrow_dtype()
227                .unwrap(),
228            DataType::BinaryView
229        );
230
231        assert_eq!(
232            DType::Struct(
233                Arc::new(StructDType::from_iter([
234                    ("field_a", DType::Bool(false.into())),
235                    ("field_b", DType::Utf8(true.into()))
236                ])),
237                Nullability::NonNullable,
238            )
239            .to_arrow_dtype()
240            .unwrap(),
241            DataType::Struct(Fields::from(vec![
242                FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
243                FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
244            ]))
245        );
246    }
247
248    #[test]
249    fn infer_nullable_list_element() {
250        let list_non_nullable = DType::List(
251            Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
252            Nullability::Nullable,
253        );
254
255        let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
256
257        let list_nullable = DType::List(
258            Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
259            Nullability::Nullable,
260        );
261        let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
262
263        assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
264        assert_eq!(
265            arrow_list_nullable,
266            DataType::new_list(DataType::Int64, true)
267        );
268        assert_eq!(
269            arrow_list_non_nullable,
270            DataType::new_list(DataType::Int64, false)
271        );
272    }
273
274    #[test]
275    #[should_panic]
276    fn test_dtype_conversion_panics() {
277        let _ = DType::Extension(Arc::new(ExtDType::new(
278            ExtID::from("my-fake-ext-dtype"),
279            Arc::new(DType::Utf8(Nullability::NonNullable)),
280            None,
281        )))
282        .to_arrow_dtype()
283        .unwrap();
284    }
285
286    #[test]
287    fn test_schema_conversion() {
288        let struct_dtype = the_struct();
289        let schema_nonnull = DType::Struct(struct_dtype, Nullability::NonNullable);
290
291        assert_eq!(
292            schema_nonnull.to_arrow_schema().unwrap(),
293            Schema::new(Fields::from(vec![
294                Field::new("field_a", DataType::Boolean, false),
295                Field::new("field_b", DataType::Utf8View, false),
296                Field::new("field_c", DataType::Int32, true),
297            ]))
298        );
299    }
300
301    #[test]
302    #[should_panic]
303    fn test_schema_conversion_panics() {
304        let struct_dtype = the_struct();
305        let schema_null = DType::Struct(struct_dtype, Nullability::Nullable);
306        let _ = schema_null.to_arrow_schema().unwrap();
307    }
308
309    fn the_struct() -> Arc<StructDType> {
310        Arc::new(StructDType::new(
311            FieldNames::from([
312                FieldName::from("field_a"),
313                FieldName::from("field_b"),
314                FieldName::from("field_c"),
315            ]),
316            vec![
317                DType::Bool(Nullability::NonNullable),
318                DType::Utf8(Nullability::NonNullable),
319                DType::Primitive(PType::I32, Nullability::Nullable),
320            ],
321        ))
322    }
323}