vortex_array/expr/exprs/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::ops::Deref;
6
7use prost::Message;
8use vortex_dtype::{DType, FieldPath};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr as pb;
11
12use crate::ArrayRef;
13use crate::compute::cast as compute_cast;
14use crate::expr::expression::Expression;
15use crate::expr::{ChildName, ExprId, ExpressionView, StatsCatalog, VTable, VTableExt};
16
17/// A cast expression that converts values to a target data type.
18pub struct Cast;
19
20impl VTable for Cast {
21    type Instance = DType;
22
23    fn id(&self) -> ExprId {
24        ExprId::from("vortex.cast")
25    }
26
27    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
28        Ok(Some(
29            pb::CastOpts {
30                target: Some(instance.into()),
31            }
32            .encode_to_vec(),
33        ))
34    }
35
36    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
37        Ok(Some(
38            pb::CastOpts::decode(metadata)?
39                .target
40                .as_ref()
41                .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?
42                .try_into()?,
43        ))
44    }
45
46    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
47        if expr.children().len() != 1 {
48            vortex_bail!(
49                "Cast expression requires exactly 1 child, got {}",
50                expr.children().len()
51            );
52        }
53        Ok(())
54    }
55
56    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
57        match child_idx {
58            0 => ChildName::from("input"),
59            _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
60        }
61    }
62
63    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
64        write!(f, "cast(")?;
65        expr.children()[0].fmt_sql(f)?;
66        write!(f, " as {}", expr.data())?;
67        write!(f, ")")
68    }
69
70    fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
71        write!(f, "{}", instance)
72    }
73
74    fn return_dtype(&self, expr: &ExpressionView<Self>, _scope: &DType) -> VortexResult<DType> {
75        Ok(expr.data().clone())
76    }
77
78    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
79        let array = expr.children()[0].evaluate(scope)?;
80        compute_cast(&array, expr.data()).map_err(|e| {
81            e.with_context(format!(
82                "Failed to cast array of dtype {} to {}",
83                array.dtype(),
84                expr.deref()
85            ))
86        })
87    }
88
89    fn stat_max(
90        &self,
91        expr: &ExpressionView<Self>,
92        catalog: &mut dyn StatsCatalog,
93    ) -> Option<Expression> {
94        expr.children()[0].stat_max(catalog)
95    }
96
97    fn stat_min(
98        &self,
99        expr: &ExpressionView<Self>,
100        catalog: &mut dyn StatsCatalog,
101    ) -> Option<Expression> {
102        expr.children()[0].stat_min(catalog)
103    }
104
105    fn stat_nan_count(
106        &self,
107        expr: &ExpressionView<Self>,
108        catalog: &mut dyn StatsCatalog,
109    ) -> Option<Expression> {
110        expr.children()[0].stat_nan_count(catalog)
111    }
112
113    fn stat_field_path(&self, expr: &ExpressionView<Self>) -> Option<FieldPath> {
114        expr.children()[0].stat_field_path()
115    }
116}
117
118/// Creates an expression that casts values to a target data type.
119///
120/// Converts the input expression's values to the specified target type.
121///
122/// ```rust
123/// # use vortex_dtype::{DType, Nullability, PType};
124/// # use vortex_array::expr::{cast, root};
125/// let expr = cast(root(), DType::Primitive(PType::I64, Nullability::NonNullable));
126/// ```
127pub fn cast(child: Expression, target: DType) -> Expression {
128    Cast.try_new_expr(target, [child])
129        .vortex_expect("Failed to create Cast expression")
130}
131
132#[cfg(test)]
133mod tests {
134    use vortex_buffer::buffer;
135    use vortex_dtype::{DType, Nullability, PType};
136    use vortex_error::VortexUnwrap as _;
137
138    use super::cast;
139    use crate::IntoArray;
140    use crate::arrays::StructArray;
141    use crate::expr::exprs::get_item::get_item;
142    use crate::expr::exprs::root::root;
143    use crate::expr::{Expression, test_harness};
144
145    #[test]
146    fn dtype() {
147        let dtype = test_harness::struct_dtype();
148        assert_eq!(
149            cast(root(), DType::Bool(Nullability::NonNullable))
150                .return_dtype(&dtype)
151                .unwrap(),
152            DType::Bool(Nullability::NonNullable)
153        );
154    }
155
156    #[test]
157    fn replace_children() {
158        let expr = cast(root(), DType::Bool(Nullability::Nullable));
159        expr.with_children(vec![root()]).vortex_unwrap();
160    }
161
162    #[test]
163    fn evaluate() {
164        let test_array = StructArray::from_fields(&[
165            ("a", buffer![0i32, 1, 2].into_array()),
166            ("b", buffer![4i64, 5, 6].into_array()),
167        ])
168        .unwrap()
169        .into_array();
170
171        let expr: Expression = cast(
172            get_item("a", root()),
173            DType::Primitive(PType::I64, Nullability::NonNullable),
174        );
175        let result = expr.evaluate(&test_array).unwrap();
176
177        assert_eq!(
178            result.dtype(),
179            &DType::Primitive(PType::I64, Nullability::NonNullable)
180        );
181    }
182
183    #[test]
184    fn test_display() {
185        let expr = cast(
186            get_item("value", root()),
187            DType::Primitive(PType::I64, Nullability::NonNullable),
188        );
189        assert_eq!(expr.to_string(), "cast($.value as i64)");
190
191        let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
192        assert_eq!(expr2.to_string(), "cast($ as bool?)");
193    }
194}