vortex_sequence/compute/
compare.rs1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
2use vortex_array::compute::{CompareKernel, Operator};
3use vortex_array::validity::Validity;
4use vortex_array::{Array, ArrayRef};
5use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
6use vortex_error::{VortexExpect, VortexResult};
7use vortex_scalar::{PValue, Scalar};
8
9use crate::SequenceArray;
10use crate::array::SequenceVTable;
11
12impl CompareKernel for SequenceVTable {
13 fn compare(
14 &self,
15 lhs: &SequenceArray,
16 rhs: &dyn Array,
17 operator: Operator,
18 ) -> VortexResult<Option<ArrayRef>> {
19 if operator != Operator::Eq {
20 return Ok(None);
21 };
22
23 let Some(constant) = rhs.as_constant() else {
24 return Ok(None);
25 };
26
27 let set_idx = find_intersection_scalar(
29 lhs.base(),
30 lhs.multiplier(),
31 lhs.len(),
32 constant
33 .as_primitive()
34 .pvalue()
35 .vortex_expect("non-null constant"),
36 );
37
38 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
39 let validity = match nullability {
40 Nullability::NonNullable => Validity::NonNullable,
41 Nullability::Nullable => Validity::AllValid,
42 };
43
44 if let Some(set_idx) = set_idx {
45 let buffer = BooleanBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
46 Ok(Some(BoolArray::new(buffer, validity).to_array()))
47 } else {
48 Ok(Some(
49 ConstantArray::new(
50 Scalar::new(DType::Bool(nullability), false.into()),
51 lhs.len(),
52 )
53 .to_array(),
54 ))
55 }
56 }
57}
58
59pub(crate) fn find_intersection_scalar(
60 base: PValue,
61 multiplier: PValue,
62 len: usize,
63 intercept: PValue,
64) -> Option<usize> {
65 match_each_integer_ptype!(base.ptype(), |P| {
66 let intercept = intercept
67 .as_primitive()
68 .vortex_expect("constant pvalue matching already validated");
69
70 let base = base
71 .as_primitive::<P>()
72 .vortex_expect("base pvalue matching already validated");
73 let multiplier = multiplier
74 .as_primitive::<P>()
75 .vortex_expect("multiplier pvalue matching already validated");
76
77 find_intersection(base, multiplier, len, intercept)
78 })
79}
80
81fn find_intersection<P: NativePType>(
82 base: P,
83 multiplier: P,
84 len: usize,
85 intercept: P,
86) -> Option<usize> {
87 let count = <P>::from_usize(len - 1).vortex_expect("idx must fit into type");
89
90 let end_element = base + (multiplier * count);
91
92 (intercept.is_ge(base)
93 && intercept.is_le(end_element)
94 && (intercept - base) % multiplier == P::zero())
95 .then(|| ((intercept - base) / multiplier).to_usize())
96 .flatten()
97}
98
99#[cfg(test)]
100mod tests {
101 use vortex_array::ToCanonical;
102 use vortex_array::arrays::{BoolArray, ConstantArray};
103 use vortex_array::compute::{Operator, compare};
104
105 use crate::SequenceArray;
106
107 #[test]
108 fn test_compare_match() {
109 let lhs = SequenceArray::typed_new(2i64, 1, 4).unwrap();
110
111 let rhs = ConstantArray::new(4i64, lhs.len());
112
113 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
114
115 assert_eq!(
116 result.to_bool().unwrap().boolean_buffer(),
117 BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
118 )
119 }
120
121 #[test]
122 fn test_compare_match_scale() {
123 let lhs = SequenceArray::typed_new(2i64, 3, 4).unwrap();
124
125 let rhs = ConstantArray::new(8i64, lhs.len());
126
127 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
128
129 assert_eq!(
130 result.to_bool().unwrap().boolean_buffer(),
131 BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
132 )
133 }
134
135 #[test]
136 fn test_compare_no_match() {
137 let lhs = SequenceArray::typed_new(2i64, 1, 4).unwrap();
138
139 let rhs = ConstantArray::new(1i64, lhs.len());
140
141 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
142
143 assert_eq!(
144 result.to_bool().unwrap().boolean_buffer(),
145 BoolArray::from_iter(vec![false, false, false, false]).boolean_buffer(),
146 )
147 }
148}