vortex_sequence/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
5use vortex_array::compute::{CompareKernel, Operator};
6use vortex_array::validity::Validity;
7use vortex_array::{Array, ArrayRef};
8use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::{PValue, Scalar};
11
12use crate::SequenceArray;
13use crate::array::SequenceVTable;
14
15impl CompareKernel for SequenceVTable {
16    fn compare(
17        &self,
18        lhs: &SequenceArray,
19        rhs: &dyn Array,
20        operator: Operator,
21    ) -> VortexResult<Option<ArrayRef>> {
22        if operator != Operator::Eq {
23            return Ok(None);
24        };
25
26        let Some(constant) = rhs.as_constant() else {
27            return Ok(None);
28        };
29
30        // Check if there exists an integer solution to const = base + (0..len) * multiplier.
31        let set_idx = find_intersection_scalar(
32            lhs.base(),
33            lhs.multiplier(),
34            lhs.len(),
35            constant
36                .as_primitive()
37                .pvalue()
38                .vortex_expect("non-null constant"),
39        );
40
41        let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
42        let validity = match nullability {
43            Nullability::NonNullable => Validity::NonNullable,
44            Nullability::Nullable => Validity::AllValid,
45        };
46
47        if let Some(set_idx) = set_idx {
48            let buffer = BooleanBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
49            Ok(Some(BoolArray::new(buffer, validity).to_array()))
50        } else {
51            Ok(Some(
52                ConstantArray::new(
53                    Scalar::new(DType::Bool(nullability), false.into()),
54                    lhs.len(),
55                )
56                .to_array(),
57            ))
58        }
59    }
60}
61
62pub(crate) fn find_intersection_scalar(
63    base: PValue,
64    multiplier: PValue,
65    len: usize,
66    intercept: PValue,
67) -> Option<usize> {
68    match_each_integer_ptype!(base.ptype(), |P| {
69        let intercept = intercept.as_primitive::<P>();
70
71        let base = base.as_primitive::<P>();
72        let multiplier = multiplier.as_primitive::<P>();
73
74        find_intersection(base, multiplier, len, intercept)
75    })
76}
77
78fn find_intersection<P: NativePType>(
79    base: P,
80    multiplier: P,
81    len: usize,
82    intercept: P,
83) -> Option<usize> {
84    // Array is non-empty here.
85    let count = <P>::from_usize(len - 1).vortex_expect("idx must fit into type");
86
87    let end_element = base + (multiplier * count);
88
89    (intercept.is_ge(base)
90        && intercept.is_le(end_element)
91        && (intercept - base) % multiplier == P::zero())
92    .then(|| ((intercept - base) / multiplier).to_usize())
93    .flatten()
94}
95
96#[cfg(test)]
97mod tests {
98    use vortex_array::ToCanonical;
99    use vortex_array::arrays::{BoolArray, ConstantArray};
100    use vortex_array::compute::{Operator, compare};
101    use vortex_dtype::Nullability::{NonNullable, Nullable};
102
103    use crate::SequenceArray;
104
105    #[test]
106    fn test_compare_match() {
107        let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
108
109        let rhs = ConstantArray::new(4i64, lhs.len());
110
111        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
112
113        assert_eq!(
114            result.to_bool().unwrap().boolean_buffer(),
115            BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
116        )
117    }
118
119    #[test]
120    fn test_compare_match_scale() {
121        let lhs = SequenceArray::typed_new(2i64, 3, Nullable, 4).unwrap();
122
123        let rhs = ConstantArray::new(8i64, lhs.len());
124
125        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
126
127        assert_eq!(
128            result.to_bool().unwrap().boolean_buffer(),
129            BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
130        )
131    }
132
133    #[test]
134    fn test_compare_no_match() {
135        let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
136
137        let rhs = ConstantArray::new(1i64, lhs.len());
138
139        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
140
141        assert_eq!(
142            result.to_bool().unwrap().boolean_buffer(),
143            BoolArray::from_iter(vec![false, false, false, false]).boolean_buffer(),
144        )
145    }
146}