vortex_sequence/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::{BoolArray, ConstantArray};
5use vortex_array::compute::{CompareKernel, Operator};
6use vortex_array::validity::Validity;
7use vortex_array::{Array, ArrayRef};
8use vortex_buffer::BitBuffer;
9use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
10use vortex_error::{VortexExpect, VortexResult};
11use vortex_scalar::{PValue, Scalar};
12
13use crate::SequenceArray;
14use crate::array::SequenceVTable;
15
16impl CompareKernel for SequenceVTable {
17    fn compare(
18        &self,
19        lhs: &SequenceArray,
20        rhs: &dyn Array,
21        operator: Operator,
22    ) -> VortexResult<Option<ArrayRef>> {
23        if operator != Operator::Eq {
24            return Ok(None);
25        };
26
27        let Some(constant) = rhs.as_constant() else {
28            return Ok(None);
29        };
30
31        // Check if there exists an integer solution to const = base + (0..len) * multiplier.
32        let set_idx = find_intersection_scalar(
33            lhs.base(),
34            lhs.multiplier(),
35            lhs.len(),
36            constant
37                .as_primitive()
38                .pvalue()
39                .vortex_expect("non-null constant"),
40        );
41
42        let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
43        let validity = match nullability {
44            Nullability::NonNullable => Validity::NonNullable,
45            Nullability::Nullable => Validity::AllValid,
46        };
47
48        if let Some(set_idx) = set_idx {
49            let buffer = BitBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
50            Ok(Some(
51                BoolArray::from_bit_buffer(buffer, validity).to_array(),
52            ))
53        } else {
54            Ok(Some(
55                ConstantArray::new(
56                    Scalar::new(DType::Bool(nullability), false.into()),
57                    lhs.len(),
58                )
59                .to_array(),
60            ))
61        }
62    }
63}
64
65pub(crate) fn find_intersection_scalar(
66    base: PValue,
67    multiplier: PValue,
68    len: usize,
69    intercept: PValue,
70) -> Option<usize> {
71    match_each_integer_ptype!(base.ptype(), |P| {
72        let intercept = intercept.cast::<P>();
73
74        let base = base.cast::<P>();
75        let multiplier = multiplier.cast::<P>();
76
77        find_intersection(base, multiplier, len, intercept)
78    })
79}
80
81fn find_intersection<P: NativePType>(
82    base: P,
83    multiplier: P,
84    len: usize,
85    intercept: P,
86) -> Option<usize> {
87    // Array is non-empty here.
88    let count = <P>::from_usize(len - 1).vortex_expect("idx must fit into type");
89
90    let end_element = base + (multiplier * count);
91
92    (intercept.is_ge(base)
93        && intercept.is_le(end_element)
94        && (intercept - base) % multiplier == P::zero())
95    .then(|| ((intercept - base) / multiplier).to_usize())
96    .flatten()
97}
98
99#[cfg(test)]
100mod tests {
101    use vortex_array::ToCanonical;
102    use vortex_array::arrays::{BoolArray, ConstantArray};
103    use vortex_array::compute::{Operator, compare};
104    use vortex_dtype::Nullability::{NonNullable, Nullable};
105
106    use crate::SequenceArray;
107
108    #[test]
109    fn test_compare_match() {
110        let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
111
112        let rhs = ConstantArray::new(4i64, lhs.len());
113
114        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
115
116        assert_eq!(
117            result.to_bool().bit_buffer(),
118            BoolArray::from_iter(vec![false, false, true, false]).bit_buffer(),
119        )
120    }
121
122    #[test]
123    fn test_compare_match_scale() {
124        let lhs = SequenceArray::typed_new(2i64, 3, Nullable, 4).unwrap();
125
126        let rhs = ConstantArray::new(8i64, lhs.len());
127
128        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
129
130        assert_eq!(
131            result.to_bool().bit_buffer(),
132            BoolArray::from_iter(vec![false, false, true, false]).bit_buffer(),
133        )
134    }
135
136    #[test]
137    fn test_compare_no_match() {
138        let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
139
140        let rhs = ConstantArray::new(1i64, lhs.len());
141
142        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
143
144        assert_eq!(
145            result.to_bool().bit_buffer(),
146            BoolArray::from_iter(vec![false, false, false, false]).bit_buffer(),
147        )
148    }
149}