vortex_array/compute/
boolean.rs

1use std::sync::Arc;
2
3use arrow_array::ArrayRef as ArrowArrayRef;
4use arrow_array::cast::AsArray;
5use arrow_schema::DataType;
6use vortex_dtype::DType;
7use vortex_error::{VortexExpect, VortexResult, vortex_bail};
8
9use crate::arrow::{FromArrowArray, IntoArrowArray};
10use crate::encoding::Encoding;
11use crate::{Array, ArrayRef};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum BinaryOperator {
15    And,
16    AndKleene,
17    Or,
18    OrKleene,
19    // AndNot,
20    // AndNotKleene,
21    // Xor,
22}
23
24pub trait BinaryBooleanFn<A> {
25    fn binary_boolean(
26        &self,
27        array: A,
28        other: &dyn Array,
29        op: BinaryOperator,
30    ) -> VortexResult<Option<ArrayRef>>;
31}
32
33impl<E: Encoding> BinaryBooleanFn<&dyn Array> for E
34where
35    E: for<'a> BinaryBooleanFn<&'a E::Array>,
36{
37    fn binary_boolean(
38        &self,
39        lhs: &dyn Array,
40        rhs: &dyn Array,
41        op: BinaryOperator,
42    ) -> VortexResult<Option<ArrayRef>> {
43        let array_ref = lhs
44            .as_any()
45            .downcast_ref::<E::Array>()
46            .vortex_expect("Failed to downcast array");
47
48        BinaryBooleanFn::binary_boolean(self, array_ref, rhs, op)
49    }
50}
51
52/// Point-wise logical _and_ between two Boolean arrays.
53///
54/// This method uses Arrow-style null propagation rather than the Kleene logic semantics.
55pub fn and(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
56    binary_boolean(lhs, rhs, BinaryOperator::And)
57}
58
59/// Point-wise Kleene logical _and_ between two Boolean arrays.
60pub fn and_kleene(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
61    binary_boolean(lhs, rhs, BinaryOperator::AndKleene)
62}
63
64/// Point-wise logical _or_ between two Boolean arrays.
65///
66/// This method uses Arrow-style null propagation rather than the Kleene logic semantics.
67pub fn or(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
68    binary_boolean(lhs, rhs, BinaryOperator::Or)
69}
70
71/// Point-wise Kleene logical _or_ between two Boolean arrays.
72pub fn or_kleene(lhs: &dyn Array, rhs: &dyn Array) -> VortexResult<ArrayRef> {
73    binary_boolean(lhs, rhs, BinaryOperator::OrKleene)
74}
75
76pub fn binary_boolean(
77    lhs: &dyn Array,
78    rhs: &dyn Array,
79    op: BinaryOperator,
80) -> VortexResult<ArrayRef> {
81    if lhs.len() != rhs.len() {
82        vortex_bail!(
83            "Boolean operations aren't supported on arrays of different lengths: {} and {}",
84            lhs.len(),
85            rhs.len()
86        )
87    }
88    if !lhs.dtype().is_boolean()
89        || !rhs.dtype().is_boolean()
90        || !lhs.dtype().eq_ignore_nullability(rhs.dtype())
91    {
92        vortex_bail!(
93            "Boolean operations are only supported on boolean arrays: {} and {}",
94            lhs.dtype(),
95            rhs.dtype()
96        )
97    }
98
99    let rhs_is_constant = rhs.is_constant();
100
101    // If LHS is constant, then we make sure it's on the RHS.
102    if lhs.is_constant() && !rhs_is_constant {
103        return binary_boolean(rhs, lhs, op);
104    }
105
106    // If the RHS is constant and the LHS is Arrow, we can't do any better than arrow_compare.
107    if lhs.is_arrow() && (rhs.is_arrow() || rhs_is_constant) {
108        return arrow_boolean(lhs.to_array(), rhs.to_array(), op);
109    }
110
111    // Check if either LHS or RHS supports the operation directly.
112    if let Some(result) = lhs
113        .vtable()
114        .binary_boolean_fn()
115        .and_then(|f| f.binary_boolean(lhs, rhs, op).transpose())
116        .transpose()?
117    {
118        assert_eq!(
119            result.len(),
120            lhs.len(),
121            "Boolean operation length mismatch {}",
122            lhs.encoding()
123        );
124        assert_eq!(
125            result.dtype(),
126            &DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
127            "Boolean operation dtype mismatch {}",
128            lhs.encoding()
129        );
130        return Ok(result);
131    }
132
133    if let Some(result) = rhs
134        .vtable()
135        .binary_boolean_fn()
136        .and_then(|f| f.binary_boolean(rhs, lhs, op).transpose())
137        .transpose()?
138    {
139        assert_eq!(
140            result.len(),
141            lhs.len(),
142            "Boolean operation length mismatch {}",
143            rhs.encoding()
144        );
145        assert_eq!(
146            result.dtype(),
147            &DType::Bool((lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into()),
148            "Boolean operation dtype mismatch {}",
149            rhs.encoding()
150        );
151        return Ok(result);
152    }
153
154    log::debug!(
155        "No boolean implementation found for LHS {}, RHS {}, and operator {:?} (or inverse)",
156        rhs.encoding(),
157        lhs.encoding(),
158        op,
159    );
160
161    // If neither side implements the trait, then we delegate to Arrow compute.
162    arrow_boolean(lhs.to_array(), rhs.to_array(), op)
163}
164
165/// Implementation of `BinaryBooleanFn` using the Arrow crate.
166///
167/// Note that other encodings should handle a constant RHS value, so we can assume here that
168/// the RHS is not constant and expand to a full array.
169pub(crate) fn arrow_boolean(
170    lhs: ArrayRef,
171    rhs: ArrayRef,
172    operator: BinaryOperator,
173) -> VortexResult<ArrayRef> {
174    let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable();
175
176    let lhs = lhs.into_arrow(&DataType::Boolean)?.as_boolean().clone();
177    let rhs = rhs.into_arrow(&DataType::Boolean)?.as_boolean().clone();
178
179    let array = match operator {
180        BinaryOperator::And => arrow_arith::boolean::and(&lhs, &rhs)?,
181        BinaryOperator::AndKleene => arrow_arith::boolean::and_kleene(&lhs, &rhs)?,
182        BinaryOperator::Or => arrow_arith::boolean::or(&lhs, &rhs)?,
183        BinaryOperator::OrKleene => arrow_arith::boolean::or_kleene(&lhs, &rhs)?,
184    };
185
186    Ok(ArrayRef::from_arrow(
187        Arc::new(array) as ArrowArrayRef,
188        nullable,
189    ))
190}
191
192#[cfg(test)]
193mod tests {
194    use rstest::rstest;
195
196    use super::*;
197    use crate::arrays::BoolArray;
198    use crate::canonical::ToCanonical;
199    use crate::compute::scalar_at;
200
201    #[rstest]
202    #[case(BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter())
203    .into_array(), BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter())
204    .into_array())]
205    #[case(BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array(),
206        BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter()).into_array())]
207    fn test_or(#[case] lhs: ArrayRef, #[case] rhs: ArrayRef) {
208        let r = or(&lhs, &rhs).unwrap();
209
210        let r = r.to_bool().unwrap().into_array();
211
212        let v0 = scalar_at(&r, 0).unwrap().as_bool().value();
213        let v1 = scalar_at(&r, 1).unwrap().as_bool().value();
214        let v2 = scalar_at(&r, 2).unwrap().as_bool().value();
215        let v3 = scalar_at(&r, 3).unwrap().as_bool().value();
216
217        assert!(v0.unwrap());
218        assert!(v1.unwrap());
219        assert!(v2.unwrap());
220        assert!(!v3.unwrap());
221    }
222
223    #[rstest]
224    #[case(BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter())
225    .into_array(), BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter())
226    .into_array())]
227    #[case(BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array(),
228        BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter()).into_array())]
229    fn test_and(#[case] lhs: ArrayRef, #[case] rhs: ArrayRef) {
230        let r = and(&lhs, &rhs).unwrap().to_bool().unwrap().into_array();
231
232        let v0 = scalar_at(&r, 0).unwrap().as_bool().value();
233        let v1 = scalar_at(&r, 1).unwrap().as_bool().value();
234        let v2 = scalar_at(&r, 2).unwrap().as_bool().value();
235        let v3 = scalar_at(&r, 3).unwrap().as_bool().value();
236
237        assert!(v0.unwrap());
238        assert!(!v1.unwrap());
239        assert!(!v2.unwrap());
240        assert!(!v3.unwrap());
241    }
242}