vortex_array/arrays/bool/compute/
cast.rs1use num_traits::One;
5use num_traits::Zero;
6use vortex_buffer::BufferMut;
7use vortex_error::VortexResult;
8
9use crate::ArrayRef;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::array::ArrayView;
13use crate::arrays::Bool;
14use crate::arrays::BoolArray;
15use crate::arrays::PrimitiveArray;
16use crate::arrays::bool::BoolArrayExt;
17use crate::dtype::DType;
18use crate::match_each_native_ptype;
19use crate::scalar_fn::fns::cast::CastKernel;
20use crate::scalar_fn::fns::cast::CastReduce;
21
22impl CastReduce for Bool {
23 fn cast(array: ArrayView<'_, Bool>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
24 if !dtype.is_boolean() {
25 return Ok(None);
26 }
27
28 let Some(new_validity) = array
29 .validity()?
30 .trivially_cast_nullability(dtype.nullability(), array.len())?
31 else {
32 return Ok(None);
33 };
34 Ok(Some(
35 BoolArray::new(array.to_bit_buffer(), new_validity).into_array(),
36 ))
37 }
38}
39
40impl CastKernel for Bool {
41 fn cast(
42 array: ArrayView<'_, Bool>,
43 dtype: &DType,
44 ctx: &mut ExecutionCtx,
45 ) -> VortexResult<Option<ArrayRef>> {
46 if dtype.is_boolean() {
47 let new_validity =
48 array
49 .validity()?
50 .cast_nullability(dtype.nullability(), array.len(), ctx)?;
51 return Ok(Some(
52 BoolArray::new(array.to_bit_buffer(), new_validity).into_array(),
53 ));
54 }
55
56 let DType::Primitive(new_ptype, new_nullability) = dtype else {
57 return Ok(None);
58 };
59
60 let new_validity =
61 array
62 .validity()?
63 .cast_nullability(*new_nullability, array.len(), ctx)?;
64
65 let bits = array.to_bit_buffer();
66 let len = bits.len();
67
68 Ok(Some(match_each_native_ptype!(*new_ptype, |T| {
69 let (one, zero) = (<T as One>::one(), <T as Zero>::zero());
70 let mut buffer = BufferMut::<T>::with_capacity(len);
71 buffer.extend(bits.iter().map(|v| if v { one } else { zero }));
72 PrimitiveArray::new(buffer.freeze(), new_validity).into_array()
73 })))
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use std::sync::LazyLock;
80
81 use rstest::rstest;
82 use vortex_session::VortexSession;
83
84 use crate::Canonical;
85 use crate::IntoArray;
86 use crate::VortexSessionExecute;
87 use crate::arrays::BoolArray;
88 use crate::builtins::ArrayBuiltins;
89 use crate::compute::conformance::cast::test_cast_conformance;
90 use crate::dtype::DType;
91 use crate::dtype::Nullability;
92 use crate::dtype::PType;
93
94 static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
95
96 #[test]
97 fn try_cast_bool_success() {
98 let bool = BoolArray::from_iter(vec![Some(true), Some(false), Some(true)]);
99
100 let res = bool
101 .into_array()
102 .cast(DType::Bool(Nullability::NonNullable));
103 assert!(res.is_ok());
104 assert_eq!(res.unwrap().dtype(), &DType::Bool(Nullability::NonNullable));
105 }
106
107 #[test]
108 fn try_cast_bool_fail() {
109 let bool = BoolArray::from_iter(vec![Some(true), Some(false), None]);
112 let mut ctx = SESSION.create_execution_ctx();
113 let result = bool
114 .into_array()
115 .cast(DType::Bool(Nullability::NonNullable))
116 .and_then(|a| a.execute::<Canonical>(&mut ctx).map(|c| c.into_array()));
117 assert!(result.is_err(), "Expected error, got: {result:?}");
118 }
119
120 #[rstest]
121 #[case(BoolArray::from_iter(vec![true, false, true, true, false]))]
122 #[case(BoolArray::from_iter(vec![Some(true), Some(false), None, Some(true), None]))]
123 #[case(BoolArray::from_iter(vec![true]))]
124 #[case(BoolArray::from_iter(vec![false, false]))]
125 fn test_cast_bool_conformance(#[case] array: BoolArray) {
126 test_cast_conformance(&array.into_array());
127 }
128
129 #[rstest]
130 #[case(PType::I8)]
131 #[case(PType::I32)]
132 #[case(PType::I64)]
133 #[case(PType::U8)]
134 #[case(PType::U64)]
135 #[case(PType::F32)]
136 #[case(PType::F64)]
137 fn cast_bool_to_primitive(#[case] target: PType) {
138 let mut ctx = SESSION.create_execution_ctx();
139 let arr = BoolArray::from_iter(vec![true, false, true]).into_array();
140 let out = arr
141 .cast(DType::Primitive(target, Nullability::NonNullable))
142 .unwrap();
143 let out = out.execute::<Canonical>(&mut ctx).unwrap().into_array();
144 assert_eq!(out.len(), 3);
145 }
146}