vortex_sequence/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::arrays::BoolArray;
7use vortex_array::arrays::ConstantArray;
8use vortex_array::compute::CompareKernel;
9use vortex_array::compute::Operator;
10use vortex_array::validity::Validity;
11use vortex_buffer::BitBuffer;
12use vortex_dtype::DType;
13use vortex_dtype::NativePType;
14use vortex_dtype::Nullability;
15use vortex_dtype::match_each_integer_ptype;
16use vortex_error::VortexExpect;
17use vortex_error::VortexResult;
18use vortex_scalar::PValue;
19use vortex_scalar::Scalar;
20
21use crate::SequenceArray;
22use crate::array::SequenceVTable;
23
24impl CompareKernel for SequenceVTable {
25    fn compare(
26        &self,
27        lhs: &SequenceArray,
28        rhs: &dyn Array,
29        operator: Operator,
30    ) -> VortexResult<Option<ArrayRef>> {
31        if operator != Operator::Eq {
32            return Ok(None);
33        };
34
35        let Some(constant) = rhs.as_constant() else {
36            return Ok(None);
37        };
38
39        // Check if there exists an integer solution to const = base + (0..len) * multiplier.
40        let set_idx = find_intersection_scalar(
41            lhs.base(),
42            lhs.multiplier(),
43            lhs.len(),
44            constant
45                .as_primitive()
46                .pvalue()
47                .vortex_expect("non-null constant"),
48        );
49
50        let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
51        let validity = match nullability {
52            Nullability::NonNullable => Validity::NonNullable,
53            Nullability::Nullable => Validity::AllValid,
54        };
55
56        if let Some(set_idx) = set_idx {
57            let buffer = BitBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
58            Ok(Some(
59                BoolArray::from_bit_buffer(buffer, validity).to_array(),
60            ))
61        } else {
62            Ok(Some(
63                ConstantArray::new(
64                    Scalar::new(DType::Bool(nullability), false.into()),
65                    lhs.len(),
66                )
67                .to_array(),
68            ))
69        }
70    }
71}
72
73pub(crate) fn find_intersection_scalar(
74    base: PValue,
75    multiplier: PValue,
76    len: usize,
77    intercept: PValue,
78) -> Option<usize> {
79    match_each_integer_ptype!(base.ptype(), |P| {
80        let intercept = intercept.cast::<P>();
81
82        let base = base.cast::<P>();
83        let multiplier = multiplier.cast::<P>();
84
85        find_intersection(base, multiplier, len, intercept)
86    })
87}
88
89fn find_intersection<P: NativePType>(
90    base: P,
91    multiplier: P,
92    len: usize,
93    intercept: P,
94) -> Option<usize> {
95    // Array is non-empty here.
96    let count = <P>::from_usize(len - 1).vortex_expect("idx must fit into type");
97
98    let end_element = base + (multiplier * count);
99
100    (intercept.is_ge(base)
101        && intercept.is_le(end_element)
102        && (intercept - base) % multiplier == P::zero())
103    .then(|| ((intercept - base) / multiplier).to_usize())
104    .flatten()
105}
106
107#[cfg(test)]
108mod tests {
109    use vortex_array::ToCanonical;
110    use vortex_array::arrays::BoolArray;
111    use vortex_array::arrays::ConstantArray;
112    use vortex_array::compute::Operator;
113    use vortex_array::compute::compare;
114    use vortex_dtype::Nullability::NonNullable;
115    use vortex_dtype::Nullability::Nullable;
116
117    use crate::SequenceArray;
118
119    #[test]
120    fn test_compare_match() {
121        let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
122
123        let rhs = ConstantArray::new(4i64, lhs.len());
124
125        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
126
127        assert_eq!(
128            result.to_bool().bit_buffer(),
129            BoolArray::from_iter(vec![false, false, true, false]).bit_buffer(),
130        )
131    }
132
133    #[test]
134    fn test_compare_match_scale() {
135        let lhs = SequenceArray::typed_new(2i64, 3, Nullable, 4).unwrap();
136
137        let rhs = ConstantArray::new(8i64, lhs.len());
138
139        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
140
141        assert_eq!(
142            result.to_bool().bit_buffer(),
143            BoolArray::from_iter(vec![false, false, true, false]).bit_buffer(),
144        )
145    }
146
147    #[test]
148    fn test_compare_no_match() {
149        let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
150
151        let rhs = ConstantArray::new(1i64, lhs.len());
152
153        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
154
155        assert_eq!(
156            result.to_bool().bit_buffer(),
157            BoolArray::from_iter(vec![false, false, false, false]).bit_buffer(),
158        )
159    }
160}