vortex_sequence/compute/
compare.rs1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
5use vortex_array::compute::{CompareKernel, Operator};
6use vortex_array::validity::Validity;
7use vortex_array::{Array, ArrayRef};
8use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::{PValue, Scalar};
11
12use crate::SequenceArray;
13use crate::array::SequenceVTable;
14
15impl CompareKernel for SequenceVTable {
16 fn compare(
17 &self,
18 lhs: &SequenceArray,
19 rhs: &dyn Array,
20 operator: Operator,
21 ) -> VortexResult<Option<ArrayRef>> {
22 if operator != Operator::Eq {
23 return Ok(None);
24 };
25
26 let Some(constant) = rhs.as_constant() else {
27 return Ok(None);
28 };
29
30 let set_idx = find_intersection_scalar(
32 lhs.base(),
33 lhs.multiplier(),
34 lhs.len(),
35 constant
36 .as_primitive()
37 .pvalue()
38 .vortex_expect("non-null constant"),
39 );
40
41 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
42 let validity = match nullability {
43 Nullability::NonNullable => Validity::NonNullable,
44 Nullability::Nullable => Validity::AllValid,
45 };
46
47 if let Some(set_idx) = set_idx {
48 let buffer = BooleanBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
49 Ok(Some(BoolArray::new(buffer, validity).to_array()))
50 } else {
51 Ok(Some(
52 ConstantArray::new(
53 Scalar::new(DType::Bool(nullability), false.into()),
54 lhs.len(),
55 )
56 .to_array(),
57 ))
58 }
59 }
60}
61
62pub(crate) fn find_intersection_scalar(
63 base: PValue,
64 multiplier: PValue,
65 len: usize,
66 intercept: PValue,
67) -> Option<usize> {
68 match_each_integer_ptype!(base.ptype(), |P| {
69 let intercept = intercept
70 .as_primitive()
71 .vortex_expect("constant pvalue matching already validated");
72
73 let base = base
74 .as_primitive::<P>()
75 .vortex_expect("base pvalue matching already validated");
76 let multiplier = multiplier
77 .as_primitive::<P>()
78 .vortex_expect("multiplier pvalue matching already validated");
79
80 find_intersection(base, multiplier, len, intercept)
81 })
82}
83
84fn find_intersection<P: NativePType>(
85 base: P,
86 multiplier: P,
87 len: usize,
88 intercept: P,
89) -> Option<usize> {
90 let count = <P>::from_usize(len - 1).vortex_expect("idx must fit into type");
92
93 let end_element = base + (multiplier * count);
94
95 (intercept.is_ge(base)
96 && intercept.is_le(end_element)
97 && (intercept - base) % multiplier == P::zero())
98 .then(|| ((intercept - base) / multiplier).to_usize())
99 .flatten()
100}
101
102#[cfg(test)]
103mod tests {
104 use vortex_array::ToCanonical;
105 use vortex_array::arrays::{BoolArray, ConstantArray};
106 use vortex_array::compute::{Operator, compare};
107 use vortex_dtype::Nullability::{NonNullable, Nullable};
108
109 use crate::SequenceArray;
110
111 #[test]
112 fn test_compare_match() {
113 let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
114
115 let rhs = ConstantArray::new(4i64, lhs.len());
116
117 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
118
119 assert_eq!(
120 result.to_bool().unwrap().boolean_buffer(),
121 BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
122 )
123 }
124
125 #[test]
126 fn test_compare_match_scale() {
127 let lhs = SequenceArray::typed_new(2i64, 3, Nullable, 4).unwrap();
128
129 let rhs = ConstantArray::new(8i64, lhs.len());
130
131 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
132
133 assert_eq!(
134 result.to_bool().unwrap().boolean_buffer(),
135 BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
136 )
137 }
138
139 #[test]
140 fn test_compare_no_match() {
141 let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
142
143 let rhs = ConstantArray::new(1i64, lhs.len());
144
145 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
146
147 assert_eq!(
148 result.to_bool().unwrap().boolean_buffer(),
149 BoolArray::from_iter(vec![false, false, false, false]).boolean_buffer(),
150 )
151 }
152}