Skip to main content

vortex_array/arrays/bool/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::One;
5use num_traits::Zero;
6use vortex_buffer::BufferMut;
7use vortex_error::VortexResult;
8
9use crate::ArrayRef;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::array::ArrayView;
13use crate::arrays::Bool;
14use crate::arrays::BoolArray;
15use crate::arrays::PrimitiveArray;
16use crate::arrays::bool::BoolArrayExt;
17use crate::dtype::DType;
18use crate::match_each_native_ptype;
19use crate::scalar_fn::fns::cast::CastKernel;
20use crate::scalar_fn::fns::cast::CastReduce;
21
22impl CastReduce for Bool {
23    fn cast(array: ArrayView<'_, Bool>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
24        if !dtype.is_boolean() {
25            return Ok(None);
26        }
27
28        let Some(new_validity) = array
29            .validity()?
30            .trivially_cast_nullability(dtype.nullability(), array.len())?
31        else {
32            return Ok(None);
33        };
34        Ok(Some(
35            BoolArray::new(array.to_bit_buffer(), new_validity).into_array(),
36        ))
37    }
38}
39
40impl CastKernel for Bool {
41    fn cast(
42        array: ArrayView<'_, Bool>,
43        dtype: &DType,
44        ctx: &mut ExecutionCtx,
45    ) -> VortexResult<Option<ArrayRef>> {
46        if dtype.is_boolean() {
47            let new_validity =
48                array
49                    .validity()?
50                    .cast_nullability(dtype.nullability(), array.len(), ctx)?;
51            return Ok(Some(
52                BoolArray::new(array.to_bit_buffer(), new_validity).into_array(),
53            ));
54        }
55
56        let DType::Primitive(new_ptype, new_nullability) = dtype else {
57            return Ok(None);
58        };
59
60        let new_validity =
61            array
62                .validity()?
63                .cast_nullability(*new_nullability, array.len(), ctx)?;
64
65        let bits = array.to_bit_buffer();
66        let len = bits.len();
67
68        Ok(Some(match_each_native_ptype!(*new_ptype, |T| {
69            let (one, zero) = (<T as One>::one(), <T as Zero>::zero());
70            let mut buffer = BufferMut::<T>::with_capacity(len);
71            buffer.extend(bits.iter().map(|v| if v { one } else { zero }));
72            PrimitiveArray::new(buffer.freeze(), new_validity).into_array()
73        })))
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use std::sync::LazyLock;
80
81    use rstest::rstest;
82    use vortex_session::VortexSession;
83
84    use crate::Canonical;
85    use crate::IntoArray;
86    use crate::VortexSessionExecute;
87    use crate::arrays::BoolArray;
88    use crate::builtins::ArrayBuiltins;
89    use crate::compute::conformance::cast::test_cast_conformance;
90    use crate::dtype::DType;
91    use crate::dtype::Nullability;
92    use crate::dtype::PType;
93
94    static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
95
96    #[test]
97    fn try_cast_bool_success() {
98        let bool = BoolArray::from_iter(vec![Some(true), Some(false), Some(true)]);
99
100        let res = bool
101            .into_array()
102            .cast(DType::Bool(Nullability::NonNullable));
103        assert!(res.is_ok());
104        assert_eq!(res.unwrap().dtype(), &DType::Bool(Nullability::NonNullable));
105    }
106
107    #[test]
108    fn try_cast_bool_fail() {
109        // When the validity array's min stat is not cached, the reduce rule defers and the
110        // failure surfaces during execution via the kernel (cast_nullability -> compute_min).
111        let bool = BoolArray::from_iter(vec![Some(true), Some(false), None]);
112        let mut ctx = SESSION.create_execution_ctx();
113        let result = bool
114            .into_array()
115            .cast(DType::Bool(Nullability::NonNullable))
116            .and_then(|a| a.execute::<Canonical>(&mut ctx).map(|c| c.into_array()));
117        assert!(result.is_err(), "Expected error, got: {result:?}");
118    }
119
120    #[rstest]
121    #[case(BoolArray::from_iter(vec![true, false, true, true, false]))]
122    #[case(BoolArray::from_iter(vec![Some(true), Some(false), None, Some(true), None]))]
123    #[case(BoolArray::from_iter(vec![true]))]
124    #[case(BoolArray::from_iter(vec![false, false]))]
125    fn test_cast_bool_conformance(#[case] array: BoolArray) {
126        test_cast_conformance(&array.into_array());
127    }
128
129    #[rstest]
130    #[case(PType::I8)]
131    #[case(PType::I32)]
132    #[case(PType::I64)]
133    #[case(PType::U8)]
134    #[case(PType::U64)]
135    #[case(PType::F32)]
136    #[case(PType::F64)]
137    fn cast_bool_to_primitive(#[case] target: PType) {
138        let mut ctx = SESSION.create_execution_ctx();
139        let arr = BoolArray::from_iter(vec![true, false, true]).into_array();
140        let out = arr
141            .cast(DType::Primitive(target, Nullability::NonNullable))
142            .unwrap();
143        let out = out.execute::<Canonical>(&mut ctx).unwrap().into_array();
144        assert_eq!(out.len(), 3);
145    }
146}