Skip to main content

vortex_sequence/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::DynArray;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::BoolArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::dtype::NativePType;
11use vortex_array::dtype::Nullability;
12use vortex_array::match_each_integer_ptype;
13use vortex_array::scalar::PValue;
14use vortex_array::scalar::Scalar;
15use vortex_array::scalar_fn::fns::binary::CompareKernel;
16use vortex_array::scalar_fn::fns::operators::CompareOperator;
17use vortex_buffer::BitBuffer;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_error::vortex_err;
22
23use crate::SequenceArray;
24use crate::array::SequenceVTable;
25
26impl CompareKernel for SequenceVTable {
27    fn compare(
28        lhs: &SequenceArray,
29        rhs: &ArrayRef,
30        operator: CompareOperator,
31        _ctx: &mut ExecutionCtx,
32    ) -> VortexResult<Option<ArrayRef>> {
33        // TODO(joe): support other operators (NotEq, Lt, Lte, Gt, Gte) in encoded space.
34        if operator != CompareOperator::Eq {
35            return Ok(None);
36        }
37
38        let Some(constant) = rhs.as_constant() else {
39            return Ok(None);
40        };
41
42        // Check if there exists an integer solution to const = base + (0..len) * multiplier.
43        let set_idx = find_intersection_scalar(
44            lhs.base(),
45            lhs.multiplier(),
46            lhs.len(),
47            constant
48                .as_primitive()
49                .pvalue()
50                .vortex_expect("null constant handled in adaptor"),
51        );
52
53        let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
54        let validity = match nullability {
55            Nullability::NonNullable => vortex_array::validity::Validity::NonNullable,
56            Nullability::Nullable => vortex_array::validity::Validity::AllValid,
57        };
58
59        if let Ok(set_idx) = set_idx {
60            let buffer = BitBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
61            Ok(Some(BoolArray::new(buffer, validity).into_array()))
62        } else {
63            Ok(Some(
64                ConstantArray::new(Scalar::bool(false, nullability), lhs.len()).into_array(),
65            ))
66        }
67    }
68}
69
70/// Find the index where `base + idx * multiplier == intercept`, if one exists.
71///
72/// # Errors
73/// Return `VortexError` if:
74/// - `len` is 0
75/// - `intercept` or `multiplier` can't be cast to `base`'s PType
76/// - `intercept` is outside the range of the sequence
77/// - `intercept` doesn't fall exactly on a sequence value
78pub(crate) fn find_intersection_scalar(
79    base: PValue,
80    multiplier: PValue,
81    len: usize,
82    intercept: PValue,
83) -> VortexResult<usize> {
84    match_each_integer_ptype!(base.ptype(), |P| {
85        let intercept = intercept.cast::<P>()?;
86        let base = base.cast::<P>()?;
87        let multiplier = multiplier.cast::<P>()?;
88        find_intersection(base, multiplier, len, intercept)
89    })
90}
91
92fn find_intersection<P: NativePType>(
93    base: P,
94    multiplier: P,
95    len: usize,
96    intercept: P,
97) -> VortexResult<usize> {
98    if len == 0 {
99        vortex_bail!("len == 0")
100    }
101
102    let count = P::from_usize(len - 1).vortex_expect("idx must fit into type");
103    let end_element = base + (multiplier * count);
104
105    // Handle ascending vs descending sequences
106    let (min_val, max_val) = if multiplier.is_ge(P::zero()) {
107        (base, end_element)
108    } else {
109        (end_element, base)
110    };
111
112    // Check if intercept is in range
113    if !intercept.is_ge(min_val) || !intercept.is_le(max_val) {
114        vortex_bail!("{intercept} is outside of ({min_val}, {max_val}) range")
115    }
116
117    // Handle zero multiplier (constant sequence)
118    if multiplier == P::zero() {
119        if intercept == base {
120            return Ok(0);
121        } else {
122            vortex_bail!("{intercept} != {base} with zero multiplier")
123        }
124    }
125
126    // Check if (intercept - base) is evenly divisible by multiplier
127    let diff = intercept - base;
128    if diff % multiplier != P::zero() {
129        vortex_bail!("{diff} % {multiplier} != 0")
130    }
131
132    let idx = diff / multiplier;
133    idx.to_usize()
134        .ok_or_else(|| vortex_err!("Cannot represent {idx} as usize"))
135}
136
137#[cfg(test)]
138mod tests {
139    use vortex_array::IntoArray;
140    use vortex_array::arrays::BoolArray;
141    use vortex_array::arrays::ConstantArray;
142    use vortex_array::assert_arrays_eq;
143    use vortex_array::builtins::ArrayBuiltins;
144    use vortex_array::dtype::Nullability::NonNullable;
145    use vortex_array::dtype::Nullability::Nullable;
146    use vortex_array::scalar_fn::fns::operators::Operator;
147
148    use crate::SequenceArray;
149
150    #[test]
151    fn test_compare_match() {
152        let lhs = SequenceArray::try_new_typed(2i64, 1, NonNullable, 4).unwrap();
153        let rhs = ConstantArray::new(4i64, lhs.len());
154        let result = lhs
155            .into_array()
156            .binary(rhs.into_array(), Operator::Eq)
157            .unwrap();
158        let expected = BoolArray::from_iter([false, false, true, false]);
159        assert_arrays_eq!(result, expected);
160    }
161
162    #[test]
163    fn test_compare_match_scale() {
164        let lhs = SequenceArray::try_new_typed(2i64, 3, Nullable, 4).unwrap();
165        let rhs = ConstantArray::new(8i64, lhs.len());
166        let result = lhs
167            .into_array()
168            .binary(rhs.into_array(), Operator::Eq)
169            .unwrap();
170        let expected = BoolArray::from_iter([Some(false), Some(false), Some(true), Some(false)]);
171        assert_arrays_eq!(result, expected);
172    }
173
174    #[test]
175    fn test_compare_no_match() {
176        let lhs = SequenceArray::try_new_typed(2i64, 1, NonNullable, 4).unwrap();
177        let rhs = ConstantArray::new(1i64, lhs.len());
178        let result = lhs
179            .into_array()
180            .binary(rhs.into_array(), Operator::Eq)
181            .unwrap();
182        let expected = BoolArray::from_iter([false, false, false, false]);
183        assert_arrays_eq!(result, expected);
184    }
185}