vortex_sequence/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
5use vortex_array::compute::{CompareKernel, Operator};
6use vortex_array::validity::Validity;
7use vortex_array::{Array, ArrayRef};
8use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::{PValue, Scalar};
11
12use crate::SequenceArray;
13use crate::array::SequenceVTable;
14
15impl CompareKernel for SequenceVTable {
16    fn compare(
17        &self,
18        lhs: &SequenceArray,
19        rhs: &dyn Array,
20        operator: Operator,
21    ) -> VortexResult<Option<ArrayRef>> {
22        if operator != Operator::Eq {
23            return Ok(None);
24        };
25
26        let Some(constant) = rhs.as_constant() else {
27            return Ok(None);
28        };
29
30        // Check if there exists an integer solution to const = base + (0..len) * multiplier.
31        let set_idx = find_intersection_scalar(
32            lhs.base(),
33            lhs.multiplier(),
34            lhs.len(),
35            constant
36                .as_primitive()
37                .pvalue()
38                .vortex_expect("non-null constant"),
39        );
40
41        let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
42        let validity = match nullability {
43            Nullability::NonNullable => Validity::NonNullable,
44            Nullability::Nullable => Validity::AllValid,
45        };
46
47        if let Some(set_idx) = set_idx {
48            let buffer = BooleanBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
49            Ok(Some(
50                BoolArray::from_bool_buffer(buffer, validity).to_array(),
51            ))
52        } else {
53            Ok(Some(
54                ConstantArray::new(
55                    Scalar::new(DType::Bool(nullability), false.into()),
56                    lhs.len(),
57                )
58                .to_array(),
59            ))
60        }
61    }
62}
63
64pub(crate) fn find_intersection_scalar(
65    base: PValue,
66    multiplier: PValue,
67    len: usize,
68    intercept: PValue,
69) -> Option<usize> {
70    match_each_integer_ptype!(base.ptype(), |P| {
71        let intercept = intercept.as_primitive::<P>();
72
73        let base = base.as_primitive::<P>();
74        let multiplier = multiplier.as_primitive::<P>();
75
76        find_intersection(base, multiplier, len, intercept)
77    })
78}
79
80fn find_intersection<P: NativePType>(
81    base: P,
82    multiplier: P,
83    len: usize,
84    intercept: P,
85) -> Option<usize> {
86    // Array is non-empty here.
87    let count = <P>::from_usize(len - 1).vortex_expect("idx must fit into type");
88
89    let end_element = base + (multiplier * count);
90
91    (intercept.is_ge(base)
92        && intercept.is_le(end_element)
93        && (intercept - base) % multiplier == P::zero())
94    .then(|| ((intercept - base) / multiplier).to_usize())
95    .flatten()
96}
97
98#[cfg(test)]
99mod tests {
100    use vortex_array::ToCanonical;
101    use vortex_array::arrays::{BoolArray, ConstantArray};
102    use vortex_array::compute::{Operator, compare};
103    use vortex_dtype::Nullability::{NonNullable, Nullable};
104
105    use crate::SequenceArray;
106
107    #[test]
108    fn test_compare_match() {
109        let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
110
111        let rhs = ConstantArray::new(4i64, lhs.len());
112
113        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
114
115        assert_eq!(
116            result.to_bool().boolean_buffer(),
117            BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
118        )
119    }
120
121    #[test]
122    fn test_compare_match_scale() {
123        let lhs = SequenceArray::typed_new(2i64, 3, Nullable, 4).unwrap();
124
125        let rhs = ConstantArray::new(8i64, lhs.len());
126
127        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
128
129        assert_eq!(
130            result.to_bool().boolean_buffer(),
131            BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
132        )
133    }
134
135    #[test]
136    fn test_compare_no_match() {
137        let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
138
139        let rhs = ConstantArray::new(1i64, lhs.len());
140
141        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
142
143        assert_eq!(
144            result.to_bool().boolean_buffer(),
145            BoolArray::from_iter(vec![false, false, false, false]).boolean_buffer(),
146        )
147    }
148}