Skip to main content

vortex_array/expr/exprs/cast/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use prost::Message;
10use vortex_dtype::DType;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_bail;
14use vortex_error::vortex_err;
15use vortex_proto::expr as pb;
16use vortex_session::VortexSession;
17
18use crate::AnyColumnar;
19use crate::ArrayRef;
20use crate::CanonicalView;
21use crate::ColumnarView;
22use crate::ExecutionCtx;
23use crate::arrays::BoolVTable;
24use crate::arrays::ConstantArray;
25use crate::arrays::ConstantVTable;
26use crate::arrays::DecimalVTable;
27use crate::arrays::ExtensionVTable;
28use crate::arrays::FixedSizeListVTable;
29use crate::arrays::ListViewVTable;
30use crate::arrays::NullVTable;
31use crate::arrays::PrimitiveVTable;
32use crate::arrays::StructVTable;
33use crate::arrays::VarBinViewVTable;
34use crate::builtins::ArrayBuiltins;
35use crate::expr::Arity;
36use crate::expr::ChildName;
37use crate::expr::ExecutionArgs;
38use crate::expr::ExprId;
39use crate::expr::ReduceCtx;
40use crate::expr::ReduceNode;
41use crate::expr::ReduceNodeRef;
42use crate::expr::StatsCatalog;
43use crate::expr::VTable;
44use crate::expr::VTableExt;
45use crate::expr::expression::Expression;
46use crate::expr::lit;
47use crate::expr::stats::Stat;
48
49/// A cast expression that converts values to a target data type.
50pub struct Cast;
51
52impl VTable for Cast {
53    type Options = DType;
54
55    fn id(&self) -> ExprId {
56        ExprId::from("vortex.cast")
57    }
58
59    fn serialize(&self, dtype: &DType) -> VortexResult<Option<Vec<u8>>> {
60        Ok(Some(
61            pb::CastOpts {
62                target: Some(dtype.try_into()?),
63            }
64            .encode_to_vec(),
65        ))
66    }
67
68    fn deserialize(
69        &self,
70        _metadata: &[u8],
71        session: &VortexSession,
72    ) -> VortexResult<Self::Options> {
73        let proto = pb::CastOpts::decode(_metadata)?.target;
74        DType::from_proto(
75            proto
76                .as_ref()
77                .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?,
78            session,
79        )
80    }
81
82    fn arity(&self, _options: &DType) -> Arity {
83        Arity::Exact(1)
84    }
85
86    fn child_name(&self, _instance: &DType, child_idx: usize) -> ChildName {
87        match child_idx {
88            0 => ChildName::from("input"),
89            _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
90        }
91    }
92
93    fn fmt_sql(&self, dtype: &DType, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
94        write!(f, "cast(")?;
95        expr.children()[0].fmt_sql(f)?;
96        write!(f, " as {}", dtype)?;
97        write!(f, ")")
98    }
99
100    fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult<DType> {
101        Ok(dtype.clone())
102    }
103
104    fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult<ArrayRef> {
105        let input = args
106            .inputs
107            .pop()
108            .vortex_expect("missing input for Cast expression");
109
110        let Some(columnar) = input.as_opt::<AnyColumnar>() else {
111            return input
112                .execute::<ArrayRef>(args.ctx)?
113                .cast(target_dtype.clone());
114        };
115
116        match columnar {
117            ColumnarView::Canonical(canonical) => {
118                match cast_canonical(canonical.clone(), target_dtype, args.ctx)? {
119                    Some(result) => Ok(result),
120                    None => vortex_bail!(
121                        "No CastKernel to cast canonical array {} from {} to {}",
122                        canonical.as_ref().encoding_id(),
123                        canonical.as_ref().dtype(),
124                        target_dtype,
125                    ),
126                }
127            }
128            ColumnarView::Constant(constant) => match cast_constant(constant, target_dtype)? {
129                Some(result) => Ok(result),
130                None => vortex_bail!(
131                    "No CastReduce to cast constant array from {} to {}",
132                    constant.dtype(),
133                    target_dtype,
134                ),
135            },
136        }
137    }
138
139    fn reduce(
140        &self,
141        target_dtype: &DType,
142        node: &dyn ReduceNode,
143        _ctx: &dyn ReduceCtx,
144    ) -> VortexResult<Option<ReduceNodeRef>> {
145        // Collapse node if child is already the target type
146        let child = node.child(0);
147        if &child.node_dtype()? == target_dtype {
148            return Ok(Some(child));
149        }
150        Ok(None)
151    }
152
153    fn stat_expression(
154        &self,
155        dtype: &DType,
156        expr: &Expression,
157        stat: Stat,
158        catalog: &dyn StatsCatalog,
159    ) -> Option<Expression> {
160        match stat {
161            Stat::IsConstant
162            | Stat::IsSorted
163            | Stat::IsStrictSorted
164            | Stat::NaNCount
165            | Stat::Sum
166            | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog),
167            Stat::Max | Stat::Min => {
168                // We cast min/max to the new type
169                expr.child(0)
170                    .stat_expression(stat, catalog)
171                    .map(|x| cast(x, dtype.clone()))
172            }
173            Stat::NullCount => {
174                // if !expr.data().is_nullable() {
175                // NOTE(ngates): we should decide on the semantics here. In theory, the null
176                //  count of something cast to non-nullable will be zero. But if we return
177                //  that we know this to be zero, then a pruning predicate may eliminate data
178                //  that would otherwise have caused the cast to error.
179                // return Some(lit(0u64));
180                // }
181                None
182            }
183        }
184    }
185
186    fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult<Option<Expression>> {
187        Ok(Some(if dtype.is_nullable() {
188            expression.child(0).validity()?
189        } else {
190            lit(true)
191        }))
192    }
193
194    // This might apply a nullability
195    fn is_null_sensitive(&self, _instance: &DType) -> bool {
196        true
197    }
198}
199
200/// Cast a canonical array to the target dtype by dispatching to the appropriate
201/// [`CastReduce`] or [`CastKernel`] for each canonical encoding.
202fn cast_canonical(
203    canonical: CanonicalView<'_>,
204    dtype: &DType,
205    ctx: &mut ExecutionCtx,
206) -> VortexResult<Option<ArrayRef>> {
207    match canonical {
208        CanonicalView::Null(a) => <NullVTable as CastReduce>::cast(a, dtype),
209        CanonicalView::Bool(a) => <BoolVTable as CastReduce>::cast(a, dtype),
210        CanonicalView::Primitive(a) => <PrimitiveVTable as CastKernel>::cast(a, dtype, ctx),
211        CanonicalView::Decimal(a) => <DecimalVTable as CastKernel>::cast(a, dtype, ctx),
212        CanonicalView::VarBinView(a) => <VarBinViewVTable as CastReduce>::cast(a, dtype),
213        CanonicalView::List(a) => <ListViewVTable as CastReduce>::cast(a, dtype),
214        CanonicalView::FixedSizeList(a) => <FixedSizeListVTable as CastReduce>::cast(a, dtype),
215        CanonicalView::Struct(a) => <StructVTable as CastKernel>::cast(a, dtype, ctx),
216        CanonicalView::Extension(a) => <ExtensionVTable as CastReduce>::cast(a, dtype),
217    }
218}
219
220/// Cast a constant array by dispatching to its [`CastReduce`] implementation.
221fn cast_constant(array: &ConstantArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
222    <ConstantVTable as CastReduce>::cast(array, dtype)
223}
224
225/// Creates an expression that casts values to a target data type.
226///
227/// Converts the input expression's values to the specified target type.
228///
229/// ```rust
230/// # use vortex_dtype::{DType, Nullability, PType};
231/// # use vortex_array::expr::{cast, root};
232/// let expr = cast(root(), DType::Primitive(PType::I64, Nullability::NonNullable));
233/// ```
234pub fn cast(child: Expression, target: DType) -> Expression {
235    Cast.try_new_expr(target, [child])
236        .vortex_expect("Failed to create Cast expression")
237}
238
239#[cfg(test)]
240mod tests {
241    use vortex_buffer::buffer;
242    use vortex_dtype::DType;
243    use vortex_dtype::Nullability;
244    use vortex_dtype::PType;
245    use vortex_error::VortexExpect as _;
246
247    use super::cast;
248    use crate::IntoArray;
249    use crate::arrays::StructArray;
250    use crate::expr::Expression;
251    use crate::expr::exprs::get_item::get_item;
252    use crate::expr::exprs::root::root;
253    use crate::expr::test_harness;
254
255    #[test]
256    fn dtype() {
257        let dtype = test_harness::struct_dtype();
258        assert_eq!(
259            cast(root(), DType::Bool(Nullability::NonNullable))
260                .return_dtype(&dtype)
261                .unwrap(),
262            DType::Bool(Nullability::NonNullable)
263        );
264    }
265
266    #[test]
267    fn replace_children() {
268        let expr = cast(root(), DType::Bool(Nullability::Nullable));
269        expr.with_children(vec![root()])
270            .vortex_expect("operation should succeed in test");
271    }
272
273    #[test]
274    fn evaluate() {
275        let test_array = StructArray::from_fields(&[
276            ("a", buffer![0i32, 1, 2].into_array()),
277            ("b", buffer![4i64, 5, 6].into_array()),
278        ])
279        .unwrap()
280        .into_array();
281
282        let expr: Expression = cast(
283            get_item("a", root()),
284            DType::Primitive(PType::I64, Nullability::NonNullable),
285        );
286        let result = test_array.apply(&expr).unwrap();
287
288        assert_eq!(
289            result.dtype(),
290            &DType::Primitive(PType::I64, Nullability::NonNullable)
291        );
292    }
293
294    #[test]
295    fn test_display() {
296        let expr = cast(
297            get_item("value", root()),
298            DType::Primitive(PType::I64, Nullability::NonNullable),
299        );
300        assert_eq!(expr.to_string(), "cast($.value as i64)");
301
302        let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
303        assert_eq!(expr2.to_string(), "cast($ as bool?)");
304    }
305}