vortex_scalar/
arrow.rs

1use std::sync::Arc;
2
3use arrow_array::*;
4use vortex_dtype::datetime::{TemporalMetadata, TimeUnit, is_temporal_ext_type};
5use vortex_dtype::{DType, PType};
6use vortex_error::{VortexError, vortex_bail, vortex_err};
7
8use crate::Scalar;
9
10macro_rules! value_to_arrow_scalar {
11    ($V:expr, $AR:ty) => {
12        Ok(std::sync::Arc::new(
13            $V.map(<$AR>::new_scalar)
14                .unwrap_or_else(|| arrow_array::Scalar::new(<$AR>::new_null(1))),
15        ))
16    };
17}
18
19impl TryFrom<&Scalar> for Arc<dyn Datum> {
20    type Error = VortexError;
21
22    fn try_from(value: &Scalar) -> Result<Arc<dyn Datum>, Self::Error> {
23        match value.dtype() {
24            DType::Null => Ok(Arc::new(NullArray::new(1))),
25            DType::Bool(_) => value_to_arrow_scalar!(value.as_bool().value(), BooleanArray),
26            DType::Primitive(ptype, ..) => {
27                let scalar = value.as_primitive();
28                Ok(match ptype {
29                    PType::U8 => scalar
30                        .typed_value()
31                        .map(|i| Arc::new(UInt8Array::new_scalar(i)) as Arc<dyn Datum>)
32                        .unwrap_or_else(|| Arc::new(UInt8Array::new_null(1))),
33                    PType::U16 => scalar
34                        .typed_value()
35                        .map(|i| Arc::new(UInt16Array::new_scalar(i)) as Arc<dyn Datum>)
36                        .unwrap_or_else(|| Arc::new(UInt16Array::new_null(1))),
37                    PType::U32 => scalar
38                        .typed_value()
39                        .map(|i| Arc::new(UInt32Array::new_scalar(i)) as Arc<dyn Datum>)
40                        .unwrap_or_else(|| Arc::new(UInt32Array::new_null(1))),
41                    PType::U64 => scalar
42                        .typed_value()
43                        .map(|i| Arc::new(UInt64Array::new_scalar(i)) as Arc<dyn Datum>)
44                        .unwrap_or_else(|| Arc::new(UInt64Array::new_null(1))),
45                    PType::I8 => scalar
46                        .typed_value()
47                        .map(|i| Arc::new(Int8Array::new_scalar(i)) as Arc<dyn Datum>)
48                        .unwrap_or_else(|| Arc::new(Int8Array::new_null(1))),
49                    PType::I16 => scalar
50                        .typed_value()
51                        .map(|i| Arc::new(Int16Array::new_scalar(i)) as Arc<dyn Datum>)
52                        .unwrap_or_else(|| Arc::new(Int16Array::new_null(1))),
53                    PType::I32 => scalar
54                        .typed_value()
55                        .map(|i| Arc::new(Int32Array::new_scalar(i)) as Arc<dyn Datum>)
56                        .unwrap_or_else(|| Arc::new(Int32Array::new_null(1))),
57                    PType::I64 => scalar
58                        .typed_value()
59                        .map(|i| Arc::new(Int64Array::new_scalar(i)) as Arc<dyn Datum>)
60                        .unwrap_or_else(|| Arc::new(Int64Array::new_null(1))),
61                    PType::F16 => scalar
62                        .typed_value()
63                        .map(|i| Arc::new(Float16Array::new_scalar(i)) as Arc<dyn Datum>)
64                        .unwrap_or_else(|| Arc::new(Float16Array::new_null(1))),
65                    PType::F32 => scalar
66                        .typed_value()
67                        .map(|i| Arc::new(Float32Array::new_scalar(i)) as Arc<dyn Datum>)
68                        .unwrap_or_else(|| Arc::new(Float32Array::new_null(1))),
69                    PType::F64 => scalar
70                        .typed_value()
71                        .map(|i| Arc::new(Float64Array::new_scalar(i)) as Arc<dyn Datum>)
72                        .unwrap_or_else(|| Arc::new(Float64Array::new_null(1))),
73                })
74            }
75            DType::Utf8(_) => {
76                value_to_arrow_scalar!(value.as_utf8().value(), StringViewArray)
77            }
78            DType::Binary(_) => {
79                value_to_arrow_scalar!(value.as_binary().value(), BinaryViewArray)
80            }
81            DType::Struct(..) => {
82                todo!("struct scalar conversion")
83            }
84            DType::List(..) => {
85                todo!("list scalar conversion")
86            }
87            DType::Extension(ext) => {
88                if is_temporal_ext_type(ext.id()) {
89                    let metadata = TemporalMetadata::try_from(ext.as_ref())?;
90                    let storage_scalar = value.as_extension().storage();
91                    let primitive = storage_scalar
92                        .as_primitive_opt()
93                        .ok_or_else(|| vortex_err!("Expected primitive scalar"))?;
94
95                    return match metadata {
96                        TemporalMetadata::Time(u) => match u {
97                            TimeUnit::Ns => value_to_arrow_scalar!(
98                                primitive.as_::<i64>()?,
99                                Time64NanosecondArray
100                            ),
101                            TimeUnit::Us => value_to_arrow_scalar!(
102                                primitive.as_::<i64>()?,
103                                Time64MicrosecondArray
104                            ),
105                            TimeUnit::Ms => value_to_arrow_scalar!(
106                                primitive.as_::<i32>()?,
107                                Time32MillisecondArray
108                            ),
109                            TimeUnit::S => {
110                                value_to_arrow_scalar!(primitive.as_::<i32>()?, Time32SecondArray)
111                            }
112                            TimeUnit::D => {
113                                vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id())
114                            }
115                        },
116                        TemporalMetadata::Date(u) => match u {
117                            TimeUnit::Ms => {
118                                value_to_arrow_scalar!(primitive.as_::<i64>()?, Date64Array)
119                            }
120                            TimeUnit::D => {
121                                value_to_arrow_scalar!(primitive.as_::<i32>()?, Date32Array)
122                            }
123                            _ => vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id()),
124                        },
125                        TemporalMetadata::Timestamp(u, _) => match u {
126                            TimeUnit::Ns => value_to_arrow_scalar!(
127                                primitive.as_::<i64>()?,
128                                TimestampNanosecondArray
129                            ),
130                            TimeUnit::Us => value_to_arrow_scalar!(
131                                primitive.as_::<i64>()?,
132                                TimestampMicrosecondArray
133                            ),
134                            TimeUnit::Ms => value_to_arrow_scalar!(
135                                primitive.as_::<i64>()?,
136                                TimestampMillisecondArray
137                            ),
138                            TimeUnit::S => value_to_arrow_scalar!(
139                                primitive.as_::<i64>()?,
140                                TimestampSecondArray
141                            ),
142                            TimeUnit::D => {
143                                vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id())
144                            }
145                        },
146                    };
147                }
148
149                todo!("Non temporal extension scalar conversion")
150            }
151        }
152    }
153}