1use std::sync::Arc;
14
15use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaBuilder, SchemaRef};
16use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
17
18use crate::datetime::arrow::{make_arrow_temporal_dtype, make_temporal_ext_dtype};
19use crate::datetime::is_temporal_ext_type;
20use crate::{DType, FieldName, Nullability, PType, StructDType};
21
22pub trait FromArrowType<T>: Sized {
24 fn from_arrow(value: T) -> Self;
26}
27
28pub trait TryFromArrowType<T>: Sized {
30 fn try_from_arrow(value: T) -> VortexResult<Self>;
32}
33
34impl TryFromArrowType<&DataType> for PType {
35 fn try_from_arrow(value: &DataType) -> VortexResult<Self> {
36 match value {
37 DataType::Int8 => Ok(Self::I8),
38 DataType::Int16 => Ok(Self::I16),
39 DataType::Int32 => Ok(Self::I32),
40 DataType::Int64 => Ok(Self::I64),
41 DataType::UInt8 => Ok(Self::U8),
42 DataType::UInt16 => Ok(Self::U16),
43 DataType::UInt32 => Ok(Self::U32),
44 DataType::UInt64 => Ok(Self::U64),
45 DataType::Float16 => Ok(Self::F16),
46 DataType::Float32 => Ok(Self::F32),
47 DataType::Float64 => Ok(Self::F64),
48 _ => Err(vortex_err!(
49 "Arrow datatype {:?} cannot be converted to ptype",
50 value
51 )),
52 }
53 }
54}
55
56impl FromArrowType<SchemaRef> for DType {
57 fn from_arrow(value: SchemaRef) -> Self {
58 Self::from_arrow(value.as_ref())
59 }
60}
61
62impl FromArrowType<&Schema> for DType {
63 fn from_arrow(value: &Schema) -> Self {
64 Self::Struct(
65 Arc::new(StructDType::from_arrow(value.fields())),
66 Nullability::NonNullable, )
68 }
69}
70
71impl FromArrowType<&Fields> for StructDType {
72 fn from_arrow(value: &Fields) -> Self {
73 StructDType::from_iter(value.into_iter().map(|f| {
74 (
75 FieldName::from(f.name().as_str()),
76 DType::from_arrow(f.as_ref()),
77 )
78 }))
79 }
80}
81
82impl FromArrowType<&Field> for DType {
83 fn from_arrow(field: &Field) -> Self {
84 use crate::DType::*;
85
86 let nullability: Nullability = field.is_nullable().into();
87
88 if field.data_type().is_integer() || field.data_type().is_floating() {
89 return Primitive(
90 PType::try_from_arrow(field.data_type())
91 .vortex_expect("arrow float/integer to ptype"),
92 nullability,
93 );
94 }
95
96 match field.data_type() {
97 DataType::Null => Null,
98 DataType::Boolean => Bool(nullability),
99 DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Utf8(nullability),
100 DataType::Binary | DataType::LargeBinary | DataType::BinaryView => Binary(nullability),
101 DataType::Date32
102 | DataType::Date64
103 | DataType::Time32(_)
104 | DataType::Time64(_)
105 | DataType::Timestamp(..) => Extension(Arc::new(
106 make_temporal_ext_dtype(field.data_type()).with_nullability(nullability),
107 )),
108 DataType::List(e) | DataType::LargeList(e) => {
109 List(Arc::new(Self::from_arrow(e.as_ref())), nullability)
110 }
111 DataType::Struct(f) => Struct(Arc::new(StructDType::from_arrow(f)), nullability),
112 _ => unimplemented!("Arrow data type not yet supported: {:?}", field.data_type()),
113 }
114 }
115}
116
117impl DType {
118 pub fn to_arrow_schema(&self) -> VortexResult<Schema> {
120 let DType::Struct(struct_dtype, nullable) = self else {
121 vortex_bail!("only DType::Struct can be converted to arrow schema");
122 };
123
124 if *nullable != Nullability::NonNullable {
125 vortex_bail!("top-level struct in Schema must be NonNullable");
126 }
127
128 let mut builder = SchemaBuilder::with_capacity(struct_dtype.names().len());
129 for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) {
130 builder.push(FieldRef::from(Field::new(
131 field_name.to_string(),
132 field_dtype.to_arrow_dtype()?,
133 field_dtype.is_nullable(),
134 )));
135 }
136
137 Ok(builder.finish())
138 }
139
140 pub fn to_arrow_dtype(&self) -> VortexResult<DataType> {
142 Ok(match self {
143 DType::Null => DataType::Null,
144 DType::Bool(_) => DataType::Boolean,
145 DType::Primitive(ptype, _) => match ptype {
146 PType::U8 => DataType::UInt8,
147 PType::U16 => DataType::UInt16,
148 PType::U32 => DataType::UInt32,
149 PType::U64 => DataType::UInt64,
150 PType::I8 => DataType::Int8,
151 PType::I16 => DataType::Int16,
152 PType::I32 => DataType::Int32,
153 PType::I64 => DataType::Int64,
154 PType::F16 => DataType::Float16,
155 PType::F32 => DataType::Float32,
156 PType::F64 => DataType::Float64,
157 },
158 DType::Utf8(_) => DataType::Utf8View,
159 DType::Binary(_) => DataType::BinaryView,
160 DType::Struct(struct_dtype, _) => {
161 let mut fields = Vec::with_capacity(struct_dtype.names().len());
162 for (field_name, field_dt) in struct_dtype.names().iter().zip(struct_dtype.fields())
163 {
164 fields.push(FieldRef::from(Field::new(
165 field_name.to_string(),
166 field_dt.to_arrow_dtype()?,
167 field_dt.is_nullable(),
168 )));
169 }
170
171 DataType::Struct(Fields::from(fields))
172 }
173 DType::List(l, _) => DataType::List(FieldRef::new(Field::new_list_field(
177 l.to_arrow_dtype()?,
178 l.nullability().into(),
179 ))),
180 DType::Extension(ext_dtype) => {
181 if is_temporal_ext_type(ext_dtype.id()) {
183 make_arrow_temporal_dtype(ext_dtype)
184 } else {
185 vortex_bail!("Unsupported extension type \"{}\"", ext_dtype.id())
186 }
187 }
188 })
189 }
190}
191
192#[cfg(test)]
193mod test {
194 use arrow_schema::{DataType, Field, FieldRef, Fields, Schema};
195
196 use super::*;
197 use crate::{DType, ExtDType, ExtID, FieldName, FieldNames, Nullability, PType, StructDType};
198
199 #[test]
200 fn test_dtype_conversion_success() {
201 assert_eq!(DType::Null.to_arrow_dtype().unwrap(), DataType::Null);
202
203 assert_eq!(
204 DType::Bool(Nullability::NonNullable)
205 .to_arrow_dtype()
206 .unwrap(),
207 DataType::Boolean
208 );
209
210 assert_eq!(
211 DType::Primitive(PType::U64, Nullability::NonNullable)
212 .to_arrow_dtype()
213 .unwrap(),
214 DataType::UInt64
215 );
216
217 assert_eq!(
218 DType::Utf8(Nullability::NonNullable)
219 .to_arrow_dtype()
220 .unwrap(),
221 DataType::Utf8View
222 );
223
224 assert_eq!(
225 DType::Binary(Nullability::NonNullable)
226 .to_arrow_dtype()
227 .unwrap(),
228 DataType::BinaryView
229 );
230
231 assert_eq!(
232 DType::Struct(
233 Arc::new(StructDType::from_iter([
234 ("field_a", DType::Bool(false.into())),
235 ("field_b", DType::Utf8(true.into()))
236 ])),
237 Nullability::NonNullable,
238 )
239 .to_arrow_dtype()
240 .unwrap(),
241 DataType::Struct(Fields::from(vec![
242 FieldRef::from(Field::new("field_a", DataType::Boolean, false)),
243 FieldRef::from(Field::new("field_b", DataType::Utf8View, true)),
244 ]))
245 );
246 }
247
248 #[test]
249 fn infer_nullable_list_element() {
250 let list_non_nullable = DType::List(
251 Arc::new(DType::Primitive(PType::I64, Nullability::NonNullable)),
252 Nullability::Nullable,
253 );
254
255 let arrow_list_non_nullable = list_non_nullable.to_arrow_dtype().unwrap();
256
257 let list_nullable = DType::List(
258 Arc::new(DType::Primitive(PType::I64, Nullability::Nullable)),
259 Nullability::Nullable,
260 );
261 let arrow_list_nullable = list_nullable.to_arrow_dtype().unwrap();
262
263 assert_ne!(arrow_list_non_nullable, arrow_list_nullable);
264 assert_eq!(
265 arrow_list_nullable,
266 DataType::new_list(DataType::Int64, true)
267 );
268 assert_eq!(
269 arrow_list_non_nullable,
270 DataType::new_list(DataType::Int64, false)
271 );
272 }
273
274 #[test]
275 #[should_panic]
276 fn test_dtype_conversion_panics() {
277 let _ = DType::Extension(Arc::new(ExtDType::new(
278 ExtID::from("my-fake-ext-dtype"),
279 Arc::new(DType::Utf8(Nullability::NonNullable)),
280 None,
281 )))
282 .to_arrow_dtype()
283 .unwrap();
284 }
285
286 #[test]
287 fn test_schema_conversion() {
288 let struct_dtype = the_struct();
289 let schema_nonnull = DType::Struct(struct_dtype, Nullability::NonNullable);
290
291 assert_eq!(
292 schema_nonnull.to_arrow_schema().unwrap(),
293 Schema::new(Fields::from(vec![
294 Field::new("field_a", DataType::Boolean, false),
295 Field::new("field_b", DataType::Utf8View, false),
296 Field::new("field_c", DataType::Int32, true),
297 ]))
298 );
299 }
300
301 #[test]
302 #[should_panic]
303 fn test_schema_conversion_panics() {
304 let struct_dtype = the_struct();
305 let schema_null = DType::Struct(struct_dtype, Nullability::Nullable);
306 let _ = schema_null.to_arrow_schema().unwrap();
307 }
308
309 fn the_struct() -> Arc<StructDType> {
310 Arc::new(StructDType::new(
311 FieldNames::from([
312 FieldName::from("field_a"),
313 FieldName::from("field_b"),
314 FieldName::from("field_c"),
315 ]),
316 vec![
317 DType::Bool(Nullability::NonNullable),
318 DType::Utf8(Nullability::NonNullable),
319 DType::Primitive(PType::I32, Nullability::Nullable),
320 ],
321 ))
322 }
323}