vortex_scalar/
arrow.rs

1use std::sync::Arc;
2
3use arrow_array::*;
4use vortex_dtype::datetime::{TemporalMetadata, TimeUnit, is_temporal_ext_type};
5use vortex_dtype::{DType, PType};
6use vortex_error::{VortexError, vortex_bail, vortex_err};
7
8use crate::Scalar;
9use crate::decimal::DecimalValue;
10
11macro_rules! value_to_arrow_scalar {
12    ($V:expr, $AR:ty) => {
13        Ok(std::sync::Arc::new(
14            $V.map(<$AR>::new_scalar)
15                .unwrap_or_else(|| arrow_array::Scalar::new(<$AR>::new_null(1))),
16        ))
17    };
18}
19
20impl TryFrom<&Scalar> for Arc<dyn Datum> {
21    type Error = VortexError;
22
23    fn try_from(value: &Scalar) -> Result<Arc<dyn Datum>, Self::Error> {
24        match value.dtype() {
25            DType::Null => Ok(Arc::new(NullArray::new(1))),
26            DType::Bool(_) => value_to_arrow_scalar!(value.as_bool().value(), BooleanArray),
27            DType::Primitive(ptype, ..) => {
28                let scalar = value.as_primitive();
29                Ok(match ptype {
30                    PType::U8 => scalar
31                        .typed_value()
32                        .map(|i| Arc::new(UInt8Array::new_scalar(i)) as Arc<dyn Datum>)
33                        .unwrap_or_else(|| Arc::new(UInt8Array::new_null(1))),
34                    PType::U16 => scalar
35                        .typed_value()
36                        .map(|i| Arc::new(UInt16Array::new_scalar(i)) as Arc<dyn Datum>)
37                        .unwrap_or_else(|| Arc::new(UInt16Array::new_null(1))),
38                    PType::U32 => scalar
39                        .typed_value()
40                        .map(|i| Arc::new(UInt32Array::new_scalar(i)) as Arc<dyn Datum>)
41                        .unwrap_or_else(|| Arc::new(UInt32Array::new_null(1))),
42                    PType::U64 => scalar
43                        .typed_value()
44                        .map(|i| Arc::new(UInt64Array::new_scalar(i)) as Arc<dyn Datum>)
45                        .unwrap_or_else(|| Arc::new(UInt64Array::new_null(1))),
46                    PType::I8 => scalar
47                        .typed_value()
48                        .map(|i| Arc::new(Int8Array::new_scalar(i)) as Arc<dyn Datum>)
49                        .unwrap_or_else(|| Arc::new(Int8Array::new_null(1))),
50                    PType::I16 => scalar
51                        .typed_value()
52                        .map(|i| Arc::new(Int16Array::new_scalar(i)) as Arc<dyn Datum>)
53                        .unwrap_or_else(|| Arc::new(Int16Array::new_null(1))),
54                    PType::I32 => scalar
55                        .typed_value()
56                        .map(|i| Arc::new(Int32Array::new_scalar(i)) as Arc<dyn Datum>)
57                        .unwrap_or_else(|| Arc::new(Int32Array::new_null(1))),
58                    PType::I64 => scalar
59                        .typed_value()
60                        .map(|i| Arc::new(Int64Array::new_scalar(i)) as Arc<dyn Datum>)
61                        .unwrap_or_else(|| Arc::new(Int64Array::new_null(1))),
62                    PType::F16 => scalar
63                        .typed_value()
64                        .map(|i| Arc::new(Float16Array::new_scalar(i)) as Arc<dyn Datum>)
65                        .unwrap_or_else(|| Arc::new(Float16Array::new_null(1))),
66                    PType::F32 => scalar
67                        .typed_value()
68                        .map(|i| Arc::new(Float32Array::new_scalar(i)) as Arc<dyn Datum>)
69                        .unwrap_or_else(|| Arc::new(Float32Array::new_null(1))),
70                    PType::F64 => scalar
71                        .typed_value()
72                        .map(|i| Arc::new(Float64Array::new_scalar(i)) as Arc<dyn Datum>)
73                        .unwrap_or_else(|| Arc::new(Float64Array::new_null(1))),
74                })
75            }
76            DType::Decimal(..) => match value.as_decimal().decimal_value() {
77                // TODO(joe): replace with decimal32, etc.
78                Some(DecimalValue::I8(v)) => Ok(Arc::new(Decimal128Array::new_scalar(*v as i128))),
79                Some(DecimalValue::I16(v)) => Ok(Arc::new(Decimal128Array::new_scalar(*v as i128))),
80                Some(DecimalValue::I32(v)) => Ok(Arc::new(Decimal128Array::new_scalar(*v as i128))),
81                Some(DecimalValue::I64(v)) => Ok(Arc::new(Decimal128Array::new_scalar(*v as i128))),
82                Some(DecimalValue::I128(v128)) => Ok(Arc::new(Decimal128Array::new_scalar(*v128))),
83                Some(DecimalValue::I256(v256)) => {
84                    Ok(Arc::new(Decimal256Array::new_scalar((*v256).into())))
85                }
86                None => Ok(Arc::new(arrow_array::Scalar::new(
87                    Decimal128Array::new_null(1),
88                ))),
89            },
90            DType::Utf8(_) => {
91                value_to_arrow_scalar!(value.as_utf8().value(), StringViewArray)
92            }
93            DType::Binary(_) => {
94                value_to_arrow_scalar!(value.as_binary().value(), BinaryViewArray)
95            }
96            DType::Struct(..) => {
97                todo!("struct scalar conversion")
98            }
99            DType::List(..) => {
100                todo!("list scalar conversion")
101            }
102            DType::Extension(ext) => {
103                if is_temporal_ext_type(ext.id()) {
104                    let metadata = TemporalMetadata::try_from(ext.as_ref())?;
105                    let storage_scalar = value.as_extension().storage();
106                    let primitive = storage_scalar
107                        .as_primitive_opt()
108                        .ok_or_else(|| vortex_err!("Expected primitive scalar"))?;
109
110                    return match metadata {
111                        TemporalMetadata::Time(u) => match u {
112                            TimeUnit::Ns => value_to_arrow_scalar!(
113                                primitive.as_::<i64>()?,
114                                Time64NanosecondArray
115                            ),
116                            TimeUnit::Us => value_to_arrow_scalar!(
117                                primitive.as_::<i64>()?,
118                                Time64MicrosecondArray
119                            ),
120                            TimeUnit::Ms => value_to_arrow_scalar!(
121                                primitive.as_::<i32>()?,
122                                Time32MillisecondArray
123                            ),
124                            TimeUnit::S => {
125                                value_to_arrow_scalar!(primitive.as_::<i32>()?, Time32SecondArray)
126                            }
127                            TimeUnit::D => {
128                                vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id())
129                            }
130                        },
131                        TemporalMetadata::Date(u) => match u {
132                            TimeUnit::Ms => {
133                                value_to_arrow_scalar!(primitive.as_::<i64>()?, Date64Array)
134                            }
135                            TimeUnit::D => {
136                                value_to_arrow_scalar!(primitive.as_::<i32>()?, Date32Array)
137                            }
138                            _ => vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id()),
139                        },
140                        TemporalMetadata::Timestamp(u, _) => match u {
141                            TimeUnit::Ns => value_to_arrow_scalar!(
142                                primitive.as_::<i64>()?,
143                                TimestampNanosecondArray
144                            ),
145                            TimeUnit::Us => value_to_arrow_scalar!(
146                                primitive.as_::<i64>()?,
147                                TimestampMicrosecondArray
148                            ),
149                            TimeUnit::Ms => value_to_arrow_scalar!(
150                                primitive.as_::<i64>()?,
151                                TimestampMillisecondArray
152                            ),
153                            TimeUnit::S => value_to_arrow_scalar!(
154                                primitive.as_::<i64>()?,
155                                TimestampSecondArray
156                            ),
157                            TimeUnit::D => {
158                                vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id())
159                            }
160                        },
161                    };
162                }
163
164                todo!("Non temporal extension scalar conversion")
165            }
166        }
167    }
168}