1use std::sync::Arc;
14
15use arrow_schema::{
16 DECIMAL128_MAX_SCALE, DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef,
17};
18use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
19
20use crate::datetime::arrow::{make_arrow_temporal_dtype, make_temporal_ext_dtype};
21use crate::datetime::is_temporal_ext_type;
22use crate::{DType, DecimalDType, FieldName, Nullability, PType, StructFields};
23
24pub trait FromArrowType<T>: Sized {
26 fn from_arrow(value: T) -> Self;
28}
29
30pub trait TryFromArrowType<T>: Sized {
32 fn try_from_arrow(value: T) -> VortexResult<Self>;
34}
35
36impl TryFromArrowType<&DataType> for PType {
37 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
38 match value {
39 DataType::Int8 => Ok(Self::I8),
40 DataType::Int16 => Ok(Self::I16),
41 DataType::Int32 => Ok(Self::I32),
42 DataType::Int64 => Ok(Self::I64),
43 DataType::UInt8 => Ok(Self::U8),
44 DataType::UInt16 => Ok(Self::U16),
45 DataType::UInt32 => Ok(Self::U32),
46 DataType::UInt64 => Ok(Self::U64),
47 DataType::Float16 => Ok(Self::F16),
48 DataType::Float32 => Ok(Self::F32),
49 DataType::Float64 => Ok(Self::F64),
50 _ => Err(vortex_err!(
51 "Arrow datatype {:?} cannot be converted to ptype",
52 value
53 )),
54 }
55 }
56}
57
58impl TryFromArrowType<&DataType> for DecimalDType {
59 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
60 match value {
61 DataType::Decimal128(precision, scale) => Ok(Self::new(*precision, *scale)),
62 DataType::Decimal256(precision, scale) => Ok(Self::new(*precision, *scale)),
63 _ => Err(vortex_err!(
64 "Arrow datatype {:?} cannot be converted to DecimalDType",
65 value
66 )),
67 }
68 }
69}
70
71impl FromArrowType<SchemaRef> for DType {
72 fn from_arrow(value: SchemaRef) -> Self {
73 Self::from_arrow(value.as_ref())
74 }
75}
76
77impl FromArrowType<&Schema> for DType {
78 fn from_arrow(value: &Schema) -> Self {
79 Self::Struct(
80 Arc::new(StructFields::from_arrow(value.fields())),
81 Nullability::NonNullable, )
83 }
84}
85
86impl FromArrowType<&Fields> for StructFields {
87 fn from_arrow(value: &Fields) -> Self {
88 StructFields::from_iter(value.into_iter().map(|f| {
89 (
90 FieldName::from(f.name().as_str()),
91 DType::from_arrow(f.as_ref()),
92 )
93 }))
94 }
95}
96
97impl FromArrowType<(&DataType, Nullability)> for DType {
98 fn from_arrow((data_type, nullability): (&DataType, Nullability)) -> Self {
99 use crate::DType::*;
100
101 if data_type.is_integer() || data_type.is_floating() {
102 return Primitive(
103 PType::try_from_arrow(data_type).vortex_expect("arrow float/integer to ptype"),
104 nullability,
105 );
106 }
107
108 match data_type {
109 DataType::Null => Null,
110 DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => {
111 Decimal(DecimalDType::new(*precision, *scale), nullability)
112 }
113 DataType::Boolean => Bool(nullability),
114 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Utf8(nullability),
115 DataType::Binary | DataType::LargeBinary | DataType::BinaryView => Binary(nullability),
116 DataType::Date32
117 | DataType::Date64
118 | DataType::Time32(_)
119 | DataType::Time64(_)
120 | DataType::Timestamp(..) => Extension(Arc::new(
121 make_temporal_ext_dtype(data_type).with_nullability(nullability),
122 )),
123 DataType::List(e) | DataType::LargeList(e) => {
124 List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
125 }
126 DataType::Struct(f) => Struct(Arc::new(StructFields::from_arrow(f)), nullability),
127 _ => unimplemented!("Arrow data type not yet supported: {:?}", data_type),
128 }
129 }
130}
131
132impl FromArrowType<&Field> for DType {
133 fn from_arrow(field: &Field) -> Self {
134 Self::from_arrow((field.data_type(), field.is_nullable().into()))
135 }
136}
137
138impl DType {
139 pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
141 let DType::Struct(struct_dtype, nullable) = self else {
142 vortex_bail!("only DType::Struct can be converted to arrow schema");
143 };
144
145 if *nullable != Nullability::NonNullable {
146 vortex_bail!("top-level struct in Schema must be NonNullable");
147 }
148
149 let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
150 for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
151 builder.push(FieldRef::from(Field::new(
152 field_name.to_string(),
153 field_dtype.to_arrow_dtype()?,
154 field_dtype.is_nullable(),
155 )));
156 }
157
158 Ok(builder.finish())
159 }
160
161 pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
163 Ok(match self {
164 DType::Null => DataType::Null,
165 DType::Bool(_) => DataType::Boolean,
166 DType::Primitive(ptype, _) => match ptype {
167 PType::U8 => DataType::UInt8,
168 PType::U16 => DataType::UInt16,
169 PType::U32 => DataType::UInt32,
170 PType::U64 => DataType::UInt64,
171 PType::I8 => DataType::Int8,
172 PType::I16 => DataType::Int16,
173 PType::I32 => DataType::Int32,
174 PType::I64 => DataType::Int64,
175 PType::F16 => DataType::Float16,
176 PType::F32 => DataType::Float32,
177 PType::F64 => DataType::Float64,
178 },
179 DType::Decimal(dt, _) => {
180 if dt.scale() > DECIMAL128_MAX_SCALE {
181 DataType::Decimal256(dt.precision(), dt.scale())
182 } else {
183 DataType::Decimal128(dt.precision(), dt.scale())
184 }
185 }
186 DType::Utf8(_) => DataType::Utf8View,
187 DType::Binary(_) => DataType::BinaryView,
188 DType::Struct(struct_dtype, _) => {
189 let mut fields = Vec::with_capacity(struct_dtype.names().len());
190 for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
191 {
192 fields.push(FieldRef::from(Field::new(
193 field_name.to_string(),
194 field_dt.to_arrow_dtype()?,
195 field_dt.is_nullable(),
196 )));
197 }
198
199 DataType::Struct(Fields::from(fields))
200 }
201 DType::List(l, _) => DataType::List(FieldRef::new(Field::new_list_field(
205 l.to_arrow_dtype()?,
206 l.nullability().into(),
207 ))),
208 DType::Extension(ext_dtype) => {
209 if is_temporal_ext_type(ext_dtype.id()) {
211 make_arrow_temporal_dtype(ext_dtype)
212 } else {
213 vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
214 }
215 }
216 })
217 }
218}
219
220#[cfg(test)]
221mod test {
222 use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
223
224 use super::*;
225 use crate::{DType, ExtDType, ExtID, FieldName, FieldNames, Nullability, PType, StructFields};
226
227 #[test]
228 fn test_dtype_conversion_success() {
229 assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
230
231 assert_eq!(
232 DType::Bool(Nullability::NonNullable)
233 .to_arrow_dtype()
234 .unwrap(),
235 DataType::Boolean
236 );
237
238 assert_eq!(
239 DType::Primitive(PType::U64, Nullability::NonNullable)
240 .to_arrow_dtype()
241 .unwrap(),
242 DataType::UInt64
243 );
244
245 assert_eq!(
246 DType::Utf8(Nullability::NonNullable)
247 .to_arrow_dtype()
248 .unwrap(),
249 DataType::Utf8View
250 );
251
252 assert_eq!(
253 DType::Binary(Nullability::NonNullable)
254 .to_arrow_dtype()
255 .unwrap(),
256 DataType::BinaryView
257 );
258
259 assert_eq!(
260 DType::Struct(
261 Arc::new(StructFields::from_iter([
262 ("field_a", DType::Bool(false.into())),
263 ("field_b", DType::Utf8(true.into()))
264 ])),
265 Nullability::NonNullable,
266 )
267 .to_arrow_dtype()
268 .unwrap(),
269 DataType::Struct(Fields::from(vec![
270 FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
271 FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
272 ]))
273 );
274 }
275
276 #[test]
277 fn infer_nullable_list_element() {
278 let list_non_nullable = DType::List(
279 Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
280 Nullability::Nullable,
281 );
282
283 let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
284
285 let list_nullable = DType::List(
286 Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
287 Nullability::Nullable,
288 );
289 let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
290
291 assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
292 assert_eq!(
293 arrow_list_nullable,
294 DataType::new_list(DataType::Int64, true)
295 );
296 assert_eq!(
297 arrow_list_non_nullable,
298 DataType::new_list(DataType::Int64, false)
299 );
300 }
301
302 #[test]
303 #[should_panic]
304 fn test_dtype_conversion_panics() {
305 let _ = DType::Extension(Arc::new(ExtDType::new(
306 ExtID::from("my-fake-ext-dtype"),
307 Arc::new(DType::Utf8(Nullability::NonNullable)),
308 None,
309 )))
310 .to_arrow_dtype()
311 .unwrap();
312 }
313
314 #[test]
315 fn test_schema_conversion() {
316 let struct_dtype = the_struct();
317 let schema_nonnull = DType::Struct(struct_dtype, Nullability::NonNullable);
318
319 assert_eq!(
320 schema_nonnull.to_arrow_schema().unwrap(),
321 Schema::new(Fields::from(vec![
322 Field::new("field_a", DataType::Boolean, false),
323 Field::new("field_b", DataType::Utf8View, false),
324 Field::new("field_c", DataType::Int32, true),
325 ]))
326 );
327 }
328
329 #[test]
330 #[should_panic]
331 fn test_schema_conversion_panics() {
332 let struct_dtype = the_struct();
333 let schema_null = DType::Struct(struct_dtype, Nullability::Nullable);
334 let _ = schema_null.to_arrow_schema().unwrap();
335 }
336
337 fn the_struct() -> Arc<StructFields> {
338 Arc::new(StructFields::new(
339 FieldNames::from([
340 FieldName::from("field_a"),
341 FieldName::from("field_b"),
342 FieldName::from("field_c"),
343 ]),
344 vec![
345 DType::Bool(Nullability::NonNullable),
346 DType::Utf8(Nullability::NonNullable),
347 DType::Primitive(PType::I32, Nullability::Nullable),
348 ],
349 ))
350 }
351}