vortex_sequence/compute/
compare.rs1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
5use vortex_array::compute::{CompareKernel, Operator};
6use vortex_array::validity::Validity;
7use vortex_array::{Array, ArrayRef};
8use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::{PValue, Scalar};
11
12use crate::SequenceArray;
13use crate::array::SequenceVTable;
14
15impl CompareKernel for SequenceVTable {
16 fn compare(
17 &self,
18 lhs: &SequenceArray,
19 rhs: &dyn Array,
20 operator: Operator,
21 ) -> VortexResult<Option<ArrayRef>> {
22 if operator != Operator::Eq {
23 return Ok(None);
24 };
25
26 let Some(constant) = rhs.as_constant() else {
27 return Ok(None);
28 };
29
30 let set_idx = find_intersection_scalar(
32 lhs.base(),
33 lhs.multiplier(),
34 lhs.len(),
35 constant
36 .as_primitive()
37 .pvalue()
38 .vortex_expect("non-null constant"),
39 );
40
41 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
42 let validity = match nullability {
43 Nullability::NonNullable => Validity::NonNullable,
44 Nullability::Nullable => Validity::AllValid,
45 };
46
47 if let Some(set_idx) = set_idx {
48 let buffer = BooleanBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
49 Ok(Some(
50 BoolArray::from_bool_buffer(buffer, validity).to_array(),
51 ))
52 } else {
53 Ok(Some(
54 ConstantArray::new(
55 Scalar::new(DType::Bool(nullability), false.into()),
56 lhs.len(),
57 )
58 .to_array(),
59 ))
60 }
61 }
62}
63
64pub(crate) fn find_intersection_scalar(
65 base: PValue,
66 multiplier: PValue,
67 len: usize,
68 intercept: PValue,
69) -> Option<usize> {
70 match_each_integer_ptype!(base.ptype(), |P| {
71 let intercept = intercept.as_primitive::<P>();
72
73 let base = base.as_primitive::<P>();
74 let multiplier = multiplier.as_primitive::<P>();
75
76 find_intersection(base, multiplier, len, intercept)
77 })
78}
79
80fn find_intersection<P: NativePType>(
81 base: P,
82 multiplier: P,
83 len: usize,
84 intercept: P,
85) -> Option<usize> {
86 let count = <P>::from_usize(len - 1).vortex_expect("idx must fit into type");
88
89 let end_element = base + (multiplier * count);
90
91 (intercept.is_ge(base)
92 && intercept.is_le(end_element)
93 && (intercept - base) % multiplier == P::zero())
94 .then(|| ((intercept - base) / multiplier).to_usize())
95 .flatten()
96}
97
98#[cfg(test)]
99mod tests {
100 use vortex_array::ToCanonical;
101 use vortex_array::arrays::{BoolArray, ConstantArray};
102 use vortex_array::compute::{Operator, compare};
103 use vortex_dtype::Nullability::{NonNullable, Nullable};
104
105 use crate::SequenceArray;
106
107 #[test]
108 fn test_compare_match() {
109 let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
110
111 let rhs = ConstantArray::new(4i64, lhs.len());
112
113 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
114
115 assert_eq!(
116 result.to_bool().boolean_buffer(),
117 BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
118 )
119 }
120
121 #[test]
122 fn test_compare_match_scale() {
123 let lhs = SequenceArray::typed_new(2i64, 3, Nullable, 4).unwrap();
124
125 let rhs = ConstantArray::new(8i64, lhs.len());
126
127 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
128
129 assert_eq!(
130 result.to_bool().boolean_buffer(),
131 BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
132 )
133 }
134
135 #[test]
136 fn test_compare_no_match() {
137 let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
138
139 let rhs = ConstantArray::new(1i64, lhs.len());
140
141 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
142
143 assert_eq!(
144 result.to_bool().boolean_buffer(),
145 BoolArray::from_iter(vec![false, false, false, false]).boolean_buffer(),
146 )
147 }
148}