vortex_array/expr/exprs/
cast.rs1use std::fmt::Formatter;
5use std::ops::Deref;
6
7use prost::Message;
8use vortex_dtype::DType;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_err;
12use vortex_proto::expr as pb;
13use vortex_vector::Datum;
14
15use crate::ArrayRef;
16use crate::compute::cast as compute_cast;
17use crate::expr::Arity;
18use crate::expr::ChildName;
19use crate::expr::ExecutionArgs;
20use crate::expr::ExprId;
21use crate::expr::ReduceCtx;
22use crate::expr::ReduceNode;
23use crate::expr::ReduceNodeRef;
24use crate::expr::StatsCatalog;
25use crate::expr::VTable;
26use crate::expr::VTableExt;
27use crate::expr::expression::Expression;
28use crate::expr::stats::Stat;
29
30pub struct Cast;
32
33impl VTable for Cast {
34 type Options = DType;
35
36 fn id(&self) -> ExprId {
37 ExprId::from("vortex.cast")
38 }
39
40 fn serialize(&self, dtype: &DType) -> VortexResult<Option<Vec<u8>>> {
41 Ok(Some(
42 pb::CastOpts {
43 target: Some(dtype.into()),
44 }
45 .encode_to_vec(),
46 ))
47 }
48
49 fn deserialize(&self, metadata: &[u8]) -> VortexResult<DType> {
50 pb::CastOpts::decode(metadata)?
51 .target
52 .as_ref()
53 .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?
54 .try_into()
55 }
56
57 fn arity(&self, _options: &DType) -> Arity {
58 Arity::Exact(1)
59 }
60
61 fn child_name(&self, _instance: &DType, child_idx: usize) -> ChildName {
62 match child_idx {
63 0 => ChildName::from("input"),
64 _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
65 }
66 }
67
68 fn fmt_sql(&self, dtype: &DType, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
69 write!(f, "cast(")?;
70 expr.children()[0].fmt_sql(f)?;
71 write!(f, " as {}", dtype)?;
72 write!(f, ")")
73 }
74
75 fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult<DType> {
76 Ok(dtype.clone())
77 }
78
79 fn evaluate(
80 &self,
81 dtype: &DType,
82 expr: &Expression,
83 scope: &ArrayRef,
84 ) -> VortexResult<ArrayRef> {
85 let array = expr.children()[0].evaluate(scope)?;
86 compute_cast(&array, dtype).map_err(|e| {
87 e.with_context(format!(
88 "Failed to cast array of dtype {} to {}",
89 array.dtype(),
90 expr.deref()
91 ))
92 })
93 }
94
95 fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult<Datum> {
96 let input = args
97 .datums
98 .pop()
99 .vortex_expect("missing input for Cast expression");
100 vortex_compute::cast::Cast::cast(&input, target_dtype)
101 }
102
103 fn reduce(
104 &self,
105 target_dtype: &DType,
106 node: &dyn ReduceNode,
107 _ctx: &dyn ReduceCtx,
108 ) -> VortexResult<Option<ReduceNodeRef>> {
109 let child = node.child(0);
111 if &child.node_dtype()? == target_dtype {
112 return Ok(Some(child));
113 }
114 Ok(None)
115 }
116
117 fn stat_expression(
118 &self,
119 dtype: &DType,
120 expr: &Expression,
121 stat: Stat,
122 catalog: &dyn StatsCatalog,
123 ) -> Option<Expression> {
124 match stat {
125 Stat::IsConstant
126 | Stat::IsSorted
127 | Stat::IsStrictSorted
128 | Stat::NaNCount
129 | Stat::Sum
130 | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog),
131 Stat::Max | Stat::Min => {
132 expr.child(0)
134 .stat_expression(stat, catalog)
135 .map(|x| cast(x, dtype.clone()))
136 }
137 Stat::NullCount => {
138 None
146 }
147 }
148 }
149
150 fn is_null_sensitive(&self, _instance: &DType) -> bool {
152 true
153 }
154}
155
156pub fn cast(child: Expression, target: DType) -> Expression {
166 Cast.try_new_expr(target, [child])
167 .vortex_expect("Failed to create Cast expression")
168}
169
170#[cfg(test)]
171mod tests {
172 use vortex_buffer::buffer;
173 use vortex_dtype::DType;
174 use vortex_dtype::Nullability;
175 use vortex_dtype::PType;
176 use vortex_error::VortexExpect as _;
177
178 use super::cast;
179 use crate::IntoArray;
180 use crate::arrays::StructArray;
181 use crate::expr::Expression;
182 use crate::expr::exprs::get_item::get_item;
183 use crate::expr::exprs::root::root;
184 use crate::expr::test_harness;
185
186 #[test]
187 fn dtype() {
188 let dtype = test_harness::struct_dtype();
189 assert_eq!(
190 cast(root(), DType::Bool(Nullability::NonNullable))
191 .return_dtype(&dtype)
192 .unwrap(),
193 DType::Bool(Nullability::NonNullable)
194 );
195 }
196
197 #[test]
198 fn replace_children() {
199 let expr = cast(root(), DType::Bool(Nullability::Nullable));
200 expr.with_children(vec![root()])
201 .vortex_expect("operation should succeed in test");
202 }
203
204 #[test]
205 fn evaluate() {
206 let test_array = StructArray::from_fields(&[
207 ("a", buffer![0i32, 1, 2].into_array()),
208 ("b", buffer![4i64, 5, 6].into_array()),
209 ])
210 .unwrap()
211 .into_array();
212
213 let expr: Expression = cast(
214 get_item("a", root()),
215 DType::Primitive(PType::I64, Nullability::NonNullable),
216 );
217 let result = expr.evaluate(&test_array).unwrap();
218
219 assert_eq!(
220 result.dtype(),
221 &DType::Primitive(PType::I64, Nullability::NonNullable)
222 );
223 }
224
225 #[test]
226 fn test_display() {
227 let expr = cast(
228 get_item("value", root()),
229 DType::Primitive(PType::I64, Nullability::NonNullable),
230 );
231 assert_eq!(expr.to_string(), "cast($.value as i64)");
232
233 let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
234 assert_eq!(expr2.to_string(), "cast($ as bool?)");
235 }
236}