1use std::sync::Arc;
14
15use arrow_schema::{
16 DECIMAL128_MAX_PRECISION, DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef,
17};
18use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
19
20use crate::datetime::arrow::{make_arrow_temporal_dtype, make_temporal_ext_dtype};
21use crate::datetime::is_temporal_ext_type;
22use crate::{DType, DecimalDType, FieldName, Nullability, PType, StructFields};
23
24pub trait FromArrowType<T>: Sized {
26 fn from_arrow(value: T) -> Self;
28}
29
30pub trait TryFromArrowType<T>: Sized {
32 fn try_from_arrow(value: T) -> VortexResult<Self>;
34}
35
36impl TryFromArrowType<&DataType> for PType {
37 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
38 match value {
39 DataType::Int8 => Ok(Self::I8),
40 DataType::Int16 => Ok(Self::I16),
41 DataType::Int32 => Ok(Self::I32),
42 DataType::Int64 => Ok(Self::I64),
43 DataType::UInt8 => Ok(Self::U8),
44 DataType::UInt16 => Ok(Self::U16),
45 DataType::UInt32 => Ok(Self::U32),
46 DataType::UInt64 => Ok(Self::U64),
47 DataType::Float16 => Ok(Self::F16),
48 DataType::Float32 => Ok(Self::F32),
49 DataType::Float64 => Ok(Self::F64),
50 _ => Err(vortex_err!(
51 "Arrow datatype {:?} cannot be converted to ptype",
52 value
53 )),
54 }
55 }
56}
57
58impl TryFromArrowType<&DataType> for DecimalDType {
59 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
60 match value {
61 DataType::Decimal128(precision, scale) => Ok(Self::new(*precision, *scale)),
62 DataType::Decimal256(precision, scale) => Ok(Self::new(*precision, *scale)),
63 _ => Err(vortex_err!(
64 "Arrow datatype {:?} cannot be converted to DecimalDType",
65 value
66 )),
67 }
68 }
69}
70
71impl FromArrowType<SchemaRef> for DType {
72 fn from_arrow(value: SchemaRef) -> Self {
73 Self::from_arrow(value.as_ref())
74 }
75}
76
77impl FromArrowType<&Schema> for DType {
78 fn from_arrow(value: &Schema) -> Self {
79 Self::Struct(
80 StructFields::from_arrow(value.fields()),
81 Nullability::NonNullable, )
83 }
84}
85
86impl FromArrowType<&Fields> for StructFields {
87 fn from_arrow(value: &Fields) -> Self {
88 StructFields::from_iter(value.into_iter().map(|f| {
89 (
90 FieldName::from(f.name().as_str()),
91 DType::from_arrow(f.as_ref()),
92 )
93 }))
94 }
95}
96
97impl FromArrowType<(&DataType, Nullability)> for DType {
98 fn from_arrow((data_type, nullability): (&DataType, Nullability)) -> Self {
99 use crate::DType::*;
100
101 if data_type.is_integer() || data_type.is_floating() {
102 return Primitive(
103 PType::try_from_arrow(data_type).vortex_expect("arrow float/integer to ptype"),
104 nullability,
105 );
106 }
107
108 match data_type {
109 DataType::Null => Null,
110 DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => {
111 Decimal(DecimalDType::new(*precision, *scale), nullability)
112 }
113 DataType::Boolean => Bool(nullability),
114 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Utf8(nullability),
115 DataType::Binary | DataType::LargeBinary | DataType::BinaryView => Binary(nullability),
116 DataType::Date32
117 | DataType::Date64
118 | DataType::Time32(_)
119 | DataType::Time64(_)
120 | DataType::Timestamp(..) => Extension(Arc::new(
121 make_temporal_ext_dtype(data_type).with_nullability(nullability),
122 )),
123 DataType::List(e) | DataType::LargeList(e) => {
124 List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
125 }
126 DataType::Struct(f) => Struct(StructFields::from_arrow(f), nullability),
127 _ => unimplemented!("Arrow data type not yet supported: {:?}", data_type),
128 }
129 }
130}
131
132impl FromArrowType<&Field> for DType {
133 fn from_arrow(field: &Field) -> Self {
134 Self::from_arrow((field.data_type(), field.is_nullable().into()))
135 }
136}
137
138impl DType {
139 pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
141 let DType::Struct(struct_dtype, nullable) = self else {
142 vortex_bail!("only DType::Struct can be converted to arrow schema");
143 };
144
145 if *nullable != Nullability::NonNullable {
146 vortex_bail!("top-level struct in Schema must be NonNullable");
147 }
148
149 let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
150 for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
151 builder.push(FieldRef::from(Field::new(
152 field_name.to_string(),
153 field_dtype.to_arrow_dtype()?,
154 field_dtype.is_nullable(),
155 )));
156 }
157
158 Ok(builder.finish())
159 }
160
161 pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
163 Ok(match self {
164 DType::Null => DataType::Null,
165 DType::Bool(_) => DataType::Boolean,
166 DType::Primitive(ptype, _) => match ptype {
167 PType::U8 => DataType::UInt8,
168 PType::U16 => DataType::UInt16,
169 PType::U32 => DataType::UInt32,
170 PType::U64 => DataType::UInt64,
171 PType::I8 => DataType::Int8,
172 PType::I16 => DataType::Int16,
173 PType::I32 => DataType::Int32,
174 PType::I64 => DataType::Int64,
175 PType::F16 => DataType::Float16,
176 PType::F32 => DataType::Float32,
177 PType::F64 => DataType::Float64,
178 },
179 DType::Decimal(dt, _) => {
180 if dt.precision() > DECIMAL128_MAX_PRECISION {
181 DataType::Decimal256(dt.precision(), dt.scale())
182 } else {
183 DataType::Decimal128(dt.precision(), dt.scale())
184 }
185 }
186 DType::Utf8(_) => DataType::Utf8View,
187 DType::Binary(_) => DataType::BinaryView,
188 DType::Struct(struct_dtype, _) => {
189 let mut fields = Vec::with_capacity(struct_dtype.names().len());
190 for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
191 {
192 fields.push(FieldRef::from(Field::new(
193 field_name.to_string(),
194 field_dt.to_arrow_dtype()?,
195 field_dt.is_nullable(),
196 )));
197 }
198
199 DataType::Struct(Fields::from(fields))
200 }
201 DType::List(l, _) => DataType::List(FieldRef::new(Field::new_list_field(
205 l.to_arrow_dtype()?,
206 l.nullability().into(),
207 ))),
208 DType::Extension(ext_dtype) => {
209 if is_temporal_ext_type(ext_dtype.id()) {
211 make_arrow_temporal_dtype(ext_dtype)
212 } else {
213 vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
214 }
215 }
216 })
217 }
218}
219
220#[cfg(test)]
221mod test {
222 use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
223 use rstest::{fixture, rstest};
224
225 use super::*;
226 use crate::{DType, ExtDType, ExtID, FieldName, FieldNames, Nullability, PType, StructFields};
227
228 #[test]
229 fn test_dtype_conversion_success() {
230 assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
231
232 assert_eq!(
233 DType::Bool(Nullability::NonNullable)
234 .to_arrow_dtype()
235 .unwrap(),
236 DataType::Boolean
237 );
238
239 assert_eq!(
240 DType::Primitive(PType::U64, Nullability::NonNullable)
241 .to_arrow_dtype()
242 .unwrap(),
243 DataType::UInt64
244 );
245
246 assert_eq!(
247 DType::Utf8(Nullability::NonNullable)
248 .to_arrow_dtype()
249 .unwrap(),
250 DataType::Utf8View
251 );
252
253 assert_eq!(
254 DType::Binary(Nullability::NonNullable)
255 .to_arrow_dtype()
256 .unwrap(),
257 DataType::BinaryView
258 );
259
260 assert_eq!(
261 DType::struct_(
262 [
263 ("field_a", DType::Bool(false.into())),
264 ("field_b", DType::Utf8(true.into()))
265 ],
266 Nullability::NonNullable,
267 )
268 .to_arrow_dtype()
269 .unwrap(),
270 DataType::Struct(Fields::from(vec![
271 FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
272 FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
273 ]))
274 );
275 }
276
277 #[test]
278 fn infer_nullable_list_element() {
279 let list_non_nullable = DType::List(
280 Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
281 Nullability::Nullable,
282 );
283
284 let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
285
286 let list_nullable = DType::List(
287 Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
288 Nullability::Nullable,
289 );
290 let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
291
292 assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
293 assert_eq!(
294 arrow_list_nullable,
295 DataType::new_list(DataType::Int64, true)
296 );
297 assert_eq!(
298 arrow_list_non_nullable,
299 DataType::new_list(DataType::Int64, false)
300 );
301 }
302
303 #[test]
304 #[should_panic]
305 fn test_dtype_conversion_panics() {
306 let _ = DType::Extension(Arc::new(ExtDType::new(
307 ExtID::from("my-fake-ext-dtype"),
308 Arc::new(DType::Utf8(Nullability::NonNullable)),
309 None,
310 )))
311 .to_arrow_dtype()
312 .unwrap();
313 }
314
315 #[fixture]
316 fn the_struct() -> StructFields {
317 StructFields::new(
318 FieldNames::from([
319 FieldName::from("field_a"),
320 FieldName::from("field_b"),
321 FieldName::from("field_c"),
322 ]),
323 vec![
324 DType::Bool(Nullability::NonNullable),
325 DType::Utf8(Nullability::NonNullable),
326 DType::Primitive(PType::I32, Nullability::Nullable),
327 ],
328 )
329 }
330
331 #[rstest]
332 fn test_schema_conversion(the_struct: StructFields) {
333 let schema_nonnull = DType::Struct(the_struct, Nullability::NonNullable);
334
335 assert_eq!(
336 schema_nonnull.to_arrow_schema().unwrap(),
337 Schema::new(Fields::from(vec![
338 Field::new("field_a", DataType::Boolean, false),
339 Field::new("field_b", DataType::Utf8View, false),
340 Field::new("field_c", DataType::Int32, true),
341 ]))
342 );
343 }
344
345 #[rstest]
346 #[should_panic]
347 fn test_schema_conversion_panics(the_struct: StructFields) {
348 let schema_null = DType::Struct(the_struct, Nullability::Nullable);
349 let _ = schema_null.to_arrow_schema().unwrap();
350 }
351}