vortex_sequence/compute/
compare.rs

1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
2use vortex_array::compute::{CompareKernel, Operator};
3use vortex_array::validity::Validity;
4use vortex_array::{Array, ArrayRef};
5use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
6use vortex_error::{VortexExpect, VortexResult};
7use vortex_scalar::{PValue, Scalar};
8
9use crate::SequenceArray;
10use crate::array::SequenceVTable;
11
12impl CompareKernel for SequenceVTable {
13    fn compare(
14        &self,
15        lhs: &SequenceArray,
16        rhs: &dyn Array,
17        operator: Operator,
18    ) -> VortexResult<Option<ArrayRef>> {
19        if operator != Operator::Eq {
20            return Ok(None);
21        };
22
23        let Some(constant) = rhs.as_constant() else {
24            return Ok(None);
25        };
26
27        // Check if there exists an integer solution to const = base + (0..len) * multiplier.
28        let set_idx = find_intersection_scalar(
29            lhs.base(),
30            lhs.multiplier(),
31            lhs.len(),
32            constant
33                .as_primitive()
34                .pvalue()
35                .vortex_expect("non-null constant"),
36        );
37
38        let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
39        let validity = match nullability {
40            Nullability::NonNullable => Validity::NonNullable,
41            Nullability::Nullable => Validity::AllValid,
42        };
43
44        if let Some(set_idx) = set_idx {
45            let buffer = BooleanBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
46            Ok(Some(BoolArray::new(buffer, validity).to_array()))
47        } else {
48            Ok(Some(
49                ConstantArray::new(
50                    Scalar::new(DType::Bool(nullability), false.into()),
51                    lhs.len(),
52                )
53                .to_array(),
54            ))
55        }
56    }
57}
58
59pub(crate) fn find_intersection_scalar(
60    base: PValue,
61    multiplier: PValue,
62    len: usize,
63    intercept: PValue,
64) -> Option<usize> {
65    match_each_integer_ptype!(base.ptype(), |P| {
66        let intercept = intercept
67            .as_primitive()
68            .vortex_expect("constant pvalue matching already validated");
69
70        let base = base
71            .as_primitive::<P>()
72            .vortex_expect("base pvalue matching already validated");
73        let multiplier = multiplier
74            .as_primitive::<P>()
75            .vortex_expect("multiplier pvalue matching already validated");
76
77        find_intersection(base, multiplier, len, intercept)
78    })
79}
80
81fn find_intersection<P: NativePType>(
82    base: P,
83    multiplier: P,
84    len: usize,
85    intercept: P,
86) -> Option<usize> {
87    // Array is non-empty here.
88    let count = <P>::from_usize(len - 1).vortex_expect("idx must fit into type");
89
90    let end_element = base + (multiplier * count);
91
92    (intercept.is_ge(base)
93        && intercept.is_le(end_element)
94        && (intercept - base) % multiplier == P::zero())
95    .then(|| ((intercept - base) / multiplier).to_usize())
96    .flatten()
97}
98
99#[cfg(test)]
100mod tests {
101    use vortex_array::ToCanonical;
102    use vortex_array::arrays::{BoolArray, ConstantArray};
103    use vortex_array::compute::{Operator, compare};
104
105    use crate::SequenceArray;
106
107    #[test]
108    fn test_compare_match() {
109        let lhs = SequenceArray::typed_new(2i64, 1, 4).unwrap();
110
111        let rhs = ConstantArray::new(4i64, lhs.len());
112
113        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
114
115        assert_eq!(
116            result.to_bool().unwrap().boolean_buffer(),
117            BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
118        )
119    }
120
121    #[test]
122    fn test_compare_match_scale() {
123        let lhs = SequenceArray::typed_new(2i64, 3, 4).unwrap();
124
125        let rhs = ConstantArray::new(8i64, lhs.len());
126
127        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
128
129        assert_eq!(
130            result.to_bool().unwrap().boolean_buffer(),
131            BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
132        )
133    }
134
135    #[test]
136    fn test_compare_no_match() {
137        let lhs = SequenceArray::typed_new(2i64, 1, 4).unwrap();
138
139        let rhs = ConstantArray::new(1i64, lhs.len());
140
141        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
142
143        assert_eq!(
144            result.to_bool().unwrap().boolean_buffer(),
145            BoolArray::from_iter(vec![false, false, false, false]).boolean_buffer(),
146        )
147    }
148}