Skip to main content

vortex_array/arrays/bool/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::ExecutionCtx;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::Bool;
11use crate::arrays::BoolArray;
12use crate::arrays::bool::BoolArrayExt;
13use crate::dtype::DType;
14use crate::scalar_fn::fns::cast::CastKernel;
15use crate::scalar_fn::fns::cast::CastReduce;
16
17impl CastReduce for Bool {
18    fn cast(array: ArrayView<'_, Bool>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19        if !dtype.is_boolean() {
20            return Ok(None);
21        }
22
23        let Some(new_validity) = array
24            .validity()?
25            .trivial_cast_nullability(dtype.nullability(), array.len())?
26        else {
27            return Ok(None);
28        };
29        Ok(Some(
30            BoolArray::new(array.to_bit_buffer(), new_validity).into_array(),
31        ))
32    }
33}
34
35impl CastKernel for Bool {
36    fn cast(
37        array: ArrayView<'_, Bool>,
38        dtype: &DType,
39        ctx: &mut ExecutionCtx,
40    ) -> VortexResult<Option<ArrayRef>> {
41        if !dtype.is_boolean() {
42            return Ok(None);
43        }
44
45        let new_validity =
46            array
47                .validity()?
48                .cast_nullability(dtype.nullability(), array.len(), ctx)?;
49        Ok(Some(
50            BoolArray::new(array.to_bit_buffer(), new_validity).into_array(),
51        ))
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use std::sync::LazyLock;
58
59    use rstest::rstest;
60    use vortex_session::VortexSession;
61
62    use crate::Canonical;
63    use crate::IntoArray;
64    use crate::VortexSessionExecute;
65    use crate::arrays::BoolArray;
66    use crate::builtins::ArrayBuiltins;
67    use crate::compute::conformance::cast::test_cast_conformance;
68    use crate::dtype::DType;
69    use crate::dtype::Nullability;
70    use crate::session::ArraySession;
71
72    static SESSION: LazyLock<VortexSession> =
73        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
74
75    #[test]
76    fn try_cast_bool_success() {
77        let bool = BoolArray::from_iter(vec![Some(true), Some(false), Some(true)]);
78
79        let res = bool
80            .into_array()
81            .cast(DType::Bool(Nullability::NonNullable));
82        assert!(res.is_ok());
83        assert_eq!(res.unwrap().dtype(), &DType::Bool(Nullability::NonNullable));
84    }
85
86    #[test]
87    fn try_cast_bool_fail() {
88        // When the validity array's min stat is not cached, the reduce rule defers and the
89        // failure surfaces during execution via the kernel (cast_nullability -> compute_min).
90        let bool = BoolArray::from_iter(vec![Some(true), Some(false), None]);
91        let mut ctx = SESSION.create_execution_ctx();
92        let result = bool
93            .into_array()
94            .cast(DType::Bool(Nullability::NonNullable))
95            .and_then(|a| a.execute::<Canonical>(&mut ctx).map(|c| c.into_array()));
96        assert!(result.is_err(), "Expected error, got: {result:?}");
97    }
98
99    #[rstest]
100    #[case(BoolArray::from_iter(vec![true, false, true, true, false]))]
101    #[case(BoolArray::from_iter(vec![Some(true), Some(false), None, Some(true), None]))]
102    #[case(BoolArray::from_iter(vec![true]))]
103    #[case(BoolArray::from_iter(vec![false, false]))]
104    fn test_cast_bool_conformance(#[case] array: BoolArray) {
105        test_cast_conformance(&array.into_array());
106    }
107}