vortex_scalar/arrow/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arrow_array::*;
7use vortex_dtype::datetime::{TemporalMetadata, TimeUnit, is_temporal_ext_type};
8use vortex_dtype::{DType, PType};
9use vortex_error::{VortexError, vortex_bail, vortex_err};
10
11use crate::Scalar;
12use crate::decimal::DecimalValue;
13
14macro_rules! value_to_arrow_scalar {
15    ($V:expr, $AR:ty) => {
16        Ok(std::sync::Arc::new(
17            $V.map(<$AR>::new_scalar)
18                .unwrap_or_else(|| arrow_array::Scalar::new(<$AR>::new_null(1))),
19        ))
20    };
21}
22
23impl TryFrom<&Scalar> for Arc<dyn Datum> {
24    type Error = VortexError;
25
26    fn try_from(value: &Scalar) -> Result<Arc<dyn Datum>, Self::Error> {
27        match value.dtype() {
28            DType::Null => Ok(Arc::new(NullArray::new(1))),
29            DType::Bool(_) => value_to_arrow_scalar!(value.as_bool().value(), BooleanArray),
30            DType::Primitive(ptype, ..) => {
31                let scalar = value.as_primitive();
32                Ok(match ptype {
33                    PType::U8 => scalar
34                        .typed_value()
35                        .map(|i| Arc::new(UInt8Array::new_scalar(i)) as Arc<dyn Datum>)
36                        .unwrap_or_else(|| Arc::new(UInt8Array::new_null(1))),
37                    PType::U16 => scalar
38                        .typed_value()
39                        .map(|i| Arc::new(UInt16Array::new_scalar(i)) as Arc<dyn Datum>)
40                        .unwrap_or_else(|| Arc::new(UInt16Array::new_null(1))),
41                    PType::U32 => scalar
42                        .typed_value()
43                        .map(|i| Arc::new(UInt32Array::new_scalar(i)) as Arc<dyn Datum>)
44                        .unwrap_or_else(|| Arc::new(UInt32Array::new_null(1))),
45                    PType::U64 => scalar
46                        .typed_value()
47                        .map(|i| Arc::new(UInt64Array::new_scalar(i)) as Arc<dyn Datum>)
48                        .unwrap_or_else(|| Arc::new(UInt64Array::new_null(1))),
49                    PType::I8 => scalar
50                        .typed_value()
51                        .map(|i| Arc::new(Int8Array::new_scalar(i)) as Arc<dyn Datum>)
52                        .unwrap_or_else(|| Arc::new(Int8Array::new_null(1))),
53                    PType::I16 => scalar
54                        .typed_value()
55                        .map(|i| Arc::new(Int16Array::new_scalar(i)) as Arc<dyn Datum>)
56                        .unwrap_or_else(|| Arc::new(Int16Array::new_null(1))),
57                    PType::I32 => scalar
58                        .typed_value()
59                        .map(|i| Arc::new(Int32Array::new_scalar(i)) as Arc<dyn Datum>)
60                        .unwrap_or_else(|| Arc::new(Int32Array::new_null(1))),
61                    PType::I64 => scalar
62                        .typed_value()
63                        .map(|i| Arc::new(Int64Array::new_scalar(i)) as Arc<dyn Datum>)
64                        .unwrap_or_else(|| Arc::new(Int64Array::new_null(1))),
65                    PType::F16 => scalar
66                        .typed_value()
67                        .map(|i| Arc::new(Float16Array::new_scalar(i)) as Arc<dyn Datum>)
68                        .unwrap_or_else(|| Arc::new(Float16Array::new_null(1))),
69                    PType::F32 => scalar
70                        .typed_value()
71                        .map(|i| Arc::new(Float32Array::new_scalar(i)) as Arc<dyn Datum>)
72                        .unwrap_or_else(|| Arc::new(Float32Array::new_null(1))),
73                    PType::F64 => scalar
74                        .typed_value()
75                        .map(|i| Arc::new(Float64Array::new_scalar(i)) as Arc<dyn Datum>)
76                        .unwrap_or_else(|| Arc::new(Float64Array::new_null(1))),
77                })
78            }
79            DType::Decimal(..) => match value.as_decimal().decimal_value() {
80                // TODO(joe): replace with decimal32, etc.
81                Some(DecimalValue::I8(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
82                Some(DecimalValue::I16(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
83                Some(DecimalValue::I32(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
84                Some(DecimalValue::I64(v)) => Ok(Arc::new(Decimal128Array::new_scalar(v as i128))),
85                Some(DecimalValue::I128(v128)) => Ok(Arc::new(Decimal128Array::new_scalar(v128))),
86                Some(DecimalValue::I256(v256)) => {
87                    Ok(Arc::new(Decimal256Array::new_scalar(v256.into())))
88                }
89                None => Ok(Arc::new(arrow_array::Scalar::new(
90                    Decimal128Array::new_null(1),
91                ))),
92            },
93            DType::Utf8(_) => {
94                value_to_arrow_scalar!(value.as_utf8().value(), StringViewArray)
95            }
96            DType::Binary(_) => {
97                value_to_arrow_scalar!(value.as_binary().value(), BinaryViewArray)
98            }
99            DType::Struct(..) => {
100                todo!("struct scalar conversion")
101            }
102            DType::List(..) => {
103                todo!("list scalar conversion")
104            }
105            DType::FixedSizeList(..) => {
106                todo!("fixed-size list scalar conversion")
107            }
108            DType::Extension(ext) => {
109                if is_temporal_ext_type(ext.id()) {
110                    let metadata = TemporalMetadata::try_from(ext.as_ref())?;
111                    let storage_scalar = value.as_extension().storage();
112                    let primitive = storage_scalar
113                        .as_primitive_opt()
114                        .ok_or_else(|| vortex_err!("Expected primitive scalar"))?;
115
116                    return match metadata {
117                        TemporalMetadata::Time(u) => match u {
118                            TimeUnit::Nanoseconds => value_to_arrow_scalar!(
119                                primitive.as_::<i64>(),
120                                Time64NanosecondArray
121                            ),
122                            TimeUnit::Microseconds => value_to_arrow_scalar!(
123                                primitive.as_::<i64>(),
124                                Time64MicrosecondArray
125                            ),
126                            TimeUnit::Milliseconds => value_to_arrow_scalar!(
127                                primitive.as_::<i32>(),
128                                Time32MillisecondArray
129                            ),
130                            TimeUnit::Seconds => {
131                                value_to_arrow_scalar!(primitive.as_::<i32>(), Time32SecondArray)
132                            }
133                            TimeUnit::Days => {
134                                vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id())
135                            }
136                        },
137                        TemporalMetadata::Date(u) => match u {
138                            TimeUnit::Milliseconds => {
139                                value_to_arrow_scalar!(primitive.as_::<i64>(), Date64Array)
140                            }
141                            TimeUnit::Days => {
142                                value_to_arrow_scalar!(primitive.as_::<i32>(), Date32Array)
143                            }
144                            TimeUnit::Nanoseconds | TimeUnit::Microseconds | TimeUnit::Seconds => {
145                                vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id())
146                            }
147                        },
148                        TemporalMetadata::Timestamp(u, _) => match u {
149                            TimeUnit::Nanoseconds => value_to_arrow_scalar!(
150                                primitive.as_::<i64>(),
151                                TimestampNanosecondArray
152                            ),
153                            TimeUnit::Microseconds => value_to_arrow_scalar!(
154                                primitive.as_::<i64>(),
155                                TimestampMicrosecondArray
156                            ),
157                            TimeUnit::Milliseconds => value_to_arrow_scalar!(
158                                primitive.as_::<i64>(),
159                                TimestampMillisecondArray
160                            ),
161                            TimeUnit::Seconds => {
162                                value_to_arrow_scalar!(primitive.as_::<i64>(), TimestampSecondArray)
163                            }
164                            TimeUnit::Days => {
165                                vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id())
166                            }
167                        },
168                    };
169                }
170
171                todo!("Non temporal extension scalar conversion")
172            }
173        }
174    }
175}
176
177#[cfg(test)]
178mod tests;