Skip to main content

vortex_array/scalar_fn/fns/cast/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use prost::Message;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_err;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16
17use crate::AnyColumnar;
18use crate::ArrayRef;
19use crate::CanonicalView;
20use crate::ColumnarView;
21use crate::ExecutionCtx;
22use crate::arrays::BoolVTable;
23use crate::arrays::ConstantArray;
24use crate::arrays::ConstantVTable;
25use crate::arrays::DecimalVTable;
26use crate::arrays::ExtensionVTable;
27use crate::arrays::FixedSizeListVTable;
28use crate::arrays::ListViewVTable;
29use crate::arrays::NullVTable;
30use crate::arrays::PrimitiveVTable;
31use crate::arrays::StructVTable;
32use crate::arrays::VarBinViewVTable;
33use crate::builtins::ArrayBuiltins;
34use crate::dtype::DType;
35use crate::expr::StatsCatalog;
36use crate::expr::cast;
37use crate::expr::expression::Expression;
38use crate::expr::lit;
39use crate::expr::stats::Stat;
40use crate::scalar_fn::Arity;
41use crate::scalar_fn::ChildName;
42use crate::scalar_fn::ExecutionArgs;
43use crate::scalar_fn::ReduceCtx;
44use crate::scalar_fn::ReduceNode;
45use crate::scalar_fn::ReduceNodeRef;
46use crate::scalar_fn::ScalarFnId;
47use crate::scalar_fn::ScalarFnVTable;
48
49/// A cast expression that converts values to a target data type.
50#[derive(Clone)]
51pub struct Cast;
52
53impl ScalarFnVTable for Cast {
54    type Options = DType;
55
56    fn id(&self) -> ScalarFnId {
57        ScalarFnId::from("vortex.cast")
58    }
59
60    fn serialize(&self, dtype: &DType) -> VortexResult<Option<Vec<u8>>> {
61        Ok(Some(
62            pb::CastOpts {
63                target: Some(dtype.try_into()?),
64            }
65            .encode_to_vec(),
66        ))
67    }
68
69    fn deserialize(
70        &self,
71        _metadata: &[u8],
72        session: &VortexSession,
73    ) -> VortexResult<Self::Options> {
74        let proto = pb::CastOpts::decode(_metadata)?.target;
75        DType::from_proto(
76            proto
77                .as_ref()
78                .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?,
79            session,
80        )
81    }
82
83    fn arity(&self, _options: &DType) -> Arity {
84        Arity::Exact(1)
85    }
86
87    fn child_name(&self, _instance: &DType, child_idx: usize) -> ChildName {
88        match child_idx {
89            0 => ChildName::from("input"),
90            _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
91        }
92    }
93
94    fn fmt_sql(&self, dtype: &DType, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
95        write!(f, "cast(")?;
96        expr.children()[0].fmt_sql(f)?;
97        write!(f, " as {}", dtype)?;
98        write!(f, ")")
99    }
100
101    fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult<DType> {
102        Ok(dtype.clone())
103    }
104
105    fn execute(&self, target_dtype: &DType, mut args: ExecutionArgs) -> VortexResult<ArrayRef> {
106        let input = args
107            .inputs
108            .pop()
109            .vortex_expect("missing input for Cast expression");
110
111        let Some(columnar) = input.as_opt::<AnyColumnar>() else {
112            return input
113                .execute::<ArrayRef>(args.ctx)?
114                .cast(target_dtype.clone());
115        };
116
117        match columnar {
118            ColumnarView::Canonical(canonical) => {
119                match cast_canonical(canonical.clone(), target_dtype, args.ctx)? {
120                    Some(result) => Ok(result),
121                    None => vortex_bail!(
122                        "No CastKernel to cast canonical array {} from {} to {}",
123                        canonical.as_ref().encoding_id(),
124                        canonical.as_ref().dtype(),
125                        target_dtype,
126                    ),
127                }
128            }
129            ColumnarView::Constant(constant) => match cast_constant(constant, target_dtype)? {
130                Some(result) => Ok(result),
131                None => vortex_bail!(
132                    "No CastReduce to cast constant array from {} to {}",
133                    constant.dtype(),
134                    target_dtype,
135                ),
136            },
137        }
138    }
139
140    fn reduce(
141        &self,
142        target_dtype: &DType,
143        node: &dyn ReduceNode,
144        _ctx: &dyn ReduceCtx,
145    ) -> VortexResult<Option<ReduceNodeRef>> {
146        // Collapse node if child is already the target type
147        let child = node.child(0);
148        if &child.node_dtype()? == target_dtype {
149            return Ok(Some(child));
150        }
151        Ok(None)
152    }
153
154    fn stat_expression(
155        &self,
156        dtype: &DType,
157        expr: &Expression,
158        stat: Stat,
159        catalog: &dyn StatsCatalog,
160    ) -> Option<Expression> {
161        match stat {
162            Stat::IsConstant
163            | Stat::IsSorted
164            | Stat::IsStrictSorted
165            | Stat::NaNCount
166            | Stat::Sum
167            | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog),
168            Stat::Max | Stat::Min => {
169                // We cast min/max to the new type
170                expr.child(0)
171                    .stat_expression(stat, catalog)
172                    .map(|x| cast(x, dtype.clone()))
173            }
174            Stat::NullCount => {
175                // if !expr.data().is_nullable() {
176                // NOTE(ngates): we should decide on the semantics here. In theory, the null
177                //  count of something cast to non-nullable will be zero. But if we return
178                //  that we know this to be zero, then a pruning predicate may eliminate data
179                //  that would otherwise have caused the cast to error.
180                // return Some(lit(0u64));
181                // }
182                None
183            }
184        }
185    }
186
187    fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult<Option<Expression>> {
188        Ok(Some(if dtype.is_nullable() {
189            expression.child(0).validity()?
190        } else {
191            lit(true)
192        }))
193    }
194
195    // This might apply a nullability
196    fn is_null_sensitive(&self, _instance: &DType) -> bool {
197        true
198    }
199}
200
201/// Cast a canonical array to the target dtype by dispatching to the appropriate
202/// [`CastReduce`] or [`CastKernel`] for each canonical encoding.
203fn cast_canonical(
204    canonical: CanonicalView<'_>,
205    dtype: &DType,
206    ctx: &mut ExecutionCtx,
207) -> VortexResult<Option<ArrayRef>> {
208    match canonical {
209        CanonicalView::Null(a) => <NullVTable as CastReduce>::cast(a, dtype),
210        CanonicalView::Bool(a) => <BoolVTable as CastReduce>::cast(a, dtype),
211        CanonicalView::Primitive(a) => <PrimitiveVTable as CastKernel>::cast(a, dtype, ctx),
212        CanonicalView::Decimal(a) => <DecimalVTable as CastKernel>::cast(a, dtype, ctx),
213        CanonicalView::VarBinView(a) => <VarBinViewVTable as CastReduce>::cast(a, dtype),
214        CanonicalView::List(a) => <ListViewVTable as CastReduce>::cast(a, dtype),
215        CanonicalView::FixedSizeList(a) => <FixedSizeListVTable as CastReduce>::cast(a, dtype),
216        CanonicalView::Struct(a) => <StructVTable as CastKernel>::cast(a, dtype, ctx),
217        CanonicalView::Extension(a) => <ExtensionVTable as CastReduce>::cast(a, dtype),
218    }
219}
220
221/// Cast a constant array by dispatching to its [`CastReduce`] implementation.
222fn cast_constant(array: &ConstantArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
223    <ConstantVTable as CastReduce>::cast(array, dtype)
224}
225
226#[cfg(test)]
227mod tests {
228    use vortex_buffer::buffer;
229    use vortex_error::VortexExpect as _;
230
231    use crate::IntoArray;
232    use crate::arrays::StructArray;
233    use crate::dtype::DType;
234    use crate::dtype::Nullability;
235    use crate::dtype::PType;
236    use crate::expr::Expression;
237    use crate::expr::cast;
238    use crate::expr::get_item;
239    use crate::expr::root;
240    use crate::expr::test_harness;
241
242    #[test]
243    fn dtype() {
244        let dtype = test_harness::struct_dtype();
245        assert_eq!(
246            cast(root(), DType::Bool(Nullability::NonNullable))
247                .return_dtype(&dtype)
248                .unwrap(),
249            DType::Bool(Nullability::NonNullable)
250        );
251    }
252
253    #[test]
254    fn replace_children() {
255        let expr = cast(root(), DType::Bool(Nullability::Nullable));
256        expr.with_children(vec![root()])
257            .vortex_expect("operation should succeed in test");
258    }
259
260    #[test]
261    fn evaluate() {
262        let test_array = StructArray::from_fields(&[
263            ("a", buffer![0i32, 1, 2].into_array()),
264            ("b", buffer![4i64, 5, 6].into_array()),
265        ])
266        .unwrap()
267        .into_array();
268
269        let expr: Expression = cast(
270            get_item("a", root()),
271            DType::Primitive(PType::I64, Nullability::NonNullable),
272        );
273        let result = test_array.apply(&expr).unwrap();
274
275        assert_eq!(
276            result.dtype(),
277            &DType::Primitive(PType::I64, Nullability::NonNullable)
278        );
279    }
280
281    #[test]
282    fn test_display() {
283        let expr = cast(
284            get_item("value", root()),
285            DType::Primitive(PType::I64, Nullability::NonNullable),
286        );
287        assert_eq!(expr.to_string(), "cast($.value as i64)");
288
289        let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
290        assert_eq!(expr2.to_string(), "cast($ as bool?)");
291    }
292}