vortex_sequence/compute/
compare.rs1use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::arrays::BoolArray;
7use vortex_array::arrays::ConstantArray;
8use vortex_array::compute::CompareKernel;
9use vortex_array::compute::Operator;
10use vortex_array::validity::Validity;
11use vortex_buffer::BitBuffer;
12use vortex_dtype::DType;
13use vortex_dtype::NativePType;
14use vortex_dtype::Nullability;
15use vortex_dtype::match_each_integer_ptype;
16use vortex_error::VortexExpect;
17use vortex_error::VortexResult;
18use vortex_scalar::PValue;
19use vortex_scalar::Scalar;
20
21use crate::SequenceArray;
22use crate::array::SequenceVTable;
23
24impl CompareKernel for SequenceVTable {
25 fn compare(
26 &self,
27 lhs: &SequenceArray,
28 rhs: &dyn Array,
29 operator: Operator,
30 ) -> VortexResult<Option<ArrayRef>> {
31 if operator != Operator::Eq {
32 return Ok(None);
33 };
34
35 let Some(constant) = rhs.as_constant() else {
36 return Ok(None);
37 };
38
39 let set_idx = find_intersection_scalar(
41 lhs.base(),
42 lhs.multiplier(),
43 lhs.len(),
44 constant
45 .as_primitive()
46 .pvalue()
47 .vortex_expect("non-null constant"),
48 );
49
50 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
51 let validity = match nullability {
52 Nullability::NonNullable => Validity::NonNullable,
53 Nullability::Nullable => Validity::AllValid,
54 };
55
56 if let Some(set_idx) = set_idx {
57 let buffer = BitBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
58 Ok(Some(
59 BoolArray::from_bit_buffer(buffer, validity).to_array(),
60 ))
61 } else {
62 Ok(Some(
63 ConstantArray::new(
64 Scalar::new(DType::Bool(nullability), false.into()),
65 lhs.len(),
66 )
67 .to_array(),
68 ))
69 }
70 }
71}
72
73pub(crate) fn find_intersection_scalar(
74 base: PValue,
75 multiplier: PValue,
76 len: usize,
77 intercept: PValue,
78) -> Option<usize> {
79 match_each_integer_ptype!(base.ptype(), |P| {
80 let intercept = intercept.cast::<P>();
81
82 let base = base.cast::<P>();
83 let multiplier = multiplier.cast::<P>();
84
85 find_intersection(base, multiplier, len, intercept)
86 })
87}
88
89fn find_intersection<P: NativePType>(
90 base: P,
91 multiplier: P,
92 len: usize,
93 intercept: P,
94) -> Option<usize> {
95 let count = <P>::from_usize(len - 1).vortex_expect("idx must fit into type");
97
98 let end_element = base + (multiplier * count);
99
100 (intercept.is_ge(base)
101 && intercept.is_le(end_element)
102 && (intercept - base) % multiplier == P::zero())
103 .then(|| ((intercept - base) / multiplier).to_usize())
104 .flatten()
105}
106
107#[cfg(test)]
108mod tests {
109 use vortex_array::ToCanonical;
110 use vortex_array::arrays::BoolArray;
111 use vortex_array::arrays::ConstantArray;
112 use vortex_array::compute::Operator;
113 use vortex_array::compute::compare;
114 use vortex_dtype::Nullability::NonNullable;
115 use vortex_dtype::Nullability::Nullable;
116
117 use crate::SequenceArray;
118
119 #[test]
120 fn test_compare_match() {
121 let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
122
123 let rhs = ConstantArray::new(4i64, lhs.len());
124
125 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
126
127 assert_eq!(
128 result.to_bool().bit_buffer(),
129 BoolArray::from_iter(vec![false, false, true, false]).bit_buffer(),
130 )
131 }
132
133 #[test]
134 fn test_compare_match_scale() {
135 let lhs = SequenceArray::typed_new(2i64, 3, Nullable, 4).unwrap();
136
137 let rhs = ConstantArray::new(8i64, lhs.len());
138
139 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
140
141 assert_eq!(
142 result.to_bool().bit_buffer(),
143 BoolArray::from_iter(vec![false, false, true, false]).bit_buffer(),
144 )
145 }
146
147 #[test]
148 fn test_compare_no_match() {
149 let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
150
151 let rhs = ConstantArray::new(1i64, lhs.len());
152
153 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
154
155 assert_eq!(
156 result.to_bool().bit_buffer(),
157 BoolArray::from_iter(vec![false, false, false, false]).bit_buffer(),
158 )
159 }
160}