vortex_sequence/compute/
compare.rs1use vortex_array::arrays::{BoolArray, BooleanBuffer, ConstantArray};
5use vortex_array::compute::{CompareKernel, Operator};
6use vortex_array::validity::Validity;
7use vortex_array::{Array, ArrayRef};
8use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_scalar::{PValue, Scalar};
11
12use crate::SequenceArray;
13use crate::array::SequenceVTable;
14
15impl CompareKernel for SequenceVTable {
16 fn compare(
17 &self,
18 lhs: &SequenceArray,
19 rhs: &dyn Array,
20 operator: Operator,
21 ) -> VortexResult<Option<ArrayRef>> {
22 if operator != Operator::Eq {
23 return Ok(None);
24 };
25
26 let Some(constant) = rhs.as_constant() else {
27 return Ok(None);
28 };
29
30 let set_idx = find_intersection_scalar(
32 lhs.base(),
33 lhs.multiplier(),
34 lhs.len(),
35 constant
36 .as_primitive()
37 .pvalue()
38 .vortex_expect("non-null constant"),
39 );
40
41 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
42 let validity = match nullability {
43 Nullability::NonNullable => Validity::NonNullable,
44 Nullability::Nullable => Validity::AllValid,
45 };
46
47 if let Some(set_idx) = set_idx {
48 let buffer = BooleanBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
49 Ok(Some(BoolArray::new(buffer, validity).to_array()))
50 } else {
51 Ok(Some(
52 ConstantArray::new(
53 Scalar::new(DType::Bool(nullability), false.into()),
54 lhs.len(),
55 )
56 .to_array(),
57 ))
58 }
59 }
60}
61
62pub(crate) fn find_intersection_scalar(
63 base: PValue,
64 multiplier: PValue,
65 len: usize,
66 intercept: PValue,
67) -> Option<usize> {
68 match_each_integer_ptype!(base.ptype(), |P| {
69 let intercept = intercept.as_primitive::<P>();
70
71 let base = base.as_primitive::<P>();
72 let multiplier = multiplier.as_primitive::<P>();
73
74 find_intersection(base, multiplier, len, intercept)
75 })
76}
77
78fn find_intersection<P: NativePType>(
79 base: P,
80 multiplier: P,
81 len: usize,
82 intercept: P,
83) -> Option<usize> {
84 let count = <P>::from_usize(len - 1).vortex_expect("idx must fit into type");
86
87 let end_element = base + (multiplier * count);
88
89 (intercept.is_ge(base)
90 && intercept.is_le(end_element)
91 && (intercept - base) % multiplier == P::zero())
92 .then(|| ((intercept - base) / multiplier).to_usize())
93 .flatten()
94}
95
96#[cfg(test)]
97mod tests {
98 use vortex_array::ToCanonical;
99 use vortex_array::arrays::{BoolArray, ConstantArray};
100 use vortex_array::compute::{Operator, compare};
101 use vortex_dtype::Nullability::{NonNullable, Nullable};
102
103 use crate::SequenceArray;
104
105 #[test]
106 fn test_compare_match() {
107 let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
108
109 let rhs = ConstantArray::new(4i64, lhs.len());
110
111 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
112
113 assert_eq!(
114 result.to_bool().unwrap().boolean_buffer(),
115 BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
116 )
117 }
118
119 #[test]
120 fn test_compare_match_scale() {
121 let lhs = SequenceArray::typed_new(2i64, 3, Nullable, 4).unwrap();
122
123 let rhs = ConstantArray::new(8i64, lhs.len());
124
125 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
126
127 assert_eq!(
128 result.to_bool().unwrap().boolean_buffer(),
129 BoolArray::from_iter(vec![false, false, true, false]).boolean_buffer(),
130 )
131 }
132
133 #[test]
134 fn test_compare_no_match() {
135 let lhs = SequenceArray::typed_new(2i64, 1, NonNullable, 4).unwrap();
136
137 let rhs = ConstantArray::new(1i64, lhs.len());
138
139 let result = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
140
141 assert_eq!(
142 result.to_bool().unwrap().boolean_buffer(),
143 BoolArray::from_iter(vec![false, false, false, false]).boolean_buffer(),
144 )
145 }
146}