vortex_array/arrays/extension/compute/
to_arrow.rs

1use std::sync::Arc;
2
3use arrow_array::{
4    ArrayRef, Date32Array, Date64Array, Time32MillisecondArray, Time32SecondArray,
5    Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
6    TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
7};
8use arrow_schema::DataType;
9use vortex_dtype::datetime::{TemporalMetadata, TimeUnit, is_temporal_ext_type};
10use vortex_dtype::{DType, NativePType};
11use vortex_error::{VortexResult, vortex_bail};
12
13use crate::Array;
14use crate::arrays::{ExtensionArray, ExtensionEncoding, TemporalArray};
15use crate::canonical::ToCanonical;
16use crate::compute::{ToArrowFn, to_arrow, try_cast};
17
18impl ToArrowFn<&ExtensionArray> for ExtensionEncoding {
19    fn to_arrow(
20        &self,
21        array: &ExtensionArray,
22        data_type: &DataType,
23    ) -> VortexResult<Option<ArrayRef>> {
24        // NOTE(ngates): this is really gross... but I guess it's ok given how tightly integrated
25        //  we are with Arrow.
26        if is_temporal_ext_type(array.id()) {
27            temporal_to_arrow(TemporalArray::try_from(array.to_array())?).map(Some)
28        } else {
29            // Convert storage array directly into arrow, losing type information
30            // that will let us round-trip.
31            // TODO(aduffy): https://github.com/spiraldb/vortex/issues/1167
32            to_arrow(array.storage(), data_type).map(Some)
33        }
34    }
35}
36
37fn temporal_to_arrow(temporal_array: TemporalArray) -> VortexResult<ArrayRef> {
38    macro_rules! extract_temporal_values {
39        ($values:expr, $prim:ty) => {{
40            let temporal_values = try_cast(
41                $values,
42                &DType::Primitive(<$prim as NativePType>::PTYPE, $values.dtype().nullability()),
43            )?
44            .to_primitive()?;
45            let nulls = temporal_values.validity_mask()?.to_null_buffer();
46            let scalars = temporal_values.into_buffer().into_arrow_scalar_buffer();
47
48            (scalars, nulls)
49        }};
50    }
51
52    Ok(match temporal_array.temporal_metadata() {
53        TemporalMetadata::Date(time_unit) => match time_unit {
54            TimeUnit::D => {
55                let (scalars, nulls) =
56                    extract_temporal_values!(temporal_array.temporal_values(), i32);
57                Arc::new(Date32Array::new(scalars, nulls))
58            }
59            TimeUnit::Ms => {
60                let (scalars, nulls) =
61                    extract_temporal_values!(temporal_array.temporal_values(), i64);
62                Arc::new(Date64Array::new(scalars, nulls))
63            }
64            _ => vortex_bail!(
65                "Invalid TimeUnit {time_unit} for {}",
66                temporal_array.ext_dtype().id()
67            ),
68        },
69        TemporalMetadata::Time(time_unit) => match time_unit {
70            TimeUnit::S => {
71                let (scalars, nulls) =
72                    extract_temporal_values!(temporal_array.temporal_values(), i32);
73                Arc::new(Time32SecondArray::new(scalars, nulls))
74            }
75            TimeUnit::Ms => {
76                let (scalars, nulls) =
77                    extract_temporal_values!(temporal_array.temporal_values(), i32);
78                Arc::new(Time32MillisecondArray::new(scalars, nulls))
79            }
80            TimeUnit::Us => {
81                let (scalars, nulls) =
82                    extract_temporal_values!(temporal_array.temporal_values(), i64);
83                Arc::new(Time64MicrosecondArray::new(scalars, nulls))
84            }
85            TimeUnit::Ns => {
86                let (scalars, nulls) =
87                    extract_temporal_values!(temporal_array.temporal_values(), i64);
88                Arc::new(Time64NanosecondArray::new(scalars, nulls))
89            }
90            _ => vortex_bail!(
91                "Invalid TimeUnit {time_unit} for {}",
92                temporal_array.ext_dtype().id()
93            ),
94        },
95        TemporalMetadata::Timestamp(time_unit, _) => {
96            let (scalars, nulls) = extract_temporal_values!(temporal_array.temporal_values(), i64);
97            match time_unit {
98                TimeUnit::Ns => Arc::new(TimestampNanosecondArray::new(scalars, nulls)),
99                TimeUnit::Us => Arc::new(TimestampMicrosecondArray::new(scalars, nulls)),
100                TimeUnit::Ms => Arc::new(TimestampMillisecondArray::new(scalars, nulls)),
101                TimeUnit::S => Arc::new(TimestampSecondArray::new(scalars, nulls)),
102                _ => vortex_bail!(
103                    "Invalid TimeUnit {time_unit} for {}",
104                    temporal_array.ext_dtype().id()
105                ),
106            }
107        }
108    })
109}