1use std::sync::Arc;
17
18use arrow_schema::DataType;
19use arrow_schema::Field;
20use arrow_schema::FieldRef;
21use arrow_schema::Fields;
22use arrow_schema::Schema;
23use arrow_schema::SchemaBuilder;
24use arrow_schema::SchemaRef;
25use vortex_error::VortexExpect;
26use vortex_error::VortexResult;
27use vortex_error::vortex_bail;
28use vortex_error::vortex_err;
29
30use crate::DType;
31use crate::DecimalDType;
32use crate::FieldName;
33use crate::Nullability;
34use crate::PType;
35use crate::StructFields;
36use crate::datetime::arrow::make_arrow_temporal_dtype;
37use crate::datetime::arrow::make_temporal_ext_dtype;
38use crate::datetime::is_temporal_ext_type;
39
40pub trait FromArrowType<T>: Sized {
42 fn from_arrow(value: T) -> Self;
44}
45
46pub trait TryFromArrowType<T>: Sized {
48 fn try_from_arrow(value: T) -> VortexResult<Self>;
50}
51
52impl TryFromArrowType<&DataType> for PType {
53 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
54 match value {
55 DataType::Int8 => Ok(Self::I8),
56 DataType::Int16 => Ok(Self::I16),
57 DataType::Int32 => Ok(Self::I32),
58 DataType::Int64 => Ok(Self::I64),
59 DataType::UInt8 => Ok(Self::U8),
60 DataType::UInt16 => Ok(Self::U16),
61 DataType::UInt32 => Ok(Self::U32),
62 DataType::UInt64 => Ok(Self::U64),
63 DataType::Float16 => Ok(Self::F16),
64 DataType::Float32 => Ok(Self::F32),
65 DataType::Float64 => Ok(Self::F64),
66 _ => Err(vortex_err!(
67 "Arrow datatype {:?} cannot be converted to ptype",
68 value
69 )),
70 }
71 }
72}
73
74impl TryFromArrowType<&DataType> for DecimalDType {
75 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
76 match value {
77 DataType::Decimal32(precision, scale)
78 | DataType::Decimal64(precision, scale)
79 | DataType::Decimal128(precision, scale)
80 | DataType::Decimal256(precision, scale) => Self::try_new(*precision, *scale),
81
82 _ => Err(vortex_err!(
83 "Arrow datatype {:?} cannot be converted to DecimalDType",
84 value
85 )),
86 }
87 }
88}
89
90impl FromArrowType<SchemaRef> for DType {
91 fn from_arrow(value: SchemaRef) -> Self {
92 Self::from_arrow(value.as_ref())
93 }
94}
95
96impl FromArrowType<&Schema> for DType {
97 fn from_arrow(value: &Schema) -> Self {
98 Self::Struct(
99 StructFields::from_arrow(value.fields()),
100 Nullability::NonNullable, )
102 }
103}
104
105impl FromArrowType<&Fields> for StructFields {
106 fn from_arrow(value: &Fields) -> Self {
107 StructFields::from_iter(value.into_iter().map(|f| {
108 (
109 FieldName::from(f.name().as_str()),
110 DType::from_arrow(f.as_ref()),
111 )
112 }))
113 }
114}
115
116impl FromArrowType<(&DataType, Nullability)> for DType {
117 fn from_arrow((data_type, nullability): (&DataType, Nullability)) -> Self {
118 if data_type.is_integer() || data_type.is_floating() {
119 return DType::Primitive(
120 PType::try_from_arrow(data_type).vortex_expect("arrow float/integer to ptype"),
121 nullability,
122 );
123 }
124
125 match data_type {
126 DataType::Null => DType::Null,
127 DataType::Decimal32(precision, scale)
128 | DataType::Decimal64(precision, scale)
129 | DataType::Decimal128(precision, scale)
130 | DataType::Decimal256(precision, scale) => {
131 DType::Decimal(DecimalDType::new(*precision, *scale), nullability)
132 }
133 DataType::Boolean => DType::Bool(nullability),
134 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => DType::Utf8(nullability),
135 DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
136 DType::Binary(nullability)
137 }
138 DataType::Date32
139 | DataType::Date64
140 | DataType::Time32(_)
141 | DataType::Time64(_)
142 | DataType::Timestamp(..) => DType::Extension(Arc::new(
143 make_temporal_ext_dtype(data_type).with_nullability(nullability),
144 )),
145 DataType::List(e)
146 | DataType::LargeList(e)
147 | DataType::ListView(e)
148 | DataType::LargeListView(e) => {
149 DType::List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
150 }
151 DataType::FixedSizeList(e, size) => DType::FixedSizeList(
152 Arc::new(Self::from_arrow(e.as_ref())),
153 *size as u32,
154 nullability,
155 ),
156 DataType::Struct(f) => DType::Struct(StructFields::from_arrow(f), nullability),
157 DataType::Dictionary(_, value_type) => {
158 Self::from_arrow((value_type.as_ref(), nullability))
159 }
160 _ => unimplemented!("Arrow data type not yet supported: {:?}", data_type),
161 }
162 }
163}
164
165impl FromArrowType<&Field> for DType {
166 fn from_arrow(field: &Field) -> Self {
167 Self::from_arrow((field.data_type(), field.is_nullable().into()))
168 }
169}
170
171impl DType {
172 pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
174 let DType::Struct(struct_dtype, nullable) = self else {
175 vortex_bail!("only DType::Struct can be converted to arrow schema");
176 };
177
178 if *nullable != Nullability::NonNullable {
179 vortex_bail!("top-level struct in Schema must be NonNullable");
180 }
181
182 let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
183 for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
184 builder.push(FieldRef::from(Field::new(
185 field_name.to_string(),
186 field_dtype.to_arrow_dtype()?,
187 field_dtype.is_nullable(),
188 )));
189 }
190
191 Ok(builder.finish())
192 }
193
194 pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
196 Ok(match self {
197 DType::Null => DataType::Null,
198 DType::Bool(_) => DataType::Boolean,
199 DType::Primitive(ptype, _) => match ptype {
200 PType::U8 => DataType::UInt8,
201 PType::U16 => DataType::UInt16,
202 PType::U32 => DataType::UInt32,
203 PType::U64 => DataType::UInt64,
204 PType::I8 => DataType::Int8,
205 PType::I16 => DataType::Int16,
206 PType::I32 => DataType::Int32,
207 PType::I64 => DataType::Int64,
208 PType::F16 => DataType::Float16,
209 PType::F32 => DataType::Float32,
210 PType::F64 => DataType::Float64,
211 },
212 DType::Decimal(dt, _) => {
213 let precision = dt.precision();
214 let scale = dt.scale();
215
216 match precision {
217 0..=38 => DataType::Decimal128(precision, scale),
224 39.. => DataType::Decimal256(precision, scale),
226 }
227 }
228 DType::Utf8(_) => DataType::Utf8View,
229 DType::Binary(_) => DataType::BinaryView,
230 DType::List(elem_dtype, _) => DataType::List(FieldRef::new(Field::new_list_field(
234 elem_dtype.to_arrow_dtype()?,
235 elem_dtype.nullability().into(),
236 ))),
237 DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList(
238 FieldRef::new(Field::new_list_field(
239 elem_dtype.to_arrow_dtype()?,
240 elem_dtype.nullability().into(),
241 )),
242 *size as i32,
243 ),
244 DType::Struct(struct_dtype, _) => {
245 let mut fields = Vec::with_capacity(struct_dtype.names().len());
246 for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
247 {
248 fields.push(FieldRef::from(Field::new(
249 field_name.to_string(),
250 field_dt.to_arrow_dtype()?,
251 field_dt.is_nullable(),
252 )));
253 }
254
255 DataType::Struct(Fields::from(fields))
256 }
257 DType::Extension(ext_dtype) => {
258 if is_temporal_ext_type(ext_dtype.id()) {
260 make_arrow_temporal_dtype(ext_dtype)
261 } else {
262 vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
263 }
264 }
265 })
266 }
267}
268
269#[cfg(test)]
270mod test {
271 use arrow_schema::DataType;
272 use arrow_schema::Field;
273 use arrow_schema::FieldRef;
274 use arrow_schema::Fields;
275 use arrow_schema::Schema;
276 use rstest::fixture;
277 use rstest::rstest;
278
279 use super::*;
280 use crate::DType;
281 use crate::ExtDType;
282 use crate::ExtID;
283 use crate::FieldName;
284 use crate::FieldNames;
285 use crate::Nullability;
286 use crate::PType;
287 use crate::StructFields;
288
289 #[test]
290 fn test_dtype_conversion_success() {
291 assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
292
293 assert_eq!(
294 DType::Bool(Nullability::NonNullable)
295 .to_arrow_dtype()
296 .unwrap(),
297 DataType::Boolean
298 );
299
300 assert_eq!(
301 DType::Primitive(PType::U64, Nullability::NonNullable)
302 .to_arrow_dtype()
303 .unwrap(),
304 DataType::UInt64
305 );
306
307 assert_eq!(
308 DType::Utf8(Nullability::NonNullable)
309 .to_arrow_dtype()
310 .unwrap(),
311 DataType::Utf8View
312 );
313
314 assert_eq!(
315 DType::Binary(Nullability::NonNullable)
316 .to_arrow_dtype()
317 .unwrap(),
318 DataType::BinaryView
319 );
320
321 assert_eq!(
322 DType::struct_(
323 [
324 ("field_a", DType::Bool(false.into())),
325 ("field_b", DType::Utf8(true.into()))
326 ],
327 Nullability::NonNullable,
328 )
329 .to_arrow_dtype()
330 .unwrap(),
331 DataType::Struct(Fields::from(vec![
332 FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
333 FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
334 ]))
335 );
336 }
337
338 #[test]
339 fn infer_nullable_list_element() {
340 let list_non_nullable = DType::List(
341 Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
342 Nullability::Nullable,
343 );
344
345 let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
346
347 let list_nullable = DType::List(
348 Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
349 Nullability::Nullable,
350 );
351 let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
352
353 assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
354 assert_eq!(
355 arrow_list_nullable,
356 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, true))),
357 );
358 assert_eq!(
359 arrow_list_non_nullable,
360 DataType::List(Arc::new(Field::new_list_field(DataType::Int64, false))),
361 );
362 }
363
364 #[test]
365 #[should_panic]
366 fn test_dtype_conversion_panics() {
367 DType::Extension(Arc::new(ExtDType::new(
368 ExtID::from("my-fake-ext-dtype"),
369 Arc::new(DType::Utf8(Nullability::NonNullable)),
370 None,
371 )))
372 .to_arrow_dtype()
373 .unwrap();
374 }
375
376 #[fixture]
377 fn the_struct() -> StructFields {
378 StructFields::new(
379 FieldNames::from([
380 FieldName::from("field_a"),
381 FieldName::from("field_b"),
382 FieldName::from("field_c"),
383 ]),
384 vec![
385 DType::Bool(Nullability::NonNullable),
386 DType::Utf8(Nullability::NonNullable),
387 DType::Primitive(PType::I32, Nullability::Nullable),
388 ],
389 )
390 }
391
392 #[rstest]
393 fn test_schema_conversion(the_struct: StructFields) {
394 let schema_nonnull = DType::Struct(the_struct, Nullability::NonNullable);
395
396 assert_eq!(
397 schema_nonnull.to_arrow_schema().unwrap(),
398 Schema::new(Fields::from(vec![
399 Field::new("field_a", DataType::Boolean, false),
400 Field::new("field_b", DataType::Utf8View, false),
401 Field::new("field_c", DataType::Int32, true),
402 ]))
403 );
404 }
405
406 #[rstest]
407 #[should_panic]
408 fn test_schema_conversion_panics(the_struct: StructFields) {
409 let schema_null = DType::Struct(the_struct, Nullability::Nullable);
410 schema_null.to_arrow_schema().unwrap();
411 }
412}