Skip to main content

vortex_sequence/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::ExecutionCtx;
7use vortex_array::arrays::BoolArray;
8use vortex_array::arrays::ConstantArray;
9use vortex_array::compute::Operator;
10use vortex_array::expr::CompareKernel;
11use vortex_array::scalar::PValue;
12use vortex_array::scalar::Scalar;
13use vortex_buffer::BitBuffer;
14use vortex_dtype::NativePType;
15use vortex_dtype::Nullability;
16use vortex_dtype::match_each_integer_ptype;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20use vortex_error::vortex_err;
21
22use crate::SequenceArray;
23use crate::array::SequenceVTable;
24
25impl CompareKernel for SequenceVTable {
26    fn compare(
27        lhs: &SequenceArray,
28        rhs: &dyn Array,
29        operator: Operator,
30        _ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        // TODO(joe): support other operators (NotEq, Lt, Lte, Gt, Gte) in encoded space.
33        if operator != Operator::Eq {
34            return Ok(None);
35        }
36
37        let Some(constant) = rhs.as_constant() else {
38            return Ok(None);
39        };
40
41        // Check if there exists an integer solution to const = base + (0..len) * multiplier.
42        let set_idx = find_intersection_scalar(
43            lhs.base(),
44            lhs.multiplier(),
45            lhs.len(),
46            constant
47                .as_primitive()
48                .pvalue()
49                .vortex_expect("null constant handled in adaptor"),
50        );
51
52        let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
53        let validity = match nullability {
54            Nullability::NonNullable => vortex_array::validity::Validity::NonNullable,
55            Nullability::Nullable => vortex_array::validity::Validity::AllValid,
56        };
57
58        if let Ok(set_idx) = set_idx {
59            let buffer = BitBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
60            Ok(Some(BoolArray::new(buffer, validity).to_array()))
61        } else {
62            Ok(Some(
63                ConstantArray::new(Scalar::bool(false, nullability), lhs.len()).to_array(),
64            ))
65        }
66    }
67}
68
69/// Find the index where `base + idx * multiplier == intercept`, if one exists.
70///
71/// # Errors
72/// Return `VortexError` if:
73/// - `len` is 0
74/// - `intercept` or `multiplier` can't be cast to `base`'s PType
75/// - `intercept` is outside the range of the sequence
76/// - `intercept` doesn't fall exactly on a sequence value
77pub(crate) fn find_intersection_scalar(
78    base: PValue,
79    multiplier: PValue,
80    len: usize,
81    intercept: PValue,
82) -> VortexResult<usize> {
83    match_each_integer_ptype!(base.ptype(), |P| {
84        let intercept = intercept.cast::<P>()?;
85        let base = base.cast::<P>()?;
86        let multiplier = multiplier.cast::<P>()?;
87        find_intersection(base, multiplier, len, intercept)
88    })
89}
90
91fn find_intersection<P: NativePType>(
92    base: P,
93    multiplier: P,
94    len: usize,
95    intercept: P,
96) -> VortexResult<usize> {
97    if len == 0 {
98        vortex_bail!("len == 0")
99    }
100
101    let count = P::from_usize(len - 1).vortex_expect("idx must fit into type");
102    let end_element = base + (multiplier * count);
103
104    // Handle ascending vs descending sequences
105    let (min_val, max_val) = if multiplier.is_ge(P::zero()) {
106        (base, end_element)
107    } else {
108        (end_element, base)
109    };
110
111    // Check if intercept is in range
112    if !intercept.is_ge(min_val) || !intercept.is_le(max_val) {
113        vortex_bail!("{intercept} is outside of ({min_val}, {max_val}) range")
114    }
115
116    // Handle zero multiplier (constant sequence)
117    if multiplier == P::zero() {
118        if intercept == base {
119            return Ok(0);
120        } else {
121            vortex_bail!("{intercept} != {base} with zero multiplier")
122        }
123    }
124
125    // Check if (intercept - base) is evenly divisible by multiplier
126    let diff = intercept - base;
127    if diff % multiplier != P::zero() {
128        vortex_bail!("{diff} % {multiplier} != 0")
129    }
130
131    let idx = diff / multiplier;
132    idx.to_usize()
133        .ok_or_else(|| vortex_err!("Cannot represent {idx} as usize"))
134}
135
136#[cfg(test)]
137mod tests {
138    use vortex_array::arrays::BoolArray;
139    use vortex_array::arrays::ConstantArray;
140    use vortex_array::assert_arrays_eq;
141    use vortex_array::compute::Operator;
142    use vortex_array::compute::compare;
143    use vortex_dtype::Nullability::NonNullable;
144    use vortex_dtype::Nullability::Nullable;
145
146    use crate::SequenceArray;
147
148    #[test]
149    fn test_compare_match() {
150        let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
151        let rhs = ConstantArray::new(4i64, lhs.len());
152        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
153        let expected = BoolArray::from_iter([false, false, true, false]);
154        assert_arrays_eq!(result, expected);
155    }
156
157    #[test]
158    fn test_compare_match_scale() {
159        let lhs = SequenceArray::typed_new(2i64, 3, Nullable, 4).unwrap();
160        let rhs = ConstantArray::new(8i64, lhs.len());
161        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
162        let expected = BoolArray::from_iter([Some(false), Some(false), Some(true), Some(false)]);
163        assert_arrays_eq!(result, expected);
164    }
165
166    #[test]
167    fn test_compare_no_match() {
168        let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
169        let rhs = ConstantArray::new(1i64, lhs.len());
170        let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
171        let expected = BoolArray::from_iter([false, false, false, false]);
172        assert_arrays_eq!(result, expected);
173    }
174}