1use std::sync::Arc;
2
3use arrow_array::*;
4use vortex_dtype::datetime::{TemporalMetadata, TimeUnit, is_temporal_ext_type};
5use vortex_dtype::{DType, PType};
6use vortex_error::{VortexError, vortex_bail, vortex_err};
7
8use crate::Scalar;
9
10macro_rules! value_to_arrow_scalar {
11 ($V:expr, $AR:ty) => {
12 Ok(std::sync::Arc::new(
13 $V.map(<$AR>::new_scalar)
14 .unwrap_or_else(|| arrow_array::Scalar::new(<$AR>::new_null(1))),
15 ))
16 };
17}
18
19impl TryFrom<&Scalar> for Arc<dyn Datum> {
20 type Error = VortexError;
21
22 fn try_from(value: &Scalar) -> Result<Arc<dyn Datum>, Self::Error> {
23 match value.dtype() {
24 DType::Null => Ok(Arc::new(NullArray::new(1))),
25 DType::Bool(_) => value_to_arrow_scalar!(value.as_bool().value(), BooleanArray),
26 DType::Primitive(ptype, ..) => {
27 let scalar = value.as_primitive();
28 Ok(match ptype {
29 PType::U8 => scalar
30 .typed_value()
31 .map(|i| Arc::new(UInt8Array::new_scalar(i)) as Arc<dyn Datum>)
32 .unwrap_or_else(|| Arc::new(UInt8Array::new_null(1))),
33 PType::U16 => scalar
34 .typed_value()
35 .map(|i| Arc::new(UInt16Array::new_scalar(i)) as Arc<dyn Datum>)
36 .unwrap_or_else(|| Arc::new(UInt16Array::new_null(1))),
37 PType::U32 => scalar
38 .typed_value()
39 .map(|i| Arc::new(UInt32Array::new_scalar(i)) as Arc<dyn Datum>)
40 .unwrap_or_else(|| Arc::new(UInt32Array::new_null(1))),
41 PType::U64 => scalar
42 .typed_value()
43 .map(|i| Arc::new(UInt64Array::new_scalar(i)) as Arc<dyn Datum>)
44 .unwrap_or_else(|| Arc::new(UInt64Array::new_null(1))),
45 PType::I8 => scalar
46 .typed_value()
47 .map(|i| Arc::new(Int8Array::new_scalar(i)) as Arc<dyn Datum>)
48 .unwrap_or_else(|| Arc::new(Int8Array::new_null(1))),
49 PType::I16 => scalar
50 .typed_value()
51 .map(|i| Arc::new(Int16Array::new_scalar(i)) as Arc<dyn Datum>)
52 .unwrap_or_else(|| Arc::new(Int16Array::new_null(1))),
53 PType::I32 => scalar
54 .typed_value()
55 .map(|i| Arc::new(Int32Array::new_scalar(i)) as Arc<dyn Datum>)
56 .unwrap_or_else(|| Arc::new(Int32Array::new_null(1))),
57 PType::I64 => scalar
58 .typed_value()
59 .map(|i| Arc::new(Int64Array::new_scalar(i)) as Arc<dyn Datum>)
60 .unwrap_or_else(|| Arc::new(Int64Array::new_null(1))),
61 PType::F16 => scalar
62 .typed_value()
63 .map(|i| Arc::new(Float16Array::new_scalar(i)) as Arc<dyn Datum>)
64 .unwrap_or_else(|| Arc::new(Float16Array::new_null(1))),
65 PType::F32 => scalar
66 .typed_value()
67 .map(|i| Arc::new(Float32Array::new_scalar(i)) as Arc<dyn Datum>)
68 .unwrap_or_else(|| Arc::new(Float32Array::new_null(1))),
69 PType::F64 => scalar
70 .typed_value()
71 .map(|i| Arc::new(Float64Array::new_scalar(i)) as Arc<dyn Datum>)
72 .unwrap_or_else(|| Arc::new(Float64Array::new_null(1))),
73 })
74 }
75 DType::Utf8(_) => {
76 value_to_arrow_scalar!(value.as_utf8().value(), StringViewArray)
77 }
78 DType::Binary(_) => {
79 value_to_arrow_scalar!(value.as_binary().value(), BinaryViewArray)
80 }
81 DType::Struct(..) => {
82 todo!("struct scalar conversion")
83 }
84 DType::List(..) => {
85 todo!("list scalar conversion")
86 }
87 DType::Extension(ext) => {
88 if is_temporal_ext_type(ext.id()) {
89 let metadata = TemporalMetadata::try_from(ext.as_ref())?;
90 let storage_scalar = value.as_extension().storage();
91 let primitive = storage_scalar
92 .as_primitive_opt()
93 .ok_or_else(|| vortex_err!("Expected primitive scalar"))?;
94
95 return match metadata {
96 TemporalMetadata::Time(u) => match u {
97 TimeUnit::Ns => value_to_arrow_scalar!(
98 primitive.as_::<i64>()?,
99 Time64NanosecondArray
100 ),
101 TimeUnit::Us => value_to_arrow_scalar!(
102 primitive.as_::<i64>()?,
103 Time64MicrosecondArray
104 ),
105 TimeUnit::Ms => value_to_arrow_scalar!(
106 primitive.as_::<i32>()?,
107 Time32MillisecondArray
108 ),
109 TimeUnit::S => {
110 value_to_arrow_scalar!(primitive.as_::<i32>()?, Time32SecondArray)
111 }
112 TimeUnit::D => {
113 vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id())
114 }
115 },
116 TemporalMetadata::Date(u) => match u {
117 TimeUnit::Ms => {
118 value_to_arrow_scalar!(primitive.as_::<i64>()?, Date64Array)
119 }
120 TimeUnit::D => {
121 value_to_arrow_scalar!(primitive.as_::<i32>()?, Date32Array)
122 }
123 _ => vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id()),
124 },
125 TemporalMetadata::Timestamp(u, _) => match u {
126 TimeUnit::Ns => value_to_arrow_scalar!(
127 primitive.as_::<i64>()?,
128 TimestampNanosecondArray
129 ),
130 TimeUnit::Us => value_to_arrow_scalar!(
131 primitive.as_::<i64>()?,
132 TimestampMicrosecondArray
133 ),
134 TimeUnit::Ms => value_to_arrow_scalar!(
135 primitive.as_::<i64>()?,
136 TimestampMillisecondArray
137 ),
138 TimeUnit::S => value_to_arrow_scalar!(
139 primitive.as_::<i64>()?,
140 TimestampSecondArray
141 ),
142 TimeUnit::D => {
143 vortex_bail!("Unsupported TimeUnit {u} for {}", ext.id())
144 }
145 },
146 };
147 }
148
149 todo!("Non temporal extension scalar conversion")
150 }
151 }
152 }
153}