1use std::sync::Arc;
17
18use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef};
19use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
20
21use crate::datetime::arrow::{make_arrow_temporal_dtype, make_temporal_ext_dtype};
22use crate::datetime::is_temporal_ext_type;
23use crate::{DType, DecimalDType, FieldName, Nullability, PType, StructFields};
24
25pub trait FromArrowType<T>: Sized {
27 fn from_arrow(value: T) -> Self;
29}
30
31pub trait TryFromArrowType<T>: Sized {
33 fn try_from_arrow(value: T) -> VortexResult<Self>;
35}
36
37impl TryFromArrowType<&DataType> for PType {
38 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
39 match value {
40 DataType::Int8 => Ok(Self::I8),
41 DataType::Int16 => Ok(Self::I16),
42 DataType::Int32 => Ok(Self::I32),
43 DataType::Int64 => Ok(Self::I64),
44 DataType::UInt8 => Ok(Self::U8),
45 DataType::UInt16 => Ok(Self::U16),
46 DataType::UInt32 => Ok(Self::U32),
47 DataType::UInt64 => Ok(Self::U64),
48 DataType::Float16 => Ok(Self::F16),
49 DataType::Float32 => Ok(Self::F32),
50 DataType::Float64 => Ok(Self::F64),
51 _ => Err(vortex_err!(
52 "Arrow datatype {:?} cannot be converted to ptype",
53 value
54 )),
55 }
56 }
57}
58
59impl TryFromArrowType<&DataType> for DecimalDType {
60 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
61 match value {
62 DataType::Decimal32(precision, scale)
63 | DataType::Decimal64(precision, scale)
64 | DataType::Decimal128(precision, scale)
65 | DataType::Decimal256(precision, scale) => Self::try_new(*precision, *scale),
66
67 _ => Err(vortex_err!(
68 "Arrow datatype {:?} cannot be converted to DecimalDType",
69 value
70 )),
71 }
72 }
73}
74
75impl FromArrowType<SchemaRef> for DType {
76 fn from_arrow(value: SchemaRef) -> Self {
77 Self::from_arrow(value.as_ref())
78 }
79}
80
81impl FromArrowType<&Schema> for DType {
82 fn from_arrow(value: &Schema) -> Self {
83 Self::Struct(
84 StructFields::from_arrow(value.fields()),
85 Nullability::NonNullable, )
87 }
88}
89
90impl FromArrowType<&Fields> for StructFields {
91 fn from_arrow(value: &Fields) -> Self {
92 StructFields::from_iter(value.into_iter().map(|f| {
93 (
94 FieldName::from(f.name().as_str()),
95 DType::from_arrow(f.as_ref()),
96 )
97 }))
98 }
99}
100
101impl FromArrowType<(&DataType, Nullability)> for DType {
102 fn from_arrow((data_type, nullability): (&DataType, Nullability)) -> Self {
103 if data_type.is_integer() || data_type.is_floating() {
104 return DType::Primitive(
105 PType::try_from_arrow(data_type).vortex_expect("arrow float/integer to ptype"),
106 nullability,
107 );
108 }
109
110 match data_type {
111 DataType::Null => DType::Null,
112 DataType::Decimal32(precision, scale)
113 | DataType::Decimal64(precision, scale)
114 | DataType::Decimal128(precision, scale)
115 | DataType::Decimal256(precision, scale) => {
116 DType::Decimal(DecimalDType::new(*precision, *scale), nullability)
117 }
118 DataType::Boolean => DType::Bool(nullability),
119 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => DType::Utf8(nullability),
120 DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
121 DType::Binary(nullability)
122 }
123 DataType::Date32
124 | DataType::Date64
125 | DataType::Time32(_)
126 | DataType::Time64(_)
127 | DataType::Timestamp(..) => DType::Extension(Arc::new(
128 make_temporal_ext_dtype(data_type).with_nullability(nullability),
129 )),
130 DataType::List(e)
131 | DataType::LargeList(e)
132 | DataType::ListView(e)
133 | DataType::LargeListView(e) => {
134 DType::List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
135 }
136 DataType::FixedSizeList(e, size) => DType::FixedSizeList(
137 Arc::new(Self::from_arrow(e.as_ref())),
138 *size as u32,
139 nullability,
140 ),
141 DataType::Struct(f) => DType::Struct(StructFields::from_arrow(f), nullability),
142 DataType::Dictionary(_, value_type) => {
143 Self::from_arrow((value_type.as_ref(), nullability))
144 }
145 _ => unimplemented!("Arrow data type not yet supported: {:?}", data_type),
146 }
147 }
148}
149
150impl FromArrowType<&Field> for DType {
151 fn from_arrow(field: &Field) -> Self {
152 Self::from_arrow((field.data_type(), field.is_nullable().into()))
153 }
154}
155
156impl DType {
157 pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
159 let DType::Struct(struct_dtype, nullable) = self else {
160 vortex_bail!("only DType::Struct can be converted to arrow schema");
161 };
162
163 if *nullable != Nullability::NonNullable {
164 vortex_bail!("top-level struct in Schema must be NonNullable");
165 }
166
167 let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
168 for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
169 builder.push(FieldRef::from(Field::new(
170 field_name.to_string(),
171 field_dtype.to_arrow_dtype()?,
172 field_dtype.is_nullable(),
173 )));
174 }
175
176 Ok(builder.finish())
177 }
178
179 pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
181 Ok(match self {
182 DType::Null => DataType::Null,
183 DType::Bool(_) => DataType::Boolean,
184 DType::Primitive(ptype, _) => match ptype {
185 PType::U8 => DataType::UInt8,
186 PType::U16 => DataType::UInt16,
187 PType::U32 => DataType::UInt32,
188 PType::U64 => DataType::UInt64,
189 PType::I8 => DataType::Int8,
190 PType::I16 => DataType::Int16,
191 PType::I32 => DataType::Int32,
192 PType::I64 => DataType::Int64,
193 PType::F16 => DataType::Float16,
194 PType::F32 => DataType::Float32,
195 PType::F64 => DataType::Float64,
196 },
197 DType::Decimal(dt, _) => {
198 let precision = dt.precision();
199 let scale = dt.scale();
200
201 match precision {
202 0..=38 => DataType::Decimal128(precision, scale),
209 39.. => DataType::Decimal256(precision, scale),
211 }
212 }
213 DType::Utf8(_) => DataType::Utf8View,
214 DType::Binary(_) => DataType::BinaryView,
215 DType::List(elem_dtype, _) => DataType::List(FieldRef::new(Field::new_list_field(
219 elem_dtype.to_arrow_dtype()?,
220 elem_dtype.nullability().into(),
221 ))),
222 DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList(
223 FieldRef::new(Field::new_list_field(
224 elem_dtype.to_arrow_dtype()?,
225 elem_dtype.nullability().into(),
226 )),
227 *size as i32,
228 ),
229 DType::Struct(struct_dtype, _) => {
230 let mut fields = Vec::with_capacity(struct_dtype.names().len());
231 for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
232 {
233 fields.push(FieldRef::from(Field::new(
234 field_name.to_string(),
235 field_dt.to_arrow_dtype()?,
236 field_dt.is_nullable(),
237 )));
238 }
239
240 DataType::Struct(Fields::from(fields))
241 }
242 DType::Extension(ext_dtype) => {
243 if is_temporal_ext_type(ext_dtype.id()) {
245 make_arrow_temporal_dtype(ext_dtype)
246 } else {
247 vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
248 }
249 }
250 })
251 }
252}
253
254#[cfg(test)]
255mod test {
256 use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
257 use rstest::{fixture, rstest};
258
259 use super::*;
260 use crate::{DType, ExtDType, ExtID, FieldName, FieldNames, Nullability, PType, StructFields};
261
262 #[test]
263 fn test_dtype_conversion_success() {
264 assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
265
266 assert_eq!(
267 DType::Bool(Nullability::NonNullable)
268 .to_arrow_dtype()
269 .unwrap(),
270 DataType::Boolean
271 );
272
273 assert_eq!(
274 DType::Primitive(PType::U64, Nullability::NonNullable)
275 .to_arrow_dtype()
276 .unwrap(),
277 DataType::UInt64
278 );
279
280 assert_eq!(
281 DType::Utf8(Nullability::NonNullable)
282 .to_arrow_dtype()
283 .unwrap(),
284 DataType::Utf8View
285 );
286
287 assert_eq!(
288 DType::Binary(Nullability::NonNullable)
289 .to_arrow_dtype()
290 .unwrap(),
291 DataType::BinaryView
292 );
293
294 assert_eq!(
295 DType::struct_(
296 [
297 ("field_a", DType::Bool(false.into())),
298 ("field_b", DType::Utf8(true.into()))
299 ],
300 Nullability::NonNullable,
301 )
302 .to_arrow_dtype()
303 .unwrap(),
304 DataType::Struct(Fields::from(vec![
305 FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
306 FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
307 ]))
308 );
309 }
310
311 #[test]
312 fn infer_nullable_list_element() {
313 let list_non_nullable = DType::List(
314 Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
315 Nullability::Nullable,
316 );
317
318 let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
319
320 let list_nullable = DType::List(
321 Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
322 Nullability::Nullable,
323 );
324 let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
325
326 assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
327 assert_eq!(
328 arrow_list_nullable,
329 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
330 );
331 assert_eq!(
332 arrow_list_non_nullable,
333 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
334 );
335 }
336
337 #[test]
338 #[should_panic]
339 fn test_dtype_conversion_panics() {
340 DType::Extension(Arc::new(ExtDType::new(
341 ExtID::from("my-fake-ext-dtype"),
342 Arc::new(DType::Utf8(Nullability::NonNullable)),
343 None,
344 )))
345 .to_arrow_dtype()
346 .unwrap();
347 }
348
349 #[fixture]
350 fn the_struct() -> StructFields {
351 StructFields::new(
352 FieldNames::from([
353 FieldName::from("field_a"),
354 FieldName::from("field_b"),
355 FieldName::from("field_c"),
356 ]),
357 vec![
358 DType::Bool(Nullability::NonNullable),
359 DType::Utf8(Nullability::NonNullable),
360 DType::Primitive(PType::I32, Nullability::Nullable),
361 ],
362 )
363 }
364
365 #[rstest]
366 fn test_schema_conversion(the_struct: StructFields) {
367 let schema_nonnull = DType::Struct(the_struct, Nullability::NonNullable);
368
369 assert_eq!(
370 schema_nonnull.to_arrow_schema().unwrap(),
371 Schema::new(Fields::from(vec![
372 Field::new("field_a", DataType::Boolean, false),
373 Field::new("field_b", DataType::Utf8View, false),
374 Field::new("field_c", DataType::Int32, true),
375 ]))
376 );
377 }
378
379 #[rstest]
380 #[should_panic]
381 fn test_schema_conversion_panics(the_struct: StructFields) {
382 let schema_null = DType::Struct(the_struct, Nullability::Nullable);
383 schema_null.to_arrow_schema().unwrap();
384 }
385}