vortex_array/expr/exprs/cast/
mod.rs1use std::fmt::Formatter;
5use std::ops::Deref;
6
7use prost::Message;
8use vortex_dtype::DType;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr as pb;
14use vortex_vector::Vector;
15
16use crate::ArrayRef;
17use crate::compute::cast as compute_cast;
18use crate::expr::ChildName;
19use crate::expr::ExecutionArgs;
20use crate::expr::ExprId;
21use crate::expr::ExpressionView;
22use crate::expr::StatsCatalog;
23use crate::expr::VTable;
24use crate::expr::VTableExt;
25use crate::expr::expression::Expression;
26use crate::expr::stats::Stat;
27
28pub struct Cast;
30
31impl VTable for Cast {
32 type Instance = DType;
33
34 fn id(&self) -> ExprId {
35 ExprId::from("vortex.cast")
36 }
37
38 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
39 Ok(Some(
40 pb::CastOpts {
41 target: Some(instance.into()),
42 }
43 .encode_to_vec(),
44 ))
45 }
46
47 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
48 Ok(Some(
49 pb::CastOpts::decode(metadata)?
50 .target
51 .as_ref()
52 .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?
53 .try_into()?,
54 ))
55 }
56
57 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
58 if expr.children().len() != 1 {
59 vortex_bail!(
60 "Cast expression requires exactly 1 child, got {}",
61 expr.children().len()
62 );
63 }
64 Ok(())
65 }
66
67 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
68 match child_idx {
69 0 => ChildName::from("input"),
70 _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
71 }
72 }
73
74 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
75 write!(f, "cast(")?;
76 expr.children()[0].fmt_sql(f)?;
77 write!(f, " as {}", expr.data())?;
78 write!(f, ")")
79 }
80
81 fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
82 write!(f, "{}", instance)
83 }
84
85 fn return_dtype(&self, expr: &ExpressionView<Self>, _scope: &DType) -> VortexResult<DType> {
86 Ok(expr.data().clone())
87 }
88
89 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
90 let array = expr.children()[0].evaluate(scope)?;
91 compute_cast(&array, expr.data()).map_err(|e| {
92 e.with_context(format!(
93 "Failed to cast array of dtype {} to {}",
94 array.dtype(),
95 expr.deref()
96 ))
97 })
98 }
99
100 fn stat_expression(
101 &self,
102 expr: &ExpressionView<Self>,
103 stat: Stat,
104 catalog: &dyn StatsCatalog,
105 ) -> Option<Expression> {
106 match stat {
107 Stat::IsConstant
108 | Stat::IsSorted
109 | Stat::IsStrictSorted
110 | Stat::NaNCount
111 | Stat::Sum
112 | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog),
113 Stat::Max | Stat::Min => {
114 expr.child(0)
116 .stat_expression(stat, catalog)
117 .map(|x| cast(x, expr.data().clone()))
118 }
119 Stat::NullCount => {
120 None
128 }
129 }
130 }
131
132 fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult<Vector> {
133 let input = args
134 .vectors
135 .pop()
136 .vortex_expect("missing input for Cast expression");
137 vortex_compute::cast::Cast::cast(&input, target_dtype)
138 }
139
140 fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
142 true
143 }
144}
145
146pub fn cast(child: Expression, target: DType) -> Expression {
156 Cast.try_new_expr(target, [child])
157 .vortex_expect("Failed to create Cast expression")
158}
159
160#[cfg(test)]
161mod tests {
162 use vortex_buffer::buffer;
163 use vortex_dtype::DType;
164 use vortex_dtype::Nullability;
165 use vortex_dtype::PType;
166 use vortex_error::VortexUnwrap as _;
167
168 use super::cast;
169 use crate::IntoArray;
170 use crate::arrays::StructArray;
171 use crate::expr::Expression;
172 use crate::expr::exprs::get_item::get_item;
173 use crate::expr::exprs::root::root;
174 use crate::expr::test_harness;
175
176 #[test]
177 fn dtype() {
178 let dtype = test_harness::struct_dtype();
179 assert_eq!(
180 cast(root(), DType::Bool(Nullability::NonNullable))
181 .return_dtype(&dtype)
182 .unwrap(),
183 DType::Bool(Nullability::NonNullable)
184 );
185 }
186
187 #[test]
188 fn replace_children() {
189 let expr = cast(root(), DType::Bool(Nullability::Nullable));
190 expr.with_children(vec![root()]).vortex_unwrap();
191 }
192
193 #[test]
194 fn evaluate() {
195 let test_array = StructArray::from_fields(&[
196 ("a", buffer![0i32, 1, 2].into_array()),
197 ("b", buffer![4i64, 5, 6].into_array()),
198 ])
199 .unwrap()
200 .into_array();
201
202 let expr: Expression = cast(
203 get_item("a", root()),
204 DType::Primitive(PType::I64, Nullability::NonNullable),
205 );
206 let result = expr.evaluate(&test_array).unwrap();
207
208 assert_eq!(
209 result.dtype(),
210 &DType::Primitive(PType::I64, Nullability::NonNullable)
211 );
212 }
213
214 #[test]
215 fn test_display() {
216 let expr = cast(
217 get_item("value", root()),
218 DType::Primitive(PType::I64, Nullability::NonNullable),
219 );
220 assert_eq!(expr.to_string(), "cast($.value as i64)");
221
222 let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
223 assert_eq!(expr2.to_string(), "cast($ as bool?)");
224 }
225}