vortex_sequence/compute/
compare.rs1use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::BoolArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::dtype::NativePType;
11use vortex_array::dtype::Nullability;
12use vortex_array::match_each_integer_ptype;
13use vortex_array::scalar::PValue;
14use vortex_array::scalar::Scalar;
15use vortex_array::scalar_fn::fns::binary::CompareKernel;
16use vortex_array::scalar_fn::fns::operators::CompareOperator;
17use vortex_buffer::BitBuffer;
18use vortex_error::VortexExpect;
19use vortex_error::VortexResult;
20use vortex_error::vortex_bail;
21use vortex_error::vortex_err;
22
23use crate::array::Sequence;
24
25impl CompareKernel for Sequence {
26 fn compare(
27 lhs: ArrayView<'_, Self>,
28 rhs: &ArrayRef,
29 operator: CompareOperator,
30 _ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 if operator != CompareOperator::Eq {
34 return Ok(None);
35 }
36
37 let Some(constant) = rhs.as_constant() else {
38 return Ok(None);
39 };
40
41 let set_idx = find_intersection_scalar(
43 lhs.base(),
44 lhs.multiplier(),
45 lhs.len(),
46 constant
47 .as_primitive()
48 .pvalue()
49 .vortex_expect("null constant handled in adaptor"),
50 );
51
52 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
53 let validity = match nullability {
54 Nullability::NonNullable => vortex_array::validity::Validity::NonNullable,
55 Nullability::Nullable => vortex_array::validity::Validity::AllValid,
56 };
57
58 if let Ok(set_idx) = set_idx {
59 let buffer = BitBuffer::from_iter((0..lhs.len()).map(|idx| idx == set_idx));
60 Ok(Some(BoolArray::new(buffer, validity).into_array()))
61 } else {
62 Ok(Some(
63 ConstantArray::new(Scalar::bool(false, nullability), lhs.len()).into_array(),
64 ))
65 }
66 }
67}
68
69pub(crate) fn find_intersection_scalar(
78 base: PValue,
79 multiplier: PValue,
80 len: usize,
81 intercept: PValue,
82) -> VortexResult<usize> {
83 match_each_integer_ptype!(base.ptype(), |P| {
84 let intercept = intercept.cast::<P>()?;
85 let base = base.cast::<P>()?;
86 let multiplier = multiplier.cast::<P>()?;
87 find_intersection(base, multiplier, len, intercept)
88 })
89}
90
91fn find_intersection<P: NativePType>(
92 base: P,
93 multiplier: P,
94 len: usize,
95 intercept: P,
96) -> VortexResult<usize> {
97 if len == 0 {
98 vortex_bail!("len == 0")
99 }
100
101 let count = P::from_usize(len - 1).vortex_expect("idx must fit into type");
102 let end_element = base + (multiplier * count);
103
104 let (min_val, max_val) = if multiplier.is_ge(P::zero()) {
106 (base, end_element)
107 } else {
108 (end_element, base)
109 };
110
111 if !intercept.is_ge(min_val) || !intercept.is_le(max_val) {
113 vortex_bail!("{intercept} is outside of ({min_val}, {max_val}) range")
114 }
115
116 if multiplier == P::zero() {
118 if intercept == base {
119 return Ok(0);
120 } else {
121 vortex_bail!("{intercept} != {base} with zero multiplier")
122 }
123 }
124
125 let diff = intercept - base;
127 if diff % multiplier != P::zero() {
128 vortex_bail!("{diff} % {multiplier} != 0")
129 }
130
131 let idx = diff / multiplier;
132 idx.to_usize()
133 .ok_or_else(|| vortex_err!("Cannot represent {idx} as usize"))
134}
135
136#[cfg(test)]
137mod tests {
138 use vortex_array::IntoArray;
139 use vortex_array::arrays::BoolArray;
140 use vortex_array::arrays::ConstantArray;
141 use vortex_array::assert_arrays_eq;
142 use vortex_array::builtins::ArrayBuiltins;
143 use vortex_array::dtype::Nullability::NonNullable;
144 use vortex_array::dtype::Nullability::Nullable;
145 use vortex_array::scalar_fn::fns::operators::Operator;
146
147 use crate::Sequence;
148
149 #[test]
150 fn test_compare_match() {
151 let lhs = Sequence::try_new_typed(2i64, 1, NonNullable, 4).unwrap();
152 let rhs = ConstantArray::new(4i64, lhs.len());
153 let result = lhs
154 .into_array()
155 .binary(rhs.into_array(), Operator::Eq)
156 .unwrap();
157 let expected = BoolArray::from_iter([false, false, true, false]);
158 assert_arrays_eq!(result, expected);
159 }
160
161 #[test]
162 fn test_compare_match_scale() {
163 let lhs = Sequence::try_new_typed(2i64, 3, Nullable, 4).unwrap();
164 let rhs = ConstantArray::new(8i64, lhs.len());
165 let result = lhs
166 .into_array()
167 .binary(rhs.into_array(), Operator::Eq)
168 .unwrap();
169 let expected = BoolArray::from_iter([Some(false), Some(false), Some(true), Some(false)]);
170 assert_arrays_eq!(result, expected);
171 }
172
173 #[test]
174 fn test_compare_no_match() {
175 let lhs = Sequence::try_new_typed(2i64, 1, NonNullable, 4).unwrap();
176 let rhs = ConstantArray::new(1i64, lhs.len());
177 let result = lhs
178 .into_array()
179 .binary(rhs.into_array(), Operator::Eq)
180 .unwrap();
181 let expected = BoolArray::from_iter([false, false, false, false]);
182 assert_arrays_eq!(result, expected);
183 }
184}