Skip to main content

vortex_array/scalar_fn/fns/cast/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use prost::Message;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16
17use crate::AnyColumnar;
18use crate::ArrayRef;
19use crate::ArrayView;
20use crate::CanonicalView;
21use crate::ColumnarView;
22use crate::ExecutionCtx;
23use crate::arrays::Bool;
24use crate::arrays::Constant;
25use crate::arrays::Decimal;
26use crate::arrays::Extension;
27use crate::arrays::FixedSizeList;
28use crate::arrays::ListView;
29use crate::arrays::Null;
30use crate::arrays::Primitive;
31use crate::arrays::VarBinView;
32use crate::arrays::struct_::compute::cast::struct_cast;
33use crate::builtins::ArrayBuiltins;
34use crate::dtype::DType;
35use crate::expr::StatsCatalog;
36use crate::expr::cast;
37use crate::expr::expression::Expression;
38use crate::expr::lit;
39use crate::expr::stats::Stat;
40use crate::scalar_fn::Arity;
41use crate::scalar_fn::ChildName;
42use crate::scalar_fn::ExecutionArgs;
43use crate::scalar_fn::ReduceCtx;
44use crate::scalar_fn::ReduceNode;
45use crate::scalar_fn::ReduceNodeRef;
46use crate::scalar_fn::ScalarFnId;
47use crate::scalar_fn::ScalarFnVTable;
48
49/// A cast expression that converts values to a target data type.
50#[derive(Clone)]
51pub struct Cast;
52
53impl ScalarFnVTable for Cast {
54    type Options = DType;
55
56    fn id(&self) -> ScalarFnId {
57        static ID: CachedId = CachedId::new("vortex.cast");
58        *ID
59    }
60
61    fn serialize(&self, dtype: &DType) -> VortexResult<Option<Vec<u8>>> {
62        Ok(Some(
63            pb::CastOpts {
64                target: Some(dtype.try_into()?),
65            }
66            .encode_to_vec(),
67        ))
68    }
69
70    fn deserialize(
71        &self,
72        _metadata: &[u8],
73        session: &VortexSession,
74    ) -> VortexResult<Self::Options> {
75        let proto = pb::CastOpts::decode(_metadata)?.target;
76        DType::from_proto(
77            proto
78                .as_ref()
79                .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?,
80            session,
81        )
82    }
83
84    fn arity(&self, _options: &DType) -> Arity {
85        Arity::Exact(1)
86    }
87
88    fn child_name(&self, _instance: &DType, child_idx: usize) -> ChildName {
89        match child_idx {
90            0 => ChildName::from("input"),
91            _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
92        }
93    }
94
95    fn fmt_sql(&self, dtype: &DType, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
96        write!(f, "cast(")?;
97        expr.children()[0].fmt_sql(f)?;
98        write!(f, " as {}", dtype)?;
99        write!(f, ")")
100    }
101
102    fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult<DType> {
103        Ok(dtype.clone())
104    }
105
106    fn execute(
107        &self,
108        target_dtype: &DType,
109        args: &dyn ExecutionArgs,
110        ctx: &mut ExecutionCtx,
111    ) -> VortexResult<ArrayRef> {
112        let input = args.get(0)?;
113
114        let Some(columnar) = input.as_opt::<AnyColumnar>() else {
115            return input.execute::<ArrayRef>(ctx)?.cast(target_dtype.clone());
116        };
117
118        match columnar {
119            ColumnarView::Canonical(canonical) => {
120                match cast_canonical(canonical, target_dtype, ctx)? {
121                    Some(result) => Ok(result),
122                    None => vortex_bail!(
123                        "No CastKernel to cast canonical array {} from {} to {}",
124                        canonical.to_array_ref().encoding_id(),
125                        canonical.to_array_ref().dtype(),
126                        target_dtype,
127                    ),
128                }
129            }
130            ColumnarView::Constant(constant) => match cast_constant(constant, target_dtype)? {
131                Some(result) => Ok(result),
132                None => vortex_bail!(
133                    "No CastReduce to cast constant array from {} to {}",
134                    constant.dtype(),
135                    target_dtype,
136                ),
137            },
138        }
139    }
140
141    fn reduce(
142        &self,
143        target_dtype: &DType,
144        node: &dyn ReduceNode,
145        _ctx: &dyn ReduceCtx,
146    ) -> VortexResult<Option<ReduceNodeRef>> {
147        // Collapse node if child is already the target type
148        let child = node.child(0);
149        if &child.node_dtype()? == target_dtype {
150            return Ok(Some(child));
151        }
152        Ok(None)
153    }
154
155    fn stat_expression(
156        &self,
157        dtype: &DType,
158        expr: &Expression,
159        stat: Stat,
160        catalog: &dyn StatsCatalog,
161    ) -> Option<Expression> {
162        match stat {
163            Stat::IsConstant
164            | Stat::IsSorted
165            | Stat::IsStrictSorted
166            | Stat::NaNCount
167            | Stat::Sum
168            | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog),
169            Stat::Max | Stat::Min => {
170                // We cast min/max to the new type
171                expr.child(0)
172                    .stat_expression(stat, catalog)
173                    .map(|x| cast(x, dtype.clone()))
174            }
175            Stat::NullCount => {
176                // if !expr.data().is_nullable() {
177                // NOTE(ngates): we should decide on the semantics here. In theory, the null
178                //  count of something cast to non-nullable will be zero. But if we return
179                //  that we know this to be zero, then a pruning predicate may eliminate data
180                //  that would otherwise have caused the cast to error.
181                // return Some(lit(0u64));
182                // }
183                None
184            }
185        }
186    }
187
188    fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult<Option<Expression>> {
189        Ok(Some(if dtype.is_nullable() {
190            expression.child(0).validity()?
191        } else {
192            lit(true)
193        }))
194    }
195
196    // This might apply a nullability
197    fn is_null_sensitive(&self, _instance: &DType) -> bool {
198        true
199    }
200}
201
202/// Cast a canonical array to the target dtype by dispatching to the appropriate
203/// [`CastKernel`] for each canonical encoding.
204///
205/// Canonical encodings that can manipulate validity directly all implement [`CastKernel`] —
206/// the kernel is the execution-time complement of their [`CastReduce`] rule and can compute
207/// statistics (e.g. min of the validity array) when the reduce rule had to give up.
208/// Encodings that delegate to scalars or storage (e.g. [`Null`], [`Constant`], [`Extension`])
209/// only implement [`CastReduce`] because they never need execution-level information.
210fn cast_canonical(
211    canonical: CanonicalView<'_>,
212    dtype: &DType,
213    ctx: &mut ExecutionCtx,
214) -> VortexResult<Option<ArrayRef>> {
215    match canonical {
216        CanonicalView::Null(a) => <Null as CastReduce>::cast(a, dtype),
217        CanonicalView::Bool(a) => <Bool as CastKernel>::cast(a, dtype, ctx),
218        CanonicalView::Primitive(a) => <Primitive as CastKernel>::cast(a, dtype, ctx),
219        CanonicalView::Decimal(a) => <Decimal as CastKernel>::cast(a, dtype, ctx),
220        CanonicalView::VarBinView(a) => <VarBinView as CastKernel>::cast(a, dtype, ctx),
221        CanonicalView::List(a) => <ListView as CastKernel>::cast(a, dtype, ctx),
222        CanonicalView::FixedSizeList(a) => <FixedSizeList as CastKernel>::cast(a, dtype, ctx),
223        CanonicalView::Struct(a) => struct_cast(a, dtype, ctx),
224        CanonicalView::Extension(a) => <Extension as CastReduce>::cast(a, dtype),
225        CanonicalView::Variant(_) => {
226            vortex_bail!("Variant arrays don't support casting")
227        }
228    }
229}
230
231/// Cast a constant array by dispatching to its [`CastReduce`] implementation.
232fn cast_constant(array: ArrayView<Constant>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
233    <Constant as CastReduce>::cast(array, dtype)
234}
235
236#[cfg(test)]
237mod tests {
238    use vortex_buffer::buffer;
239    use vortex_error::VortexExpect as _;
240
241    use crate::IntoArray;
242    use crate::arrays::StructArray;
243    use crate::dtype::DType;
244    use crate::dtype::Nullability;
245    use crate::dtype::PType;
246    use crate::expr::Expression;
247    use crate::expr::cast;
248    use crate::expr::get_item;
249    use crate::expr::root;
250    use crate::expr::test_harness;
251
252    #[test]
253    fn dtype() {
254        let dtype = test_harness::struct_dtype();
255        assert_eq!(
256            cast(root(), DType::Bool(Nullability::NonNullable))
257                .return_dtype(&dtype)
258                .unwrap(),
259            DType::Bool(Nullability::NonNullable)
260        );
261    }
262
263    #[test]
264    fn replace_children() {
265        let expr = cast(root(), DType::Bool(Nullability::Nullable));
266        expr.with_children(vec![root()])
267            .vortex_expect("operation should succeed in test");
268    }
269
270    #[test]
271    fn evaluate() {
272        let test_array = StructArray::from_fields(&[
273            ("a", buffer![0i32, 1, 2].into_array()),
274            ("b", buffer![4i64, 5, 6].into_array()),
275        ])
276        .unwrap()
277        .into_array();
278
279        let expr: Expression = cast(
280            get_item("a", root()),
281            DType::Primitive(PType::I64, Nullability::NonNullable),
282        );
283        let result = test_array.apply(&expr).unwrap();
284
285        assert_eq!(
286            result.dtype(),
287            &DType::Primitive(PType::I64, Nullability::NonNullable)
288        );
289    }
290
291    #[test]
292    fn test_display() {
293        let expr = cast(
294            get_item("value", root()),
295            DType::Primitive(PType::I64, Nullability::NonNullable),
296        );
297        assert_eq!(expr.to_string(), "cast($.value as i64)");
298
299        let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
300        assert_eq!(expr2.to_string(), "cast($ as bool?)");
301    }
302}