1use std::sync::Arc;
2
3use arrow_array::ArrayRef as ArrowArrayRef;
4use arrow_array::cast::AsArray;
5use arrow_schema::DataType;
6use vortex_dtype::DType;
7use vortex_error::{VortexExpect, VortexResult, vortex_bail};
8
9use crate::arrow::{FromArrowArray, IntoArrowArray};
10use crate::encoding::Encoding;
11use crate::{Array, ArrayRef};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum BinaryOperator {
15 And,
16 AndKleene,
17 Or,
18 OrKleene,
19 }
23
24pub trait BinaryBooleanFn<A> {
25 fn binary_boolean(
26 &self,
27 array: A,
28 other: &dyn Array,
29 op: BinaryOperator,
30 ) -> VortexResult<Option<ArrayRef>>;
31}
32
33impl<E: Encoding> BinaryBooleanFn<&dyn Array> for E
34where
35 E: for<'a> BinaryBooleanFn<&'a E::Array>,
36{
37 fn binary_boolean(
38 &self,
39 lhs: &dyn Array,
40 rhs: &dyn Array,
41 op: BinaryOperator,
42 ) -> VortexResult<Option<ArrayRef>> {
43 let array_ref = lhs
44 .as_any()
45 .downcast_ref::<E::Array>()
46 .vortex_expect("Failed to downcast array");
47
48 BinaryBooleanFn::binary_boolean(self, array_ref, rhs, op)
49 }
50}
51
52pub fn and(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
56 binary_boolean(lhs, rhs, BinaryOperator::And)
57}
58
59pub fn and_kleene(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
61 binary_boolean(lhs, rhs, BinaryOperator::AndKleene)
62}
63
64pub fn or(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
68 binary_boolean(lhs, rhs, BinaryOperator::Or)
69}
70
71pub fn or_kleene(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
73 binary_boolean(lhs, rhs, BinaryOperator::OrKleene)
74}
75
76pub fn binary_boolean(
77 lhs: &dyn Array,
78 rhs: &dyn Array,
79 op: BinaryOperator,
80) -> VortexResult<ArrayRef> {
81 if lhs.len() != rhs.len() {
82 vortex_bail!(
83 "Boolean operations aren't supported on arrays of different lengths: {} and {}",
84 lhs.len(),
85 rhs.len()
86 )
87 }
88 if !lhs.dtype().is_boolean()
89 || !rhs.dtype().is_boolean()
90 || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
91 {
92 vortex_bail!(
93 "Boolean operations are only supported on boolean arrays: {} and {}",
94 lhs.dtype(),
95 rhs.dtype()
96 )
97 }
98
99 let rhs_is_constant = rhs.is_constant();
100
101 if lhs.is_constant() && !rhs_is_constant {
103 return binary_boolean(rhs, lhs, op);
104 }
105
106 if lhs.is_arrow() && (rhs.is_arrow() || rhs_is_constant) {
108 return arrow_boolean(lhs.to_array(), rhs.to_array(), op);
109 }
110
111 if let Some(result) = lhs
113 .vtable()
114 .binary_boolean_fn()
115 .and_then(|f| f.binary_boolean(lhs, rhs, op).transpose())
116 .transpose()?
117 {
118 assert_eq!(
119 result.len(),
120 lhs.len(),
121 "Boolean operation length mismatch {}",
122 lhs.encoding()
123 );
124 assert_eq!(
125 result.dtype(),
126 &DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
127 "Boolean operation dtype mismatch {}",
128 lhs.encoding()
129 );
130 return Ok(result);
131 }
132
133 if let Some(result) = rhs
134 .vtable()
135 .binary_boolean_fn()
136 .and_then(|f| f.binary_boolean(rhs, lhs, op).transpose())
137 .transpose()?
138 {
139 assert_eq!(
140 result.len(),
141 lhs.len(),
142 "Boolean operation length mismatch {}",
143 rhs.encoding()
144 );
145 assert_eq!(
146 result.dtype(),
147 &DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
148 "Boolean operation dtype mismatch {}",
149 rhs.encoding()
150 );
151 return Ok(result);
152 }
153
154 log::debug!(
155 "No boolean implementation found for LHS {}, RHS {}, and operator {:?} (or inverse)",
156 rhs.encoding(),
157 lhs.encoding(),
158 op,
159 );
160
161 arrow_boolean(lhs.to_array(), rhs.to_array(), op)
163}
164
165pub(crate) fn arrow_boolean(
170 lhs: ArrayRef,
171 rhs: ArrayRef,
172 operator: BinaryOperator,
173) -> VortexResult<ArrayRef> {
174 let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
175
176 let lhs = lhs.into_arrow(&DataType::Boolean)?.as_boolean().clone();
177 let rhs = rhs.into_arrow(&DataType::Boolean)?.as_boolean().clone();
178
179 let array = match operator {
180 BinaryOperator::And => arrow_arith::boolean::and(&lhs, &rhs)?,
181 BinaryOperator::AndKleene => arrow_arith::boolean::and_kleene(&lhs, &rhs)?,
182 BinaryOperator::Or => arrow_arith::boolean::or(&lhs, &rhs)?,
183 BinaryOperator::OrKleene => arrow_arith::boolean::or_kleene(&lhs, &rhs)?,
184 };
185
186 Ok(ArrayRef::from_arrow(
187 Arc::new(array) as ArrowArrayRef,
188 nullable,
189 ))
190}
191
192#[cfg(test)]
193mod tests {
194 use rstest::rstest;
195
196 use super::*;
197 use crate::arrays::BoolArray;
198 use crate::canonical::ToCanonical;
199 use crate::compute::scalar_at;
200
201 #[rstest]
202 #[case(BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter())
203 .into_array(), BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter())
204 .into_array())]
205 #[case(BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array(),
206 BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter()).into_array())]
207 fn test_or(#[case] lhs: ArrayRef, #[case] rhs: ArrayRef) {
208 let r = or(&lhs, &rhs).unwrap();
209
210 let r = r.to_bool().unwrap().into_array();
211
212 let v0 = scalar_at(&r, 0).unwrap().as_bool().value();
213 let v1 = scalar_at(&r, 1).unwrap().as_bool().value();
214 let v2 = scalar_at(&r, 2).unwrap().as_bool().value();
215 let v3 = scalar_at(&r, 3).unwrap().as_bool().value();
216
217 assert!(v0.unwrap());
218 assert!(v1.unwrap());
219 assert!(v2.unwrap());
220 assert!(!v3.unwrap());
221 }
222
223 #[rstest]
224 #[case(BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter())
225 .into_array(), BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter())
226 .into_array())]
227 #[case(BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array(),
228 BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter()).into_array())]
229 fn test_and(#[case] lhs: ArrayRef, #[case] rhs: ArrayRef) {
230 let r = and(&lhs, &rhs).unwrap().to_bool().unwrap().into_array();
231
232 let v0 = scalar_at(&r, 0).unwrap().as_bool().value();
233 let v1 = scalar_at(&r, 1).unwrap().as_bool().value();
234 let v2 = scalar_at(&r, 2).unwrap().as_bool().value();
235 let v3 = scalar_at(&r, 3).unwrap().as_bool().value();
236
237 assert!(v0.unwrap());
238 assert!(!v1.unwrap());
239 assert!(!v2.unwrap());
240 assert!(!v3.unwrap());
241 }
242}