vortex_array/expr/exprs/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::ops::Deref;
6
7use prost::Message;
8use vortex_dtype::DType;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_err;
12use vortex_proto::expr as pb;
13use vortex_vector::Datum;
14
15use crate::ArrayRef;
16use crate::compute::cast as compute_cast;
17use crate::expr::Arity;
18use crate::expr::ChildName;
19use crate::expr::ExecutionArgs;
20use crate::expr::ExprId;
21use crate::expr::ReduceCtx;
22use crate::expr::ReduceNode;
23use crate::expr::ReduceNodeRef;
24use crate::expr::StatsCatalog;
25use crate::expr::VTable;
26use crate::expr::VTableExt;
27use crate::expr::expression::Expression;
28use crate::expr::stats::Stat;
29
30/// A cast expression that converts values to a target data type.
31pub struct Cast;
32
33impl VTable for Cast {
34    type Options = DType;
35
36    fn id(&self) -> ExprId {
37        ExprId::from("vortex.cast")
38    }
39
40    fn serialize(&self, dtype: &DType) -> VortexResult<Option<Vec<u8>>> {
41        Ok(Some(
42            pb::CastOpts {
43                target: Some(dtype.into()),
44            }
45            .encode_to_vec(),
46        ))
47    }
48
49    fn deserialize(&self, metadata: &[u8]) -> VortexResult<DType> {
50        pb::CastOpts::decode(metadata)?
51            .target
52            .as_ref()
53            .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?
54            .try_into()
55    }
56
57    fn arity(&self, _options: &DType) -> Arity {
58        Arity::Exact(1)
59    }
60
61    fn child_name(&self, _instance: &DType, child_idx: usize) -> ChildName {
62        match child_idx {
63            0 => ChildName::from("input"),
64            _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
65        }
66    }
67
68    fn fmt_sql(&self, dtype: &DType, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
69        write!(f, "cast(")?;
70        expr.children()[0].fmt_sql(f)?;
71        write!(f, " as {}", dtype)?;
72        write!(f, ")")
73    }
74
75    fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult<DType> {
76        Ok(dtype.clone())
77    }
78
79    fn evaluate(
80        &self,
81        dtype: &DType,
82        expr: &Expression,
83        scope: &ArrayRef,
84    ) -> VortexResult<ArrayRef> {
85        let array = expr.children()[0].evaluate(scope)?;
86        compute_cast(&array, dtype).map_err(|e| {
87            e.with_context(format!(
88                "Failed to cast array of dtype {} to {}",
89                array.dtype(),
90                expr.deref()
91            ))
92        })
93    }
94
95    fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult<Datum> {
96        let input = args
97            .datums
98            .pop()
99            .vortex_expect("missing input for Cast expression");
100        vortex_compute::cast::Cast::cast(&input, target_dtype)
101    }
102
103    fn reduce(
104        &self,
105        target_dtype: &DType,
106        node: &dyn ReduceNode,
107        _ctx: &dyn ReduceCtx,
108    ) -> VortexResult<Option<ReduceNodeRef>> {
109        // Collapse node if child is already the target type
110        let child = node.child(0);
111        if &child.node_dtype()? == target_dtype {
112            return Ok(Some(child));
113        }
114        Ok(None)
115    }
116
117    fn stat_expression(
118        &self,
119        dtype: &DType,
120        expr: &Expression,
121        stat: Stat,
122        catalog: &dyn StatsCatalog,
123    ) -> Option<Expression> {
124        match stat {
125            Stat::IsConstant
126            | Stat::IsSorted
127            | Stat::IsStrictSorted
128            | Stat::NaNCount
129            | Stat::Sum
130            | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog),
131            Stat::Max | Stat::Min => {
132                // We cast min/max to the new type
133                expr.child(0)
134                    .stat_expression(stat, catalog)
135                    .map(|x| cast(x, dtype.clone()))
136            }
137            Stat::NullCount => {
138                // if !expr.data().is_nullable() {
139                // NOTE(ngates): we should decide on the semantics here. In theory, the null
140                //  count of something cast to non-nullable will be zero. But if we return
141                //  that we know this to be zero, then a pruning predicate may eliminate data
142                //  that would otherwise have caused the cast to error.
143                // return Some(lit(0u64));
144                // }
145                None
146            }
147        }
148    }
149
150    // This might apply a nullability
151    fn is_null_sensitive(&self, _instance: &DType) -> bool {
152        true
153    }
154}
155
156/// Creates an expression that casts values to a target data type.
157///
158/// Converts the input expression's values to the specified target type.
159///
160/// ```rust
161/// # use vortex_dtype::{DType, Nullability, PType};
162/// # use vortex_array::expr::{cast, root};
163/// let expr = cast(root(), DType::Primitive(PType::I64, Nullability::NonNullable));
164/// ```
165pub fn cast(child: Expression, target: DType) -> Expression {
166    Cast.try_new_expr(target, [child])
167        .vortex_expect("Failed to create Cast expression")
168}
169
170#[cfg(test)]
171mod tests {
172    use vortex_buffer::buffer;
173    use vortex_dtype::DType;
174    use vortex_dtype::Nullability;
175    use vortex_dtype::PType;
176    use vortex_error::VortexExpect as _;
177
178    use super::cast;
179    use crate::IntoArray;
180    use crate::arrays::StructArray;
181    use crate::expr::Expression;
182    use crate::expr::exprs::get_item::get_item;
183    use crate::expr::exprs::root::root;
184    use crate::expr::test_harness;
185
186    #[test]
187    fn dtype() {
188        let dtype = test_harness::struct_dtype();
189        assert_eq!(
190            cast(root(), DType::Bool(Nullability::NonNullable))
191                .return_dtype(&dtype)
192                .unwrap(),
193            DType::Bool(Nullability::NonNullable)
194        );
195    }
196
197    #[test]
198    fn replace_children() {
199        let expr = cast(root(), DType::Bool(Nullability::Nullable));
200        expr.with_children(vec![root()])
201            .vortex_expect("operation should succeed in test");
202    }
203
204    #[test]
205    fn evaluate() {
206        let test_array = StructArray::from_fields(&[
207            ("a", buffer![0i32, 1, 2].into_array()),
208            ("b", buffer![4i64, 5, 6].into_array()),
209        ])
210        .unwrap()
211        .into_array();
212
213        let expr: Expression = cast(
214            get_item("a", root()),
215            DType::Primitive(PType::I64, Nullability::NonNullable),
216        );
217        let result = expr.evaluate(&test_array).unwrap();
218
219        assert_eq!(
220            result.dtype(),
221            &DType::Primitive(PType::I64, Nullability::NonNullable)
222        );
223    }
224
225    #[test]
226    fn test_display() {
227        let expr = cast(
228            get_item("value", root()),
229            DType::Primitive(PType::I64, Nullability::NonNullable),
230        );
231        assert_eq!(expr.to_string(), "cast($.value as i64)");
232
233        let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
234        assert_eq!(expr2.to_string(), "cast($ as bool?)");
235    }
236}