1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use vortex_error::{vortex_bail, VortexResult};

use crate::{Array, ArrayDType, IntoArrayVariant};

pub trait AndFn {
    fn and(&self, array: &Array) -> VortexResult<Array>;
}

pub trait OrFn {
    fn or(&self, array: &Array) -> VortexResult<Array>;
}

pub fn and(lhs: &Array, rhs: &Array) -> VortexResult<Array> {
    if lhs.len() != rhs.len() {
        vortex_bail!("Boolean operations aren't supported on arrays of different lengths")
    }

    if !lhs.dtype().is_boolean() || !rhs.dtype().is_boolean() {
        vortex_bail!("Boolean operations are only supported on boolean arrays")
    }

    if let Some(selection) = lhs.with_dyn(|lhs| lhs.and().map(|lhs| lhs.and(rhs))) {
        return selection;
    }

    if let Some(selection) = rhs.with_dyn(|rhs| rhs.and().map(|rhs| rhs.and(lhs))) {
        return selection;
    }

    // If neither side implements `AndFn`, we try to expand the left-hand side into a `BoolArray`, which we know does implement it, and call into that implementation.
    let lhs = lhs.clone().into_bool()?;

    lhs.and(rhs)
}

pub fn or(lhs: &Array, rhs: &Array) -> VortexResult<Array> {
    if lhs.len() != rhs.len() {
        vortex_bail!("Boolean operations aren't supported on arrays of different lengths")
    }

    if !lhs.dtype().is_boolean() || !rhs.dtype().is_boolean() {
        vortex_bail!("Boolean operations are only supported on boolean arrays")
    }

    if let Some(selection) = lhs.with_dyn(|lhs| lhs.or().map(|lhs| lhs.or(rhs))) {
        return selection;
    }

    if let Some(selection) = rhs.with_dyn(|rhs| rhs.or().map(|rhs| rhs.or(lhs))) {
        return selection;
    }

    // If neither side implements `OrFn`, we try to expand the left-hand side into a `BoolArray`, which we know does implement it, and call into that implementation.
    let lhs = lhs.clone().into_bool()?;

    lhs.or(rhs)
}

#[cfg(test)]
mod tests {
    use rstest::rstest;

    use super::*;
    use crate::array::bool::BoolArray;
    use crate::compute::unary::scalar_at;
    use crate::IntoArray;

    #[rstest]
    #[case(BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter())
    .into_array(), BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter())
    .into_array())]
    #[case(BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array(),
        BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter()).into_array())]
    fn test_or(#[case] lhs: Array, #[case] rhs: Array) {
        let r = or(&lhs, &rhs).unwrap();

        let r = r.into_bool().unwrap().into_array();

        let v0 = scalar_at(&r, 0).unwrap().value().as_bool().unwrap();
        let v1 = scalar_at(&r, 1).unwrap().value().as_bool().unwrap();
        let v2 = scalar_at(&r, 2).unwrap().value().as_bool().unwrap();
        let v3 = scalar_at(&r, 3).unwrap().value().as_bool().unwrap();

        assert!(v0.unwrap());
        assert!(v1.unwrap());
        assert!(v2.unwrap());
        assert!(!v3.unwrap());
    }

    #[rstest]
    #[case(BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter())
    .into_array(), BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter())
    .into_array())]
    #[case(BoolArray::from_iter([Some(true), Some(false), Some(true), Some(false)].into_iter()).into_array(),
        BoolArray::from_iter([Some(true), Some(true), Some(false), Some(false)].into_iter()).into_array())]
    fn test_and(#[case] lhs: Array, #[case] rhs: Array) {
        let r = and(&lhs, &rhs).unwrap().into_bool().unwrap().into_array();

        let v0 = scalar_at(&r, 0).unwrap().value().as_bool().unwrap();
        let v1 = scalar_at(&r, 1).unwrap().value().as_bool().unwrap();
        let v2 = scalar_at(&r, 2).unwrap().value().as_bool().unwrap();
        let v3 = scalar_at(&r, 3).unwrap().value().as_bool().unwrap();

        assert!(v0.unwrap());
        assert!(!v1.unwrap());
        assert!(!v2.unwrap());
        assert!(!v3.unwrap());
    }
}