vortex_array/arrays/patched/compute/
compare.rs1use vortex_buffer::BitBufferMut;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::array::ArrayView;
13use crate::array::child_to_validity;
14use crate::arrays::BoolArray;
15use crate::arrays::ConstantArray;
16use crate::arrays::Patched;
17use crate::arrays::PrimitiveArray;
18use crate::arrays::bool::BoolDataParts;
19use crate::arrays::patched::PatchedArrayExt;
20use crate::arrays::patched::PatchedArraySlotsExt;
21use crate::arrays::primitive::NativeValue;
22use crate::builtins::ArrayBuiltins;
23use crate::dtype::NativePType;
24use crate::match_each_native_ptype;
25use crate::scalar_fn::fns::binary::CompareKernel;
26use crate::scalar_fn::fns::operators::CompareOperator;
27
28impl CompareKernel for Patched {
29 fn compare(
30 lhs: ArrayView<'_, Self>,
31 rhs: &ArrayRef,
32 operator: CompareOperator,
33 ctx: &mut ExecutionCtx,
34 ) -> VortexResult<Option<ArrayRef>> {
35 if !lhs.dtype().is_primitive() {
37 return Ok(None);
38 }
39
40 let Some(constant) = rhs.as_constant() else {
42 return Ok(None);
43 };
44
45 let result = lhs
48 .inner()
49 .binary(
50 ConstantArray::new(constant.clone(), lhs.len()).into_array(),
51 operator.into(),
52 )?
53 .execute::<Canonical>(ctx)?
54 .into_bool();
55
56 let validity = child_to_validity(&result.slots()[0], result.dtype().nullability());
57 let len = result.len();
58 let BoolDataParts { bits, offset, len } = result.into_data().into_parts(len);
59
60 let mut bits = BitBufferMut::from_buffer(bits.unwrap_host().into_mut(), offset, len);
61
62 let lane_offsets = lhs.lane_offsets().clone().execute::<PrimitiveArray>(ctx)?;
63 let indices = lhs.patch_indices().clone().execute::<PrimitiveArray>(ctx)?;
64 let values = lhs.patch_values().clone().execute::<PrimitiveArray>(ctx)?;
65 let n_lanes = lhs.n_lanes();
66
67 match_each_native_ptype!(values.ptype(), |V| {
68 let offset = lhs.offset();
69 let indices = indices.as_slice::<u16>();
70 let values = values.as_slice::<V>();
71 let constant = constant
72 .as_primitive()
73 .as_::<V>()
74 .vortex_expect("compare constant not null");
75
76 let apply_patches = ApplyPatches {
77 bits: &mut bits,
78 offset,
79 n_lanes,
80 lane_offsets: lane_offsets.as_slice::<u32>(),
81 indices,
82 values,
83 constant,
84 };
85
86 match operator {
87 CompareOperator::Eq => {
88 apply_patches.apply(|l, r| NativeValue(l) == NativeValue(r))?;
89 }
90 CompareOperator::NotEq => {
91 apply_patches.apply(|l, r| NativeValue(l) != NativeValue(r))?;
92 }
93 CompareOperator::Gt => {
94 apply_patches.apply(|l, r| NativeValue(l) > NativeValue(r))?;
95 }
96 CompareOperator::Gte => {
97 apply_patches.apply(|l, r| NativeValue(l) >= NativeValue(r))?;
98 }
99 CompareOperator::Lt => {
100 apply_patches.apply(|l, r| NativeValue(l) < NativeValue(r))?;
101 }
102 CompareOperator::Lte => {
103 apply_patches.apply(|l, r| NativeValue(l) <= NativeValue(r))?;
104 }
105 }
106 });
107
108 let result = BoolArray::new(bits.freeze(), validity);
109 Ok(Some(result.into_array()))
110 }
111}
112
113struct ApplyPatches<'a, V: NativePType> {
114 bits: &'a mut BitBufferMut,
115 offset: usize,
116 n_lanes: usize,
117 lane_offsets: &'a [u32],
118 indices: &'a [u16],
119 values: &'a [V],
120 constant: V,
121}
122
123impl<V: NativePType> ApplyPatches<'_, V> {
124 fn apply<F>(self, cmp: F) -> VortexResult<()>
125 where
126 F: Fn(V, V) -> bool,
127 {
128 for index in 0..(self.lane_offsets.len() - 1) {
129 let chunk = index / self.n_lanes;
130
131 let lane_start = self.lane_offsets[index] as usize;
132 let lane_end = self.lane_offsets[index + 1] as usize;
133
134 for (&patch_index, &patch_value) in std::iter::zip(
135 &self.indices[lane_start..lane_end],
136 &self.values[lane_start..lane_end],
137 ) {
138 let bit_index = chunk * 1024 + patch_index as usize;
139 if bit_index < self.offset {
141 continue;
142 }
143 let bit_index = bit_index - self.offset;
144 if bit_index >= self.bits.len() {
145 break;
146 }
147 if cmp(patch_value, self.constant) {
148 self.bits.set(bit_index)
149 } else {
150 self.bits.unset(bit_index)
151 }
152 }
153 }
154
155 Ok(())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use vortex_buffer::buffer;
162 use vortex_error::VortexResult;
163
164 use crate::ExecutionCtx;
165 use crate::IntoArray;
166 use crate::LEGACY_SESSION;
167 use crate::arrays::BoolArray;
168 use crate::arrays::ConstantArray;
169 use crate::arrays::Patched;
170 use crate::arrays::PrimitiveArray;
171 use crate::assert_arrays_eq;
172 use crate::optimizer::ArrayOptimizer;
173 use crate::patches::Patches;
174 use crate::scalar_fn::fns::binary::CompareKernel;
175 use crate::scalar_fn::fns::operators::CompareOperator;
176 use crate::validity::Validity;
177
178 #[test]
179 fn test_basic() {
180 let lhs = PrimitiveArray::from_iter(0u32..512).into_array();
181 let patches = Patches::new(
182 512,
183 0,
184 buffer![509u16, 510, 511].into_array(),
185 buffer![u32::MAX; 3].into_array(),
186 None,
187 )
188 .unwrap();
189
190 let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone());
191
192 let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx)
193 .unwrap()
194 .into_array()
195 .try_downcast::<Patched>()
196 .unwrap();
197
198 let rhs = ConstantArray::new(u32::MAX, 512).into_array();
199
200 let result =
201 <Patched as CompareKernel>::compare(lhs.as_view(), &rhs, CompareOperator::Eq, &mut ctx)
202 .unwrap()
203 .unwrap();
204
205 let expected =
206 BoolArray::from_indices(512, [509, 510, 511], Validity::NonNullable).into_array();
207
208 assert_arrays_eq!(expected, result);
209 }
210
211 #[test]
212 fn test_with_offset() {
213 let lhs = PrimitiveArray::from_iter(0u32..512).into_array();
214 let patches = Patches::new(
215 512,
216 0,
217 buffer![5u16, 510, 511].into_array(),
218 buffer![u32::MAX; 3].into_array(),
219 None,
220 )
221 .unwrap();
222
223 let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone());
224
225 let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx).unwrap();
226 let lhs_ref = lhs.into_array().slice(10..512).unwrap().optimize().unwrap();
228 let lhs = lhs_ref.try_downcast::<Patched>().unwrap();
229
230 assert_eq!(lhs.len(), 502);
231
232 let rhs = ConstantArray::new(u32::MAX, lhs.len()).into_array();
233
234 let result =
235 <Patched as CompareKernel>::compare(lhs.as_view(), &rhs, CompareOperator::Eq, &mut ctx)
236 .unwrap()
237 .unwrap();
238
239 let expected = BoolArray::from_indices(502, [500, 501], Validity::NonNullable).into_array();
240
241 assert_arrays_eq!(expected, result);
242 }
243
244 #[test]
245 fn test_subnormal_f32() -> VortexResult<()> {
246 let subnormal: f32 = f32::MIN_POSITIVE / 2.0;
248 assert!(subnormal > 0.0 && subnormal < f32::MIN_POSITIVE);
249
250 let lhs = PrimitiveArray::from_iter((0..512).map(|i| i as f32)).into_array();
251
252 let patches = Patches::new(
253 512,
254 0,
255 buffer![509u16, 510, 511].into_array(),
256 buffer![f32::NAN, subnormal, f32::NEG_INFINITY].into_array(),
257 None,
258 )?;
259
260 let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone());
261 let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx)?
262 .into_array()
263 .try_downcast::<Patched>()
264 .unwrap();
265
266 let rhs = ConstantArray::new(subnormal, 512).into_array();
267
268 let result = <Patched as CompareKernel>::compare(
269 lhs.as_view(),
270 &rhs,
271 CompareOperator::Eq,
272 &mut ctx,
273 )?
274 .unwrap();
275
276 let expected = BoolArray::from_indices(512, [510], Validity::NonNullable).into_array();
277
278 assert_arrays_eq!(expected, result);
279 Ok(())
280 }
281
282 #[test]
283 fn test_pos_neg_zero() -> VortexResult<()> {
284 let lhs = PrimitiveArray::from_iter([-0.0f32; 10]).into_array();
285
286 let patches = Patches::new(
287 10,
288 0,
289 buffer![5u16, 6, 7, 8, 9].into_array(),
290 buffer![f32::NAN, f32::NEG_INFINITY, 0f32, -0.0f32, f32::INFINITY].into_array(),
291 None,
292 )?;
293
294 let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone());
295 let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx)?
296 .into_array()
297 .try_downcast::<Patched>()
298 .unwrap();
299
300 let rhs = ConstantArray::new(0.0f32, 10).into_array();
301
302 let result = <Patched as CompareKernel>::compare(
303 lhs.as_view(),
304 &rhs,
305 CompareOperator::Eq,
306 &mut ctx,
307 )?
308 .unwrap();
309
310 let expected = BoolArray::from_indices(10, [7], Validity::NonNullable).into_array();
311
312 assert_arrays_eq!(expected, result);
313
314 Ok(())
315 }
316}