1use std::sync::Arc;
17
18use arrow_schema::{
19 DECIMAL128_MAX_PRECISION, DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef,
20};
21use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
22
23use crate::datetime::arrow::{make_arrow_temporal_dtype, make_temporal_ext_dtype};
24use crate::datetime::is_temporal_ext_type;
25use crate::{DType, DecimalDType, FieldName, Nullability, PType, StructFields};
26
27pub trait FromArrowType<T>: Sized {
29 fn from_arrow(value: T) -> Self;
31}
32
33pub trait TryFromArrowType<T>: Sized {
35 fn try_from_arrow(value: T) -> VortexResult<Self>;
37}
38
39impl TryFromArrowType<&DataType> for PType {
40 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
41 match value {
42 DataType::Int8 => Ok(Self::I8),
43 DataType::Int16 => Ok(Self::I16),
44 DataType::Int32 => Ok(Self::I32),
45 DataType::Int64 => Ok(Self::I64),
46 DataType::UInt8 => Ok(Self::U8),
47 DataType::UInt16 => Ok(Self::U16),
48 DataType::UInt32 => Ok(Self::U32),
49 DataType::UInt64 => Ok(Self::U64),
50 DataType::Float16 => Ok(Self::F16),
51 DataType::Float32 => Ok(Self::F32),
52 DataType::Float64 => Ok(Self::F64),
53 _ => Err(vortex_err!(
54 "Arrow datatype {:?} cannot be converted to ptype",
55 value
56 )),
57 }
58 }
59}
60
61impl TryFromArrowType<&DataType> for DecimalDType {
62 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
63 match value {
64 DataType::Decimal128(precision, scale) => Self::try_new(*precision, *scale),
65 DataType::Decimal256(precision, scale) => Self::try_new(*precision, *scale),
66 _ => Err(vortex_err!(
67 "Arrow datatype {:?} cannot be converted to DecimalDType",
68 value
69 )),
70 }
71 }
72}
73
74impl FromArrowType<SchemaRef> for DType {
75 fn from_arrow(value: SchemaRef) -> Self {
76 Self::from_arrow(value.as_ref())
77 }
78}
79
80impl FromArrowType<&Schema> for DType {
81 fn from_arrow(value: &Schema) -> Self {
82 Self::Struct(
83 StructFields::from_arrow(value.fields()),
84 Nullability::NonNullable, )
86 }
87}
88
89impl FromArrowType<&Fields> for StructFields {
90 fn from_arrow(value: &Fields) -> Self {
91 StructFields::from_iter(value.into_iter().map(|f| {
92 (
93 FieldName::from(f.name().as_str()),
94 DType::from_arrow(f.as_ref()),
95 )
96 }))
97 }
98}
99
100impl FromArrowType<(&DataType, Nullability)> for DType {
101 fn from_arrow((data_type, nullability): (&DataType, Nullability)) -> Self {
102 if data_type.is_integer() || data_type.is_floating() {
103 return DType::Primitive(
104 PType::try_from_arrow(data_type).vortex_expect("arrow float/integer to ptype"),
105 nullability,
106 );
107 }
108
109 match data_type {
110 DataType::Null => DType::Null,
111 DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => {
112 DType::Decimal(DecimalDType::new(*precision, *scale), nullability)
113 }
114 DataType::Boolean => DType::Bool(nullability),
115 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => DType::Utf8(nullability),
116 DataType::Binary | DataType::LargeBinary | DataType::BinaryView => {
117 DType::Binary(nullability)
118 }
119 DataType::Date32
120 | DataType::Date64
121 | DataType::Time32(_)
122 | DataType::Time64(_)
123 | DataType::Timestamp(..) => DType::Extension(Arc::new(
124 make_temporal_ext_dtype(data_type).with_nullability(nullability),
125 )),
126 DataType::List(e) | DataType::LargeList(e) => {
127 DType::List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
128 }
129 DataType::FixedSizeList(e, size) => DType::FixedSizeList(
130 Arc::new(Self::from_arrow(e.as_ref())),
131 *size as u32,
132 nullability,
133 ),
134 DataType::Struct(f) => DType::Struct(StructFields::from_arrow(f), nullability),
135 DataType::Dictionary(_, value_type) => {
136 Self::from_arrow((value_type.as_ref(), nullability))
137 }
138 _ => unimplemented!("Arrow data type not yet supported: {:?}", data_type),
139 }
140 }
141}
142
143impl FromArrowType<&Field> for DType {
144 fn from_arrow(field: &Field) -> Self {
145 Self::from_arrow((field.data_type(), field.is_nullable().into()))
146 }
147}
148
149impl DType {
150 pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
152 let DType::Struct(struct_dtype, nullable) = self else {
153 vortex_bail!("only DType::Struct can be converted to arrow schema");
154 };
155
156 if *nullable != Nullability::NonNullable {
157 vortex_bail!("top-level struct in Schema must be NonNullable");
158 }
159
160 let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
161 for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
162 builder.push(FieldRef::from(Field::new(
163 field_name.to_string(),
164 field_dtype.to_arrow_dtype()?,
165 field_dtype.is_nullable(),
166 )));
167 }
168
169 Ok(builder.finish())
170 }
171
172 pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
174 Ok(match self {
175 DType::Null => DataType::Null,
176 DType::Bool(_) => DataType::Boolean,
177 DType::Primitive(ptype, _) => match ptype {
178 PType::U8 => DataType::UInt8,
179 PType::U16 => DataType::UInt16,
180 PType::U32 => DataType::UInt32,
181 PType::U64 => DataType::UInt64,
182 PType::I8 => DataType::Int8,
183 PType::I16 => DataType::Int16,
184 PType::I32 => DataType::Int32,
185 PType::I64 => DataType::Int64,
186 PType::F16 => DataType::Float16,
187 PType::F32 => DataType::Float32,
188 PType::F64 => DataType::Float64,
189 },
190 DType::Decimal(dt, _) => {
191 if dt.precision() > DECIMAL128_MAX_PRECISION {
192 DataType::Decimal256(dt.precision(), dt.scale())
193 } else {
194 DataType::Decimal128(dt.precision(), dt.scale())
195 }
196 }
197 DType::Utf8(_) => DataType::Utf8View,
198 DType::Binary(_) => DataType::BinaryView,
199 DType::Struct(struct_dtype, _) => {
200 let mut fields = Vec::with_capacity(struct_dtype.names().len());
201 for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
202 {
203 fields.push(FieldRef::from(Field::new(
204 field_name.to_string(),
205 field_dt.to_arrow_dtype()?,
206 field_dt.is_nullable(),
207 )));
208 }
209
210 DataType::Struct(Fields::from(fields))
211 }
212 DType::List(elem_dtype, _) => DataType::List(FieldRef::new(Field::new_list_field(
216 elem_dtype.to_arrow_dtype()?,
217 elem_dtype.nullability().into(),
218 ))),
219 DType::FixedSizeList(elem_dtype, size, _) => DataType::FixedSizeList(
220 FieldRef::new(Field::new_list_field(
221 elem_dtype.to_arrow_dtype()?,
222 elem_dtype.nullability().into(),
223 )),
224 *size as i32,
225 ),
226 DType::Extension(ext_dtype) => {
227 if is_temporal_ext_type(ext_dtype.id()) {
229 make_arrow_temporal_dtype(ext_dtype)
230 } else {
231 vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
232 }
233 }
234 })
235 }
236}
237
238#[cfg(test)]
239mod test {
240 use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
241 use rstest::{fixture, rstest};
242
243 use super::*;
244 use crate::{DType, ExtDType, ExtID, FieldName, FieldNames, Nullability, PType, StructFields};
245
246 #[test]
247 fn test_dtype_conversion_success() {
248 assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
249
250 assert_eq!(
251 DType::Bool(Nullability::NonNullable)
252 .to_arrow_dtype()
253 .unwrap(),
254 DataType::Boolean
255 );
256
257 assert_eq!(
258 DType::Primitive(PType::U64, Nullability::NonNullable)
259 .to_arrow_dtype()
260 .unwrap(),
261 DataType::UInt64
262 );
263
264 assert_eq!(
265 DType::Utf8(Nullability::NonNullable)
266 .to_arrow_dtype()
267 .unwrap(),
268 DataType::Utf8View
269 );
270
271 assert_eq!(
272 DType::Binary(Nullability::NonNullable)
273 .to_arrow_dtype()
274 .unwrap(),
275 DataType::BinaryView
276 );
277
278 assert_eq!(
279 DType::struct_(
280 [
281 ("field_a", DType::Bool(false.into())),
282 ("field_b", DType::Utf8(true.into()))
283 ],
284 Nullability::NonNullable,
285 )
286 .to_arrow_dtype()
287 .unwrap(),
288 DataType::Struct(Fields::from(vec![
289 FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
290 FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
291 ]))
292 );
293 }
294
295 #[test]
296 fn infer_nullable_list_element() {
297 let list_non_nullable = DType::List(
298 Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
299 Nullability::Nullable,
300 );
301
302 let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
303
304 let list_nullable = DType::List(
305 Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
306 Nullability::Nullable,
307 );
308 let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
309
310 assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
311 assert_eq!(
312 arrow_list_nullable,
313 DataType::new_list(DataType::Int64, true)
314 );
315 assert_eq!(
316 arrow_list_non_nullable,
317 DataType::new_list(DataType::Int64, false)
318 );
319 }
320
321 #[test]
322 #[should_panic]
323 fn test_dtype_conversion_panics() {
324 let _ = DType::Extension(Arc::new(ExtDType::new(
325 ExtID::from("my-fake-ext-dtype"),
326 Arc::new(DType::Utf8(Nullability::NonNullable)),
327 None,
328 )))
329 .to_arrow_dtype()
330 .unwrap();
331 }
332
333 #[fixture]
334 fn the_struct() -> StructFields {
335 StructFields::new(
336 FieldNames::from([
337 FieldName::from("field_a"),
338 FieldName::from("field_b"),
339 FieldName::from("field_c"),
340 ]),
341 vec![
342 DType::Bool(Nullability::NonNullable),
343 DType::Utf8(Nullability::NonNullable),
344 DType::Primitive(PType::I32, Nullability::Nullable),
345 ],
346 )
347 }
348
349 #[rstest]
350 fn test_schema_conversion(the_struct: StructFields) {
351 let schema_nonnull = DType::Struct(the_struct, Nullability::NonNullable);
352
353 assert_eq!(
354 schema_nonnull.to_arrow_schema().unwrap(),
355 Schema::new(Fields::from(vec![
356 Field::new("field_a", DataType::Boolean, false),
357 Field::new("field_b", DataType::Utf8View, false),
358 Field::new("field_c", DataType::Int32, true),
359 ]))
360 );
361 }
362
363 #[rstest]
364 #[should_panic]
365 fn test_schema_conversion_panics(the_struct: StructFields) {
366 let schema_null = DType::Struct(the_struct, Nullability::Nullable);
367 let _ = schema_null.to_arrow_schema().unwrap();
368 }
369}