1use std::sync::Arc;
17
18use arrow_schema::DataType;
19use arrow_schema::Field;
20use arrow_schema::FieldRef;
21use arrow_schema::Fields;
22use arrow_schema::Schema;
23use arrow_schema::SchemaBuilder;
24use arrow_schema::SchemaRef;
25use arrow_schema::TimeUnit as ArrowTimeUnit;
26use vortex_error::VortexError;
27use vortex_error::VortexExpect;
28use vortex_error::VortexResult;
29use vortex_error::vortex_bail;
30use vortex_error::vortex_err;
31use vortex_error::vortex_panic;
32
33use crate::DType;
34use crate::DecimalDType;
35use crate::FieldName;
36use crate::Nullability;
37use crate::PType;
38use crate::StructFields;
39use crate::datetime::AnyTemporal;
40use crate::datetime::Date;
41use crate::datetime::TemporalMetadata;
42use crate::datetime::Time;
43use crate::datetime::TimeUnit;
44use crate::datetime::Timestamp;
45
46pub trait FromArrowType<T>: Sized {
48 fn from_arrow(value: T) -> Self;
50}
51
52pub trait TryFromArrowType<T>: Sized {
54 fn try_from_arrow(value: T) -> VortexResult<Self>;
56}
57
58impl TryFromArrowType<&DataType> for PType {
59 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
60 match value {
61 DataType::Int8 => Ok(Self::I8),
62 DataType::Int16 => Ok(Self::I16),
63 DataType::Int32 => Ok(Self::I32),
64 DataType::Int64 => Ok(Self::I64),
65 DataType::UInt8 => Ok(Self::U8),
66 DataType::UInt16 => Ok(Self::U16),
67 DataType::UInt32 => Ok(Self::U32),
68 DataType::UInt64 => Ok(Self::U64),
69 DataType::Float16 => Ok(Self::F16),
70 DataType::Float32 => Ok(Self::F32),
71 DataType::Float64 => Ok(Self::F64),
72 _ => Err(vortex_err!(
73 "Arrow datatype {:?} cannot be converted to ptype",
74 value
75 )),
76 }
77 }
78}
79
80impl TryFromArrowType<&DataType> for DecimalDType {
81 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
82 match value {
83 DataType::Decimal32(precision, scale)
84 | DataType::Decimal64(precision, scale)
85 | DataType::Decimal128(precision, scale)
86 | DataType::Decimal256(precision, scale) => Self::try_new(*precision, *scale),
87
88 _ => Err(vortex_err!(
89 "Arrow datatype {:?} cannot be converted to DecimalDType",
90 value
91 )),
92 }
93 }
94}
95
96impl From<&ArrowTimeUnit> for TimeUnit {
97 fn from(value: &ArrowTimeUnit) -> Self {
98 (*value).into()
99 }
100}
101
102impl From<ArrowTimeUnit> for TimeUnit {
103 fn from(value: ArrowTimeUnit) -> Self {
104 match value {
105 ArrowTimeUnit::Second => Self::Seconds,
106 ArrowTimeUnit::Millisecond => Self::Milliseconds,
107 ArrowTimeUnit::Microsecond => Self::Microseconds,
108 ArrowTimeUnit::Nanosecond => Self::Nanoseconds,
109 }
110 }
111}
112
113impl TryFrom<TimeUnit> for ArrowTimeUnit {
114 type Error = VortexError;
115
116 fn try_from(value: TimeUnit) -> VortexResult<Self> {
117 Ok(match value {
118 TimeUnit::Seconds => Self::Second,
119 TimeUnit::Milliseconds => Self::Millisecond,
120 TimeUnit::Microseconds => Self::Microsecond,
121 TimeUnit::Nanoseconds => Self::Nanosecond,
122 _ => vortex_bail!("Cannot convert {value} to Arrow TimeUnit"),
123 })
124 }
125}
126
127impl FromArrowType<SchemaRef> for DType {
128 fn from_arrow(value: SchemaRef) -> Self {
129 Self::from_arrow(value.as_ref())
130 }
131}
132
133impl FromArrowType<&Schema> for DType {
134 fn from_arrow(value: &Schema) -> Self {
135 Self::Struct(
136 StructFields::from_arrow(value.fields()),
137 Nullability::NonNullable, )
139 }
140}
141
142impl FromArrowType<&Fields> for StructFields {
143 fn from_arrow(value: &Fields) -> Self {
144 StructFields::from_iter(value.into_iter().map(|f| {
145 (
146 FieldName::from(f.name().as_str()),
147 DType::from_arrow(f.as_ref()),
148 )
149 }))
150 }
151}
152
153impl FromArrowType<(&DataType, Nullability)> for DType {
154 fn from_arrow((data_type, nullability): (&DataType, Nullability)) -> Self {
155 if data_type.is_integer() || data_type.is_floating() {
156 return DType::Primitive(
157 PType::try_from_arrow(data_type).vortex_expect("arrow float/integer to ptype"),
158 nullability,
159 );
160 }
161
162 match data_type {
163 DataType::Null => DType::Null,
164 DataType::Decimal32(precision, scale)
165 | DataType::Decimal64(precision, scale)
166 | DataType::Decimal128(precision, scale)
167 | DataType::Decimal256(precision, scale) => {
168 DType::Decimal(DecimalDType::new(*precision, *scale), nullability)
169 }
170 DataType::Boolean => DType::Bool(nullability),
171 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => DType::Utf8(nullability),
172 DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
173 DType::Binary(nullability)
174 }
175 DataType::Date32 => DType::Extension(Date::new(TimeUnit::Days, nullability).erased()),
176 DataType::Date64 => {
177 DType::Extension(Date::new(TimeUnit::Milliseconds, nullability).erased())
178 }
179 DataType::Time32(unit) => {
180 DType::Extension(Time::new(unit.into(), nullability).erased())
181 }
182 DataType::Time64(unit) => {
183 DType::Extension(Time::new(unit.into(), nullability).erased())
184 }
185 DataType::Timestamp(unit, tz) => DType::Extension(
186 Timestamp::new_with_tz(unit.into(), tz.clone(), nullability).erased(),
187 ),
188 DataType::List(e)
189 | DataType::LargeList(e)
190 | DataType::ListView(e)
191 | DataType::LargeListView(e) => {
192 DType::List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
193 }
194 DataType::FixedSizeList(e, size) => DType::FixedSizeList(
195 Arc::new(Self::from_arrow(e.as_ref())),
196 *size as u32,
197 nullability,
198 ),
199 DataType::Struct(f) => DType::Struct(StructFields::from_arrow(f), nullability),
200 DataType::Dictionary(_, value_type) => {
201 Self::from_arrow((value_type.as_ref(), nullability))
202 }
203 _ => unimplemented!("Arrow data type not yet supported: {:?}", data_type),
204 }
205 }
206}
207
208impl FromArrowType<&Field> for DType {
209 fn from_arrow(field: &Field) -> Self {
210 Self::from_arrow((field.data_type(), field.is_nullable().into()))
211 }
212}
213
214impl DType {
215 pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
217 let DType::Struct(struct_dtype, nullable) = self else {
218 vortex_bail!("only DType::Struct can be converted to arrow schema");
219 };
220
221 if *nullable != Nullability::NonNullable {
222 vortex_bail!("top-level struct in Schema must be NonNullable");
223 }
224
225 let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
226 for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
227 builder.push(FieldRef::from(Field::new(
228 field_name.as_ref(),
229 field_dtype.to_arrow_dtype()?,
230 field_dtype.is_nullable(),
231 )));
232 }
233
234 Ok(builder.finish())
235 }
236
237 pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
239 Ok(match self {
240 DType::Null => DataType::Null,
241 DType::Bool(_) => DataType::Boolean,
242 DType::Primitive(ptype, _) => match ptype {
243 PType::U8 => DataType::UInt8,
244 PType::U16 => DataType::UInt16,
245 PType::U32 => DataType::UInt32,
246 PType::U64 => DataType::UInt64,
247 PType::I8 => DataType::Int8,
248 PType::I16 => DataType::Int16,
249 PType::I32 => DataType::Int32,
250 PType::I64 => DataType::Int64,
251 PType::F16 => DataType::Float16,
252 PType::F32 => DataType::Float32,
253 PType::F64 => DataType::Float64,
254 },
255 DType::Decimal(dt, _) => {
256 let precision = dt.precision();
257 let scale = dt.scale();
258
259 match precision {
260 0..=38 => DataType::Decimal128(precision, scale),
267 39.. => DataType::Decimal256(precision, scale),
269 }
270 }
271 DType::Utf8(_) => DataType::Utf8View,
272 DType::Binary(_) => DataType::BinaryView,
273 DType::List(elem_dtype, _) => DataType::List(FieldRef::new(Field::new_list_field(
277 elem_dtype.to_arrow_dtype()?,
278 elem_dtype.nullability().into(),
279 ))),
280 DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList(
281 FieldRef::new(Field::new_list_field(
282 elem_dtype.to_arrow_dtype()?,
283 elem_dtype.nullability().into(),
284 )),
285 *size as i32,
286 ),
287 DType::Struct(struct_dtype, _) => {
288 let mut fields = Vec::with_capacity(struct_dtype.names().len());
289 for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
290 {
291 fields.push(FieldRef::from(Field::new(
292 field_name.as_ref(),
293 field_dt.to_arrow_dtype()?,
294 field_dt.is_nullable(),
295 )));
296 }
297
298 DataType::Struct(Fields::from(fields))
299 }
300 DType::Extension(ext_dtype) => {
301 if let Some(temporal) = ext_dtype.metadata_opt::<AnyTemporal>() {
303 return Ok(match temporal {
304 TemporalMetadata::Timestamp(unit, tz) => {
305 DataType::Timestamp(ArrowTimeUnit::try_from(*unit)?, tz.clone())
306 }
307 TemporalMetadata::Date(unit) => match unit {
308 TimeUnit::Days => DataType::Date32,
309 TimeUnit::Milliseconds => DataType::Date64,
310 TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => {
311 vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id())
312 }
313 },
314 TemporalMetadata::Time(unit) => match unit {
315 TimeUnit::Seconds => DataType::Time32(ArrowTimeUnit::Second),
316 TimeUnit::Milliseconds => DataType::Time32(ArrowTimeUnit::Millisecond),
317 TimeUnit::Microseconds => DataType::Time64(ArrowTimeUnit::Microsecond),
318 TimeUnit::Nanoseconds => DataType::Time64(ArrowTimeUnit::Nanosecond),
319 TimeUnit::Days => {
320 vortex_panic!(InvalidArgument: "Invalid TimeUnit {} for {}", unit, ext_dtype.id())
321 }
322 },
323 });
324 };
325
326 vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
327 }
328 })
329 }
330}
331
332#[cfg(test)]
333mod test {
334 use arrow_schema::DataType;
335 use arrow_schema::Field;
336 use arrow_schema::FieldRef;
337 use arrow_schema::Fields;
338 use arrow_schema::Schema;
339 use rstest::fixture;
340 use rstest::rstest;
341
342 use super::*;
343 use crate::DType;
344 use crate::FieldName;
345 use crate::FieldNames;
346 use crate::Nullability;
347 use crate::PType;
348 use crate::StructFields;
349
350 #[test]
351 fn test_dtype_conversion_success() {
352 assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
353
354 assert_eq!(
355 DType::Bool(Nullability::NonNullable)
356 .to_arrow_dtype()
357 .unwrap(),
358 DataType::Boolean
359 );
360
361 assert_eq!(
362 DType::Primitive(PType::U64, Nullability::NonNullable)
363 .to_arrow_dtype()
364 .unwrap(),
365 DataType::UInt64
366 );
367
368 assert_eq!(
369 DType::Utf8(Nullability::NonNullable)
370 .to_arrow_dtype()
371 .unwrap(),
372 DataType::Utf8View
373 );
374
375 assert_eq!(
376 DType::Binary(Nullability::NonNullable)
377 .to_arrow_dtype()
378 .unwrap(),
379 DataType::BinaryView
380 );
381
382 assert_eq!(
383 DType::struct_(
384 [
385 ("field_a", DType::Bool(false.into())),
386 ("field_b", DType::Utf8(true.into()))
387 ],
388 Nullability::NonNullable,
389 )
390 .to_arrow_dtype()
391 .unwrap(),
392 DataType::Struct(Fields::from(vec![
393 FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
394 FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
395 ]))
396 );
397 }
398
399 #[test]
400 fn infer_nullable_list_element() {
401 let list_non_nullable = DType::List(
402 Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
403 Nullability::Nullable,
404 );
405
406 let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
407
408 let list_nullable = DType::List(
409 Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
410 Nullability::Nullable,
411 );
412 let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
413
414 assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
415 assert_eq!(
416 arrow_list_nullable,
417 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
418 );
419 assert_eq!(
420 arrow_list_non_nullable,
421 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
422 );
423 }
424
425 #[fixture]
426 fn the_struct() -> StructFields {
427 StructFields::new(
428 FieldNames::from([
429 FieldName::from("field_a"),
430 FieldName::from("field_b"),
431 FieldName::from("field_c"),
432 ]),
433 vec![
434 DType::Bool(Nullability::NonNullable),
435 DType::Utf8(Nullability::NonNullable),
436 DType::Primitive(PType::I32, Nullability::Nullable),
437 ],
438 )
439 }
440
441 #[rstest]
442 fn test_schema_conversion(the_struct: StructFields) {
443 let schema_nonnull = DType::Struct(the_struct, Nullability::NonNullable);
444
445 assert_eq!(
446 schema_nonnull.to_arrow_schema().unwrap(),
447 Schema::new(Fields::from(vec![
448 Field::new("field_a", DataType::Boolean, false),
449 Field::new("field_b", DataType::Utf8View, false),
450 Field::new("field_c", DataType::Int32, true),
451 ]))
452 );
453 }
454
455 #[rstest]
456 #[should_panic]
457 fn test_schema_conversion_panics(the_struct: StructFields) {
458 let schema_null = DType::Struct(the_struct, Nullability::Nullable);
459 schema_null.to_arrow_schema().unwrap();
460 }
461
462 #[test]
463 fn test_unicode_field_names_roundtrip() {
464 let unicode_field_name = "\u{5}=A";
469 let original_dtype = DType::struct_(
470 [(
471 unicode_field_name,
472 DType::Primitive(PType::I8, Nullability::Nullable),
473 )],
474 Nullability::NonNullable,
475 );
476
477 let arrow_dtype = original_dtype.to_arrow_dtype().unwrap();
478 let roundtripped_dtype = DType::from_arrow((&arrow_dtype, Nullability::NonNullable));
479
480 assert_eq!(original_dtype, roundtripped_dtype);
481 }
482
483 #[test]
484 fn test_unicode_field_names_nested_roundtrip() {
485 let inner_struct = DType::struct_(
489 [(
490 "\u{6}=inner",
491 DType::Primitive(PType::I32, Nullability::Nullable),
492 )],
493 Nullability::Nullable,
494 );
495 let original_dtype =
496 DType::struct_([("\u{7}=outer", inner_struct)], Nullability::NonNullable);
497
498 let arrow_dtype = original_dtype.to_arrow_dtype().unwrap();
499 let roundtripped_dtype = DType::from_arrow((&arrow_dtype, Nullability::NonNullable));
500
501 assert_eq!(original_dtype, roundtripped_dtype);
502 }
503}