vortex_array/scalar_fns/cast/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4pub(crate) mod array;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_err;
11use vortex_error::vortex_panic;
12use vortex_proto::expr as pb;
13use vortex_vector::Datum;
14
15use crate::expr::Expression;
16use crate::expr::StatsCatalog;
17use crate::expr::functions::ArgName;
18use crate::expr::functions::Arity;
19use crate::expr::functions::ExecutionArgs;
20use crate::expr::functions::FunctionId;
21use crate::expr::functions::NullHandling;
22use crate::expr::functions::VTable;
23use crate::expr::stats::Stat;
24use crate::scalar_fns::ExprBuiltins;
25
26pub struct CastFn;
27impl VTable for CastFn {
28    type Options = DType;
29
30    fn id(&self) -> FunctionId {
31        FunctionId::from("vortex.cast")
32    }
33
34    fn serialize(&self, target_dtype: &DType) -> VortexResult<Option<Vec<u8>>> {
35        Ok(Some(
36            pb::CastOpts {
37                target: Some(target_dtype.into()),
38            }
39            .encode_to_vec(),
40        ))
41    }
42
43    fn deserialize(&self, bytes: &[u8]) -> VortexResult<DType> {
44        pb::CastOpts::decode(bytes)?
45            .target
46            .as_ref()
47            .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?
48            .try_into()
49    }
50
51    fn arity(&self, _options: &DType) -> Arity {
52        Arity::Exact(1)
53    }
54
55    fn null_handling(&self, _options: &DType) -> NullHandling {
56        NullHandling::Propagate
57    }
58
59    fn arg_name(&self, _options: &DType, arg_idx: usize) -> ArgName {
60        match arg_idx {
61            0 => ArgName::from("input"),
62            _ => vortex_panic!("Invalid argument index {}", arg_idx),
63        }
64    }
65
66    fn stat_expression(
67        &self,
68        target_dtype: &DType,
69        expr: &Expression,
70        stat: Stat,
71        catalog: &dyn StatsCatalog,
72    ) -> Option<Expression> {
73        match stat {
74            Stat::IsConstant
75            | Stat::IsSorted
76            | Stat::IsStrictSorted
77            | Stat::NaNCount
78            | Stat::Sum
79            | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog),
80            Stat::Max | Stat::Min => {
81                // We cast min/max to the new type
82                expr.child(0).stat_expression(stat, catalog).map(|x| {
83                    x.cast(target_dtype.clone())
84                        .vortex_expect("Failed to cast stat expression")
85                })
86            }
87            Stat::NullCount => {
88                // if !expr.data().is_nullable() {
89                // NOTE(ngates): we should decide on the semantics here. In theory, the null
90                //  count of something cast to non-nullable will be zero. But if we return
91                //  that we know this to be zero, then a pruning predicate may eliminate data
92                //  that would otherwise have caused the cast to error.
93                // return Some(lit(0u64));
94                // }
95                None
96            }
97        }
98    }
99
100    fn return_dtype(&self, target_dtype: &DType, _arg_types: &[DType]) -> VortexResult<DType> {
101        Ok(target_dtype.clone())
102    }
103
104    fn execute(&self, target_dtype: &DType, args: &ExecutionArgs) -> VortexResult<Datum> {
105        let datum = args.input_datums(0);
106        vortex_compute::cast::Cast::cast(datum, target_dtype)
107    }
108}