vortex_array/expr/exprs/cast/
mod.rs1mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use prost::Message;
10use vortex_dtype::DType;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_err;
15use vortex_proto::expr as pb;
16use vortex_session::VortexSession;
17
18use crate::AnyColumnar;
19use crate::ArrayRef;
20use crate::CanonicalView;
21use crate::ColumnarView;
22use crate::ExecutionCtx;
23use crate::arrays::BoolVTable;
24use crate::arrays::ConstantArray;
25use crate::arrays::ConstantVTable;
26use crate::arrays::DecimalVTable;
27use crate::arrays::ExtensionVTable;
28use crate::arrays::FixedSizeListVTable;
29use crate::arrays::ListViewVTable;
30use crate::arrays::NullVTable;
31use crate::arrays::PrimitiveVTable;
32use crate::arrays::StructVTable;
33use crate::arrays::VarBinViewVTable;
34use crate::builtins::ArrayBuiltins;
35use crate::expr::Arity;
36use crate::expr::ChildName;
37use crate::expr::ExecutionArgs;
38use crate::expr::ExprId;
39use crate::expr::ReduceCtx;
40use crate::expr::ReduceNode;
41use crate::expr::ReduceNodeRef;
42use crate::expr::StatsCatalog;
43use crate::expr::VTable;
44use crate::expr::VTableExt;
45use crate::expr::expression::Expression;
46use crate::expr::lit;
47use crate::expr::stats::Stat;
48
49pub struct Cast;
51
52impl VTable for Cast {
53 type Options = DType;
54
55 fn id(&self) -> ExprId {
56 ExprId::from("vortex.cast")
57 }
58
59 fn serialize(&self, dtype: &DType) -> VortexResult<Option<Vec<u8>>> {
60 Ok(Some(
61 pb::CastOpts {
62 target: Some(dtype.try_into()?),
63 }
64 .encode_to_vec(),
65 ))
66 }
67
68 fn deserialize(
69 &self,
70 _metadata: &[u8],
71 session: &VortexSession,
72 ) -> VortexResult<Self::Options> {
73 let proto = pb::CastOpts::decode(_metadata)?.target;
74 DType::from_proto(
75 proto
76 .as_ref()
77 .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?,
78 session,
79 )
80 }
81
82 fn arity(&self, _options: &DType) -> Arity {
83 Arity::Exact(1)
84 }
85
86 fn child_name(&self, _instance: &DType, child_idx: usize) -> ChildName {
87 match child_idx {
88 0 => ChildName::from("input"),
89 _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
90 }
91 }
92
93 fn fmt_sql(&self, dtype: &DType, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
94 write!(f, "cast(")?;
95 expr.children()[0].fmt_sql(f)?;
96 write!(f, " as {}", dtype)?;
97 write!(f, ")")
98 }
99
100 fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult<DType> {
101 Ok(dtype.clone())
102 }
103
104 fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult<ArrayRef> {
105 let input = args
106 .inputs
107 .pop()
108 .vortex_expect("missing input for Cast expression");
109
110 let Some(columnar) = input.as_opt::<AnyColumnar>() else {
111 return input
112 .execute::<ArrayRef>(args.ctx)?
113 .cast(target_dtype.clone());
114 };
115
116 match columnar {
117 ColumnarView::Canonical(canonical) => {
118 match cast_canonical(canonical.clone(), target_dtype, args.ctx)? {
119 Some(result) => Ok(result),
120 None => vortex_bail!(
121 "No CastKernel to cast canonical array {} from {} to {}",
122 canonical.as_ref().encoding_id(),
123 canonical.as_ref().dtype(),
124 target_dtype,
125 ),
126 }
127 }
128 ColumnarView::Constant(constant) => match cast_constant(constant, target_dtype)? {
129 Some(result) => Ok(result),
130 None => vortex_bail!(
131 "No CastReduce to cast constant array from {} to {}",
132 constant.dtype(),
133 target_dtype,
134 ),
135 },
136 }
137 }
138
139 fn reduce(
140 &self,
141 target_dtype: &DType,
142 node: &dyn ReduceNode,
143 _ctx: &dyn ReduceCtx,
144 ) -> VortexResult<Option<ReduceNodeRef>> {
145 let child = node.child(0);
147 if &child.node_dtype()? == target_dtype {
148 return Ok(Some(child));
149 }
150 Ok(None)
151 }
152
153 fn stat_expression(
154 &self,
155 dtype: &DType,
156 expr: &Expression,
157 stat: Stat,
158 catalog: &dyn StatsCatalog,
159 ) -> Option<Expression> {
160 match stat {
161 Stat::IsConstant
162 | Stat::IsSorted
163 | Stat::IsStrictSorted
164 | Stat::NaNCount
165 | Stat::Sum
166 | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog),
167 Stat::Max | Stat::Min => {
168 expr.child(0)
170 .stat_expression(stat, catalog)
171 .map(|x| cast(x, dtype.clone()))
172 }
173 Stat::NullCount => {
174 None
182 }
183 }
184 }
185
186 fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult<Option<Expression>> {
187 Ok(Some(if dtype.is_nullable() {
188 expression.child(0).validity()?
189 } else {
190 lit(true)
191 }))
192 }
193
194 fn is_null_sensitive(&self, _instance: &DType) -> bool {
196 true
197 }
198}
199
200fn cast_canonical(
203 canonical: CanonicalView<'_>,
204 dtype: &DType,
205 ctx: &mut ExecutionCtx,
206) -> VortexResult<Option<ArrayRef>> {
207 match canonical {
208 CanonicalView::Null(a) => <NullVTable as CastReduce>::cast(a, dtype),
209 CanonicalView::Bool(a) => <BoolVTable as CastReduce>::cast(a, dtype),
210 CanonicalView::Primitive(a) => <PrimitiveVTable as CastKernel>::cast(a, dtype, ctx),
211 CanonicalView::Decimal(a) => <DecimalVTable as CastKernel>::cast(a, dtype, ctx),
212 CanonicalView::VarBinView(a) => <VarBinViewVTable as CastReduce>::cast(a, dtype),
213 CanonicalView::List(a) => <ListViewVTable as CastReduce>::cast(a, dtype),
214 CanonicalView::FixedSizeList(a) => <FixedSizeListVTable as CastReduce>::cast(a, dtype),
215 CanonicalView::Struct(a) => <StructVTable as CastKernel>::cast(a, dtype, ctx),
216 CanonicalView::Extension(a) => <ExtensionVTable as CastReduce>::cast(a, dtype),
217 }
218}
219
220fn cast_constant(array: &ConstantArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
222 <ConstantVTable as CastReduce>::cast(array, dtype)
223}
224
225pub fn cast(child: Expression, target: DType) -> Expression {
235 Cast.try_new_expr(target, [child])
236 .vortex_expect("Failed to create Cast expression")
237}
238
239#[cfg(test)]
240mod tests {
241 use vortex_buffer::buffer;
242 use vortex_dtype::DType;
243 use vortex_dtype::Nullability;
244 use vortex_dtype::PType;
245 use vortex_error::VortexExpect as _;
246
247 use super::cast;
248 use crate::IntoArray;
249 use crate::arrays::StructArray;
250 use crate::expr::Expression;
251 use crate::expr::exprs::get_item::get_item;
252 use crate::expr::exprs::root::root;
253 use crate::expr::test_harness;
254
255 #[test]
256 fn dtype() {
257 let dtype = test_harness::struct_dtype();
258 assert_eq!(
259 cast(root(), DType::Bool(Nullability::NonNullable))
260 .return_dtype(&dtype)
261 .unwrap(),
262 DType::Bool(Nullability::NonNullable)
263 );
264 }
265
266 #[test]
267 fn replace_children() {
268 let expr = cast(root(), DType::Bool(Nullability::Nullable));
269 expr.with_children(vec![root()])
270 .vortex_expect("operation should succeed in test");
271 }
272
273 #[test]
274 fn evaluate() {
275 let test_array = StructArray::from_fields(&[
276 ("a", buffer![0i32, 1, 2].into_array()),
277 ("b", buffer![4i64, 5, 6].into_array()),
278 ])
279 .unwrap()
280 .into_array();
281
282 let expr: Expression = cast(
283 get_item("a", root()),
284 DType::Primitive(PType::I64, Nullability::NonNullable),
285 );
286 let result = test_array.apply(&expr).unwrap();
287
288 assert_eq!(
289 result.dtype(),
290 &DType::Primitive(PType::I64, Nullability::NonNullable)
291 );
292 }
293
294 #[test]
295 fn test_display() {
296 let expr = cast(
297 get_item("value", root()),
298 DType::Primitive(PType::I64, Nullability::NonNullable),
299 );
300 assert_eq!(expr.to_string(), "cast($.value as i64)");
301
302 let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
303 assert_eq!(expr2.to_string(), "cast($ as bool?)");
304 }
305}