vortex_sequence/compute/
compare.rs1use vortex_array::arrays::{BoolArray, ConstantArray};
5use vortex_array::compute::{CompareKernel, Operator};
6use vortex_array::validity::Validity;
7use vortex_array::{Array, ArrayRef};
8use vortex_buffer::BitBuffer;
9use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
10use vortex_error::{VortexExpect, VortexResult};
11use vortex_scalar::{PValue, Scalar};
12
13use crate::SequenceArray;
14use crate::array::SequenceVTable;
15
16impl CompareKernel for SequenceVTable {
17 fn compare(
18 &self,
19 lhs: &SequenceArray,
20 rhs: &dyn Array,
21 operator: Operator,
22 ) -> VortexResult<Option<ArrayRef>> {
23 if operator != Operator::Eq {
24 return Ok(None);
25 };
26
27 let Some(constant) = rhs.as_constant() else {
28 return Ok(None);
29 };
30
31 let set_idx = find_intersection_scalar(
33 lhs.base(),
34 lhs.multiplier(),
35 lhs.len(),
36 constant
37 .as_primitive()
38 .pvalue()
39 .vortex_expect("non-null constant"),
40 );
41
42 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
43 let validity = match nullability {
44 Nullability::NonNullable => Validity::NonNullable,
45 Nullability::Nullable => Validity::AllValid,
46 };
47
48 if let Some(set_idx) = set_idx {
49 let buffer = BitBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
50 Ok(Some(
51 BoolArray::from_bit_buffer(buffer, validity).to_array(),
52 ))
53 } else {
54 Ok(Some(
55 ConstantArray::new(
56 Scalar::new(DType::Bool(nullability), false.into()),
57 lhs.len(),
58 )
59 .to_array(),
60 ))
61 }
62 }
63}
64
65pub(crate) fn find_intersection_scalar(
66 base: PValue,
67 multiplier: PValue,
68 len: usize,
69 intercept: PValue,
70) -> Option<usize> {
71 match_each_integer_ptype!(base.ptype(), |P| {
72 let intercept = intercept.cast::<P>();
73
74 let base = base.cast::<P>();
75 let multiplier = multiplier.cast::<P>();
76
77 find_intersection(base, multiplier, len, intercept)
78 })
79}
80
81fn find_intersection<P: NativePType>(
82 base: P,
83 multiplier: P,
84 len: usize,
85 intercept: P,
86) -> Option<usize> {
87 let count = <P>::from_usize(len - 1).vortex_expect("idx must fit into type");
89
90 let end_element = base + (multiplier * count);
91
92 (intercept.is_ge(base)
93 && intercept.is_le(end_element)
94 && (intercept - base) % multiplier == P::zero())
95 .then(|| ((intercept - base) / multiplier).to_usize())
96 .flatten()
97}
98
99#[cfg(test)]
100mod tests {
101 use vortex_array::ToCanonical;
102 use vortex_array::arrays::{BoolArray, ConstantArray};
103 use vortex_array::compute::{Operator, compare};
104 use vortex_dtype::Nullability::{NonNullable, Nullable};
105
106 use crate::SequenceArray;
107
108 #[test]
109 fn test_compare_match() {
110 let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
111
112 let rhs = ConstantArray::new(4i64, lhs.len());
113
114 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
115
116 assert_eq!(
117 result.to_bool().bit_buffer(),
118 BoolArray::from_iter(vec![false, false, true, false]).bit_buffer(),
119 )
120 }
121
122 #[test]
123 fn test_compare_match_scale() {
124 let lhs = SequenceArray::typed_new(2i64, 3, Nullable, 4).unwrap();
125
126 let rhs = ConstantArray::new(8i64, lhs.len());
127
128 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
129
130 assert_eq!(
131 result.to_bool().bit_buffer(),
132 BoolArray::from_iter(vec![false, false, true, false]).bit_buffer(),
133 )
134 }
135
136 #[test]
137 fn test_compare_no_match() {
138 let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
139
140 let rhs = ConstantArray::new(1i64, lhs.len());
141
142 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
143
144 assert_eq!(
145 result.to_bool().bit_buffer(),
146 BoolArray::from_iter(vec![false, false, false, false]).bit_buffer(),
147 )
148 }
149}