Skip to main content

vortex_array/scalar_fn/fns/cast/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use prost::Message;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16
17use crate::AnyColumnar;
18use crate::ArrayRef;
19use crate::ArrayView;
20use crate::CanonicalView;
21use crate::ColumnarView;
22use crate::ExecutionCtx;
23use crate::arrays::Bool;
24use crate::arrays::Constant;
25use crate::arrays::Decimal;
26use crate::arrays::Extension;
27use crate::arrays::FixedSizeList;
28use crate::arrays::ListView;
29use crate::arrays::Null;
30use crate::arrays::Primitive;
31use crate::arrays::VarBinView;
32use crate::arrays::struct_::compute::cast::struct_cast;
33use crate::builtins::ArrayBuiltins;
34use crate::dtype::DType;
35use crate::expr::expression::Expression;
36use crate::expr::lit;
37use crate::scalar_fn::Arity;
38use crate::scalar_fn::ChildName;
39use crate::scalar_fn::ExecutionArgs;
40use crate::scalar_fn::ReduceCtx;
41use crate::scalar_fn::ReduceNode;
42use crate::scalar_fn::ReduceNodeRef;
43use crate::scalar_fn::ScalarFnId;
44use crate::scalar_fn::ScalarFnVTable;
45
46/// A cast expression that converts values to a target data type.
47#[derive(Clone)]
48pub struct Cast;
49
50impl ScalarFnVTable for Cast {
51    type Options = DType;
52
53    fn id(&self) -> ScalarFnId {
54        static ID: CachedId = CachedId::new("vortex.cast");
55        *ID
56    }
57
58    fn serialize(&self, dtype: &DType) -> VortexResult<Option<Vec<u8>>> {
59        Ok(Some(
60            pb::CastOpts {
61                target: Some(dtype.try_into()?),
62            }
63            .encode_to_vec(),
64        ))
65    }
66
67    fn deserialize(
68        &self,
69        _metadata: &[u8],
70        session: &VortexSession,
71    ) -> VortexResult<Self::Options> {
72        let proto = pb::CastOpts::decode(_metadata)?.target;
73        DType::from_proto(
74            proto
75                .as_ref()
76                .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?,
77            session,
78        )
79    }
80
81    fn arity(&self, _options: &DType) -> Arity {
82        Arity::Exact(1)
83    }
84
85    fn child_name(&self, _instance: &DType, child_idx: usize) -> ChildName {
86        match child_idx {
87            0 => ChildName::from("input"),
88            _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
89        }
90    }
91
92    fn fmt_sql(&self, dtype: &DType, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
93        write!(f, "cast(")?;
94        expr.children()[0].fmt_sql(f)?;
95        write!(f, " as {}", dtype)?;
96        write!(f, ")")
97    }
98
99    fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult<DType> {
100        Ok(dtype.clone())
101    }
102
103    fn execute(
104        &self,
105        target_dtype: &DType,
106        args: &dyn ExecutionArgs,
107        ctx: &mut ExecutionCtx,
108    ) -> VortexResult<ArrayRef> {
109        let input = args.get(0)?;
110
111        let Some(columnar) = input.as_opt::<AnyColumnar>() else {
112            return input.execute::<ArrayRef>(ctx)?.cast(target_dtype.clone());
113        };
114
115        match columnar {
116            ColumnarView::Canonical(canonical) => {
117                match cast_canonical(canonical, target_dtype, ctx)? {
118                    Some(result) => Ok(result),
119                    None => vortex_bail!(
120                        "No CastKernel to cast canonical array {} from {} to {}",
121                        canonical.to_array_ref().encoding_id(),
122                        canonical.to_array_ref().dtype(),
123                        target_dtype,
124                    ),
125                }
126            }
127            ColumnarView::Constant(constant) => match cast_constant(constant, target_dtype)? {
128                Some(result) => Ok(result),
129                None => vortex_bail!(
130                    "No CastReduce to cast constant array from {} to {}",
131                    constant.dtype(),
132                    target_dtype,
133                ),
134            },
135        }
136    }
137
138    fn reduce(
139        &self,
140        target_dtype: &DType,
141        node: &dyn ReduceNode,
142        _ctx: &dyn ReduceCtx,
143    ) -> VortexResult<Option<ReduceNodeRef>> {
144        // Collapse node if child is already the target type
145        let child = node.child(0);
146        if &child.node_dtype()? == target_dtype {
147            return Ok(Some(child));
148        }
149        Ok(None)
150    }
151
152    fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult<Option<Expression>> {
153        Ok(Some(if dtype.is_nullable() {
154            expression.child(0).validity()?
155        } else {
156            lit(true)
157        }))
158    }
159
160    // This might apply a nullability
161    fn is_null_sensitive(&self, _instance: &DType) -> bool {
162        true
163    }
164}
165
166/// Cast a canonical array to the target dtype by dispatching to the appropriate
167/// [`CastKernel`] for each canonical encoding.
168///
169/// Canonical encodings that can manipulate validity directly all implement [`CastKernel`] —
170/// the kernel is the execution-time complement of their [`CastReduce`] rule and can compute
171/// statistics (e.g. min of the validity array) when the reduce rule had to give up.
172/// Encodings that delegate to scalars or storage (e.g. [`Null`], [`Constant`], [`Extension`])
173/// only implement [`CastReduce`] because they never need execution-level information.
174fn cast_canonical(
175    canonical: CanonicalView<'_>,
176    dtype: &DType,
177    ctx: &mut ExecutionCtx,
178) -> VortexResult<Option<ArrayRef>> {
179    match canonical {
180        CanonicalView::Null(a) => <Null as CastReduce>::cast(a, dtype),
181        CanonicalView::Bool(a) => <Bool as CastKernel>::cast(a, dtype, ctx),
182        CanonicalView::Primitive(a) => <Primitive as CastKernel>::cast(a, dtype, ctx),
183        CanonicalView::Decimal(a) => <Decimal as CastKernel>::cast(a, dtype, ctx),
184        CanonicalView::VarBinView(a) => <VarBinView as CastKernel>::cast(a, dtype, ctx),
185        CanonicalView::List(a) => <ListView as CastKernel>::cast(a, dtype, ctx),
186        CanonicalView::FixedSizeList(a) => <FixedSizeList as CastKernel>::cast(a, dtype, ctx),
187        CanonicalView::Struct(a) => struct_cast(a, dtype, ctx),
188        CanonicalView::Extension(a) => <Extension as CastReduce>::cast(a, dtype),
189        CanonicalView::Variant(_) => {
190            vortex_bail!("Variant arrays don't support casting")
191        }
192    }
193}
194
195/// Cast a constant array by dispatching to its [`CastReduce`] implementation.
196fn cast_constant(array: ArrayView<Constant>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
197    <Constant as CastReduce>::cast(array, dtype)
198}
199
200#[cfg(test)]
201mod tests {
202    use vortex_buffer::buffer;
203    use vortex_error::VortexExpect as _;
204
205    use crate::IntoArray;
206    use crate::arrays::StructArray;
207    use crate::dtype::DType;
208    use crate::dtype::Nullability;
209    use crate::dtype::PType;
210    use crate::expr::Expression;
211    use crate::expr::cast;
212    use crate::expr::get_item;
213    use crate::expr::root;
214    use crate::expr::test_harness;
215
216    #[test]
217    fn dtype() {
218        let dtype = test_harness::struct_dtype();
219        assert_eq!(
220            cast(root(), DType::Bool(Nullability::NonNullable))
221                .return_dtype(&dtype)
222                .unwrap(),
223            DType::Bool(Nullability::NonNullable)
224        );
225    }
226
227    #[test]
228    fn replace_children() {
229        let expr = cast(root(), DType::Bool(Nullability::Nullable));
230        expr.with_children(vec![root()])
231            .vortex_expect("operation should succeed in test");
232    }
233
234    #[test]
235    fn evaluate() {
236        let test_array = StructArray::from_fields(&[
237            ("a", buffer![0i32, 1, 2].into_array()),
238            ("b", buffer![4i64, 5, 6].into_array()),
239        ])
240        .unwrap()
241        .into_array();
242
243        let expr: Expression = cast(
244            get_item("a", root()),
245            DType::Primitive(PType::I64, Nullability::NonNullable),
246        );
247        let result = test_array.apply(&expr).unwrap();
248
249        assert_eq!(
250            result.dtype(),
251            &DType::Primitive(PType::I64, Nullability::NonNullable)
252        );
253    }
254
255    #[test]
256    fn test_display() {
257        let expr = cast(
258            get_item("value", root()),
259            DType::Primitive(PType::I64, Nullability::NonNullable),
260        );
261        assert_eq!(expr.to_string(), "cast($.value as i64)");
262
263        let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
264        assert_eq!(expr2.to_string(), "cast($ as bool?)");
265    }
266}