Skip to main content

vortex_sequence/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::BoolArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::dtype::NativePType;
11use vortex_array::dtype::Nullability;
12use vortex_array::match_each_integer_ptype;
13use vortex_array::scalar::PValue;
14use vortex_array::scalar::Scalar;
15use vortex_array::scalar_fn::fns::binary::CompareKernel;
16use vortex_array::scalar_fn::fns::operators::CompareOperator;
17use vortex_buffer::BitBuffer;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_error::vortex_err;
22
23use crate::array::Sequence;
24
25impl CompareKernel for Sequence {
26    fn compare(
27        lhs: ArrayView<'_, Self>,
28        rhs: &ArrayRef,
29        operator: CompareOperator,
30        _ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        // TODO(joe): support other operators (NotEq, Lt, Lte, Gt, Gte) in encoded space.
33        if operator != CompareOperator::Eq {
34            return Ok(None);
35        }
36
37        let Some(constant) = rhs.as_constant() else {
38            return Ok(None);
39        };
40
41        // Check if there exists an integer solution to const = base + (0..len) * multiplier.
42        let set_idx = find_intersection_scalar(
43            lhs.base(),
44            lhs.multiplier(),
45            lhs.len(),
46            constant
47                .as_primitive()
48                .pvalue()
49                .vortex_expect("null constant handled in adaptor"),
50        );
51
52        let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
53        let validity = match nullability {
54            Nullability::NonNullable => vortex_array::validity::Validity::NonNullable,
55            Nullability::Nullable => vortex_array::validity::Validity::AllValid,
56        };
57
58        if let Ok(set_idx) = set_idx {
59            let buffer = BitBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
60            Ok(Some(BoolArray::new(buffer, validity).into_array()))
61        } else {
62            Ok(Some(
63                ConstantArray::new(Scalar::bool(false, nullability), lhs.len()).into_array(),
64            ))
65        }
66    }
67}
68
69/// Find the index where `base + idx * multiplier == intercept`, if one exists.
70///
71/// # Errors
72/// Return `VortexError` if:
73/// - `len` is 0
74/// - `intercept` or `multiplier` can't be cast to `base`'s PType
75/// - `intercept` is outside the range of the sequence
76/// - `intercept` doesn't fall exactly on a sequence value
77pub(crate) fn find_intersection_scalar(
78    base: PValue,
79    multiplier: PValue,
80    len: usize,
81    intercept: PValue,
82) -> VortexResult<usize> {
83    match_each_integer_ptype!(base.ptype(), |P| {
84        let intercept = intercept.cast::<P>()?;
85        let base = base.cast::<P>()?;
86        let multiplier = multiplier.cast::<P>()?;
87        find_intersection(base, multiplier, len, intercept)
88    })
89}
90
91fn find_intersection<P: NativePType>(
92    base: P,
93    multiplier: P,
94    len: usize,
95    intercept: P,
96) -> VortexResult<usize> {
97    if len == 0 {
98        vortex_bail!("len == 0")
99    }
100
101    let count = P::from_usize(len - 1).vortex_expect("idx must fit into type");
102    let end_element = base + (multiplier * count);
103
104    // Handle ascending vs descending sequences
105    let (min_val, max_val) = if multiplier.is_ge(P::zero()) {
106        (base, end_element)
107    } else {
108        (end_element, base)
109    };
110
111    // Check if intercept is in range
112    if !intercept.is_ge(min_val) || !intercept.is_le(max_val) {
113        vortex_bail!("{intercept} is outside of ({min_val}, {max_val}) range")
114    }
115
116    // Handle zero multiplier (constant sequence)
117    if multiplier == P::zero() {
118        if intercept == base {
119            return Ok(0);
120        } else {
121            vortex_bail!("{intercept} != {base} with zero multiplier")
122        }
123    }
124
125    // Check if (intercept - base) is evenly divisible by multiplier
126    let diff = intercept - base;
127    if diff % multiplier != P::zero() {
128        vortex_bail!("{diff} % {multiplier} != 0")
129    }
130
131    let idx = diff / multiplier;
132    idx.to_usize()
133        .ok_or_else(|| vortex_err!("Cannot represent {idx} as usize"))
134}
135
136#[cfg(test)]
137mod tests {
138    use vortex_array::IntoArray;
139    use vortex_array::arrays::BoolArray;
140    use vortex_array::arrays::ConstantArray;
141    use vortex_array::assert_arrays_eq;
142    use vortex_array::builtins::ArrayBuiltins;
143    use vortex_array::dtype::Nullability::NonNullable;
144    use vortex_array::dtype::Nullability::Nullable;
145    use vortex_array::scalar_fn::fns::operators::Operator;
146
147    use crate::Sequence;
148
149    #[test]
150    fn test_compare_match() {
151        let lhs = Sequence::try_new_typed(2i64, 1, NonNullable, 4).unwrap();
152        let rhs = ConstantArray::new(4i64, lhs.len());
153        let result = lhs
154            .into_array()
155            .binary(rhs.into_array(), Operator::Eq)
156            .unwrap();
157        let expected = BoolArray::from_iter([false, false, true, false]);
158        assert_arrays_eq!(result, expected);
159    }
160
161    #[test]
162    fn test_compare_match_scale() {
163        let lhs = Sequence::try_new_typed(2i64, 3, Nullable, 4).unwrap();
164        let rhs = ConstantArray::new(8i64, lhs.len());
165        let result = lhs
166            .into_array()
167            .binary(rhs.into_array(), Operator::Eq)
168            .unwrap();
169        let expected = BoolArray::from_iter([Some(false), Some(false), Some(true), Some(false)]);
170        assert_arrays_eq!(result, expected);
171    }
172
173    #[test]
174    fn test_compare_no_match() {
175        let lhs = Sequence::try_new_typed(2i64, 1, NonNullable, 4).unwrap();
176        let rhs = ConstantArray::new(1i64, lhs.len());
177        let result = lhs
178            .into_array()
179            .binary(rhs.into_array(), Operator::Eq)
180            .unwrap();
181        let expected = BoolArray::from_iter([false, false, false, false]);
182        assert_arrays_eq!(result, expected);
183    }
184}