vortex_array/expr/exprs/cast/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::ops::Deref;
6
7use prost::Message;
8use vortex_dtype::DType;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr as pb;
14use vortex_vector::Vector;
15
16use crate::ArrayRef;
17use crate::compute::cast as compute_cast;
18use crate::expr::ChildName;
19use crate::expr::ExecutionArgs;
20use crate::expr::ExprId;
21use crate::expr::ExpressionView;
22use crate::expr::StatsCatalog;
23use crate::expr::VTable;
24use crate::expr::VTableExt;
25use crate::expr::expression::Expression;
26use crate::expr::stats::Stat;
27
28/// A cast expression that converts values to a target data type.
29pub struct Cast;
30
31impl VTable for Cast {
32    type Instance = DType;
33
34    fn id(&self) -> ExprId {
35        ExprId::from("vortex.cast")
36    }
37
38    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
39        Ok(Some(
40            pb::CastOpts {
41                target: Some(instance.into()),
42            }
43            .encode_to_vec(),
44        ))
45    }
46
47    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
48        Ok(Some(
49            pb::CastOpts::decode(metadata)?
50                .target
51                .as_ref()
52                .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?
53                .try_into()?,
54        ))
55    }
56
57    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
58        if expr.children().len() != 1 {
59            vortex_bail!(
60                "Cast expression requires exactly 1 child, got {}",
61                expr.children().len()
62            );
63        }
64        Ok(())
65    }
66
67    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
68        match child_idx {
69            0 => ChildName::from("input"),
70            _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
71        }
72    }
73
74    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
75        write!(f, "cast(")?;
76        expr.children()[0].fmt_sql(f)?;
77        write!(f, " as {}", expr.data())?;
78        write!(f, ")")
79    }
80
81    fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
82        write!(f, "{}", instance)
83    }
84
85    fn return_dtype(&self, expr: &ExpressionView<Self>, _scope: &DType) -> VortexResult<DType> {
86        Ok(expr.data().clone())
87    }
88
89    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
90        let array = expr.children()[0].evaluate(scope)?;
91        compute_cast(&array, expr.data()).map_err(|e| {
92            e.with_context(format!(
93                "Failed to cast array of dtype {} to {}",
94                array.dtype(),
95                expr.deref()
96            ))
97        })
98    }
99
100    fn stat_expression(
101        &self,
102        expr: &ExpressionView<Self>,
103        stat: Stat,
104        catalog: &dyn StatsCatalog,
105    ) -> Option<Expression> {
106        match stat {
107            Stat::IsConstant
108            | Stat::IsSorted
109            | Stat::IsStrictSorted
110            | Stat::NaNCount
111            | Stat::Sum
112            | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog),
113            Stat::Max | Stat::Min => {
114                // We cast min/max to the new type
115                expr.child(0)
116                    .stat_expression(stat, catalog)
117                    .map(|x| cast(x, expr.data().clone()))
118            }
119            Stat::NullCount => {
120                // if !expr.data().is_nullable() {
121                // NOTE(ngates): we should decide on the semantics here. In theory, the null
122                //  count of something cast to non-nullable will be zero. But if we return
123                //  that we know this to be zero, then a pruning predicate may eliminate data
124                //  that would otherwise have caused the cast to error.
125                // return Some(lit(0u64));
126                // }
127                None
128            }
129        }
130    }
131
132    fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult<Vector> {
133        let input = args
134            .vectors
135            .pop()
136            .vortex_expect("missing input for Cast expression");
137        vortex_compute::cast::Cast::cast(&input, target_dtype)
138    }
139
140    // This might apply a nullability
141    fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
142        true
143    }
144}
145
146/// Creates an expression that casts values to a target data type.
147///
148/// Converts the input expression's values to the specified target type.
149///
150/// ```rust
151/// # use vortex_dtype::{DType, Nullability, PType};
152/// # use vortex_array::expr::{cast, root};
153/// let expr = cast(root(), DType::Primitive(PType::I64, Nullability::NonNullable));
154/// ```
155pub fn cast(child: Expression, target: DType) -> Expression {
156    Cast.try_new_expr(target, [child])
157        .vortex_expect("Failed to create Cast expression")
158}
159
160#[cfg(test)]
161mod tests {
162    use vortex_buffer::buffer;
163    use vortex_dtype::DType;
164    use vortex_dtype::Nullability;
165    use vortex_dtype::PType;
166    use vortex_error::VortexUnwrap as _;
167
168    use super::cast;
169    use crate::IntoArray;
170    use crate::arrays::StructArray;
171    use crate::expr::Expression;
172    use crate::expr::exprs::get_item::get_item;
173    use crate::expr::exprs::root::root;
174    use crate::expr::test_harness;
175
176    #[test]
177    fn dtype() {
178        let dtype = test_harness::struct_dtype();
179        assert_eq!(
180            cast(root(), DType::Bool(Nullability::NonNullable))
181                .return_dtype(&dtype)
182                .unwrap(),
183            DType::Bool(Nullability::NonNullable)
184        );
185    }
186
187    #[test]
188    fn replace_children() {
189        let expr = cast(root(), DType::Bool(Nullability::Nullable));
190        expr.with_children(vec![root()]).vortex_unwrap();
191    }
192
193    #[test]
194    fn evaluate() {
195        let test_array = StructArray::from_fields(&[
196            ("a", buffer![0i32, 1, 2].into_array()),
197            ("b", buffer![4i64, 5, 6].into_array()),
198        ])
199        .unwrap()
200        .into_array();
201
202        let expr: Expression = cast(
203            get_item("a", root()),
204            DType::Primitive(PType::I64, Nullability::NonNullable),
205        );
206        let result = expr.evaluate(&test_array).unwrap();
207
208        assert_eq!(
209            result.dtype(),
210            &DType::Primitive(PType::I64, Nullability::NonNullable)
211        );
212    }
213
214    #[test]
215    fn test_display() {
216        let expr = cast(
217            get_item("value", root()),
218            DType::Primitive(PType::I64, Nullability::NonNullable),
219        );
220        assert_eq!(expr.to_string(), "cast($.value as i64)");
221
222        let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
223        assert_eq!(expr2.to_string(), "cast($ as bool?)");
224    }
225}