vortex_array/scalar_fn/fns/cast/
mod.rs1mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use prost::Message;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16
17use crate::AnyColumnar;
18use crate::ArrayRef;
19use crate::CanonicalView;
20use crate::ColumnarView;
21use crate::ExecutionCtx;
22use crate::arrays::BoolVTable;
23use crate::arrays::ConstantArray;
24use crate::arrays::ConstantVTable;
25use crate::arrays::DecimalVTable;
26use crate::arrays::ExtensionVTable;
27use crate::arrays::FixedSizeListVTable;
28use crate::arrays::ListViewVTable;
29use crate::arrays::NullVTable;
30use crate::arrays::PrimitiveVTable;
31use crate::arrays::StructVTable;
32use crate::arrays::VarBinViewVTable;
33use crate::builtins::ArrayBuiltins;
34use crate::dtype::DType;
35use crate::expr::StatsCatalog;
36use crate::expr::cast;
37use crate::expr::expression::Expression;
38use crate::expr::lit;
39use crate::expr::stats::Stat;
40use crate::scalar_fn::Arity;
41use crate::scalar_fn::ChildName;
42use crate::scalar_fn::ExecutionArgs;
43use crate::scalar_fn::ReduceCtx;
44use crate::scalar_fn::ReduceNode;
45use crate::scalar_fn::ReduceNodeRef;
46use crate::scalar_fn::ScalarFnId;
47use crate::scalar_fn::ScalarFnVTable;
48
49#[derive(Clone)]
51pub struct Cast;
52
53impl ScalarFnVTable for Cast {
54 type Options = DType;
55
56 fn id(&self) -> ScalarFnId {
57 ScalarFnId::from("vortex.cast")
58 }
59
60 fn serialize(&self, dtype: &DType) -> VortexResult<Option<Vec<u8>>> {
61 Ok(Some(
62 pb::CastOpts {
63 target: Some(dtype.try_into()?),
64 }
65 .encode_to_vec(),
66 ))
67 }
68
69 fn deserialize(
70 &self,
71 _metadata: &[u8],
72 session: &VortexSession,
73 ) -> VortexResult<Self::Options> {
74 let proto = pb::CastOpts::decode(_metadata)?.target;
75 DType::from_proto(
76 proto
77 .as_ref()
78 .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?,
79 session,
80 )
81 }
82
83 fn arity(&self, _options: &DType) -> Arity {
84 Arity::Exact(1)
85 }
86
87 fn child_name(&self, _instance: &DType, child_idx: usize) -> ChildName {
88 match child_idx {
89 0 => ChildName::from("input"),
90 _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
91 }
92 }
93
94 fn fmt_sql(&self, dtype: &DType, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
95 write!(f, "cast(")?;
96 expr.children()[0].fmt_sql(f)?;
97 write!(f, " as {}", dtype)?;
98 write!(f, ")")
99 }
100
101 fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult<DType> {
102 Ok(dtype.clone())
103 }
104
105 fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult<ArrayRef> {
106 let input = args
107 .inputs
108 .pop()
109 .vortex_expect("missing input for Cast expression");
110
111 let Some(columnar) = input.as_opt::<AnyColumnar>() else {
112 return input
113 .execute::<ArrayRef>(args.ctx)?
114 .cast(target_dtype.clone());
115 };
116
117 match columnar {
118 ColumnarView::Canonical(canonical) => {
119 match cast_canonical(canonical.clone(), target_dtype, args.ctx)? {
120 Some(result) => Ok(result),
121 None => vortex_bail!(
122 "No CastKernel to cast canonical array {} from {} to {}",
123 canonical.as_ref().encoding_id(),
124 canonical.as_ref().dtype(),
125 target_dtype,
126 ),
127 }
128 }
129 ColumnarView::Constant(constant) => match cast_constant(constant, target_dtype)? {
130 Some(result) => Ok(result),
131 None => vortex_bail!(
132 "No CastReduce to cast constant array from {} to {}",
133 constant.dtype(),
134 target_dtype,
135 ),
136 },
137 }
138 }
139
140 fn reduce(
141 &self,
142 target_dtype: &DType,
143 node: &dyn ReduceNode,
144 _ctx: &dyn ReduceCtx,
145 ) -> VortexResult<Option<ReduceNodeRef>> {
146 let child = node.child(0);
148 if &child.node_dtype()? == target_dtype {
149 return Ok(Some(child));
150 }
151 Ok(None)
152 }
153
154 fn stat_expression(
155 &self,
156 dtype: &DType,
157 expr: &Expression,
158 stat: Stat,
159 catalog: &dyn StatsCatalog,
160 ) -> Option<Expression> {
161 match stat {
162 Stat::IsConstant
163 | Stat::IsSorted
164 | Stat::IsStrictSorted
165 | Stat::NaNCount
166 | Stat::Sum
167 | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog),
168 Stat::Max | Stat::Min => {
169 expr.child(0)
171 .stat_expression(stat, catalog)
172 .map(|x| cast(x, dtype.clone()))
173 }
174 Stat::NullCount => {
175 None
183 }
184 }
185 }
186
187 fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult<Option<Expression>> {
188 Ok(Some(if dtype.is_nullable() {
189 expression.child(0).validity()?
190 } else {
191 lit(true)
192 }))
193 }
194
195 fn is_null_sensitive(&self, _instance: &DType) -> bool {
197 true
198 }
199}
200
201fn cast_canonical(
204 canonical: CanonicalView<'_>,
205 dtype: &DType,
206 ctx: &mut ExecutionCtx,
207) -> VortexResult<Option<ArrayRef>> {
208 match canonical {
209 CanonicalView::Null(a) => <NullVTable as CastReduce>::cast(a, dtype),
210 CanonicalView::Bool(a) => <BoolVTable as CastReduce>::cast(a, dtype),
211 CanonicalView::Primitive(a) => <PrimitiveVTable as CastKernel>::cast(a, dtype, ctx),
212 CanonicalView::Decimal(a) => <DecimalVTable as CastKernel>::cast(a, dtype, ctx),
213 CanonicalView::VarBinView(a) => <VarBinViewVTable as CastReduce>::cast(a, dtype),
214 CanonicalView::List(a) => <ListViewVTable as CastReduce>::cast(a, dtype),
215 CanonicalView::FixedSizeList(a) => <FixedSizeListVTable as CastReduce>::cast(a, dtype),
216 CanonicalView::Struct(a) => <StructVTable as CastKernel>::cast(a, dtype, ctx),
217 CanonicalView::Extension(a) => <ExtensionVTable as CastReduce>::cast(a, dtype),
218 }
219}
220
221fn cast_constant(array: &ConstantArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
223 <ConstantVTable as CastReduce>::cast(array, dtype)
224}
225
226#[cfg(test)]
227mod tests {
228 use vortex_buffer::buffer;
229 use vortex_error::VortexExpect as _;
230
231 use crate::IntoArray;
232 use crate::arrays::StructArray;
233 use crate::dtype::DType;
234 use crate::dtype::Nullability;
235 use crate::dtype::PType;
236 use crate::expr::Expression;
237 use crate::expr::cast;
238 use crate::expr::get_item;
239 use crate::expr::root;
240 use crate::expr::test_harness;
241
242 #[test]
243 fn dtype() {
244 let dtype = test_harness::struct_dtype();
245 assert_eq!(
246 cast(root(), DType::Bool(Nullability::NonNullable))
247 .return_dtype(&dtype)
248 .unwrap(),
249 DType::Bool(Nullability::NonNullable)
250 );
251 }
252
253 #[test]
254 fn replace_children() {
255 let expr = cast(root(), DType::Bool(Nullability::Nullable));
256 expr.with_children(vec![root()])
257 .vortex_expect("operation should succeed in test");
258 }
259
260 #[test]
261 fn evaluate() {
262 let test_array = StructArray::from_fields(&[
263 ("a", buffer![0i32, 1, 2].into_array()),
264 ("b", buffer![4i64, 5, 6].into_array()),
265 ])
266 .unwrap()
267 .into_array();
268
269 let expr: Expression = cast(
270 get_item("a", root()),
271 DType::Primitive(PType::I64, Nullability::NonNullable),
272 );
273 let result = test_array.apply(&expr).unwrap();
274
275 assert_eq!(
276 result.dtype(),
277 &DType::Primitive(PType::I64, Nullability::NonNullable)
278 );
279 }
280
281 #[test]
282 fn test_display() {
283 let expr = cast(
284 get_item("value", root()),
285 DType::Primitive(PType::I64, Nullability::NonNullable),
286 );
287 assert_eq!(expr.to_string(), "cast($.value as i64)");
288
289 let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
290 assert_eq!(expr2.to_string(), "cast($ as bool?)");
291 }
292}