vortex_array/arrays/patched/compute/
compare.rs1use vortex_buffer::BitBufferMut;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::array::ArrayView;
13use crate::array::child_to_validity;
14use crate::arrays::BoolArray;
15use crate::arrays::ConstantArray;
16use crate::arrays::Patched;
17use crate::arrays::PrimitiveArray;
18use crate::arrays::bool::BoolDataParts;
19use crate::arrays::patched::PatchedArrayExt;
20use crate::arrays::patched::PatchedArraySlotsExt;
21use crate::arrays::primitive::NativeValue;
22use crate::builtins::ArrayBuiltins;
23use crate::dtype::NativePType;
24use crate::match_each_native_ptype;
25use crate::scalar_fn::fns::binary::CompareKernel;
26use crate::scalar_fn::fns::operators::CompareOperator;
27
28impl CompareKernel for Patched {
29 fn compare(
30 lhs: ArrayView<'_, Self>,
31 rhs: &ArrayRef,
32 operator: CompareOperator,
33 ctx: &mut ExecutionCtx,
34 ) -> VortexResult<Option<ArrayRef>> {
35 if !lhs.dtype().is_primitive() {
37 return Ok(None);
38 }
39
40 let Some(constant) = rhs.as_constant() else {
42 return Ok(None);
43 };
44
45 let result = lhs
48 .inner()
49 .binary(
50 ConstantArray::new(constant.clone(), lhs.len()).into_array(),
51 operator.into(),
52 )?
53 .execute::<Canonical>(ctx)?
54 .into_bool();
55
56 let validity = child_to_validity(result.slots()[0].as_ref(), result.dtype().nullability());
57 let len = result.len();
58 let BoolDataParts { bits, meta } = result.into_data().into_parts(len);
59
60 let mut bits =
61 BitBufferMut::from_buffer(bits.unwrap_host().into_mut(), meta.offset(), meta.len());
62
63 let lane_offsets = lhs.lane_offsets().clone().execute::<PrimitiveArray>(ctx)?;
64 let indices = lhs.patch_indices().clone().execute::<PrimitiveArray>(ctx)?;
65 let values = lhs.patch_values().clone().execute::<PrimitiveArray>(ctx)?;
66 let n_lanes = lhs.n_lanes();
67
68 match_each_native_ptype!(values.ptype(), |V| {
69 let offset = lhs.offset();
70 let indices = indices.as_slice::<u16>();
71 let values = values.as_slice::<V>();
72 let constant = constant
73 .as_primitive()
74 .as_::<V>()
75 .vortex_expect("compare constant not null");
76
77 let apply_patches = ApplyPatches {
78 bits: &mut bits,
79 offset,
80 n_lanes,
81 lane_offsets: lane_offsets.as_slice::<u32>(),
82 indices,
83 values,
84 constant,
85 };
86
87 match operator {
88 CompareOperator::Eq => {
89 apply_patches.apply(|l, r| NativeValue(l) == NativeValue(r))?;
90 }
91 CompareOperator::NotEq => {
92 apply_patches.apply(|l, r| NativeValue(l) != NativeValue(r))?;
93 }
94 CompareOperator::Gt => {
95 apply_patches.apply(|l, r| NativeValue(l) > NativeValue(r))?;
96 }
97 CompareOperator::Gte => {
98 apply_patches.apply(|l, r| NativeValue(l) >= NativeValue(r))?;
99 }
100 CompareOperator::Lt => {
101 apply_patches.apply(|l, r| NativeValue(l) < NativeValue(r))?;
102 }
103 CompareOperator::Lte => {
104 apply_patches.apply(|l, r| NativeValue(l) <= NativeValue(r))?;
105 }
106 }
107 });
108
109 let result = BoolArray::new(bits.freeze(), validity);
110 Ok(Some(result.into_array()))
111 }
112}
113
114struct ApplyPatches<'a, V: NativePType> {
115 bits: &'a mut BitBufferMut,
116 offset: usize,
117 n_lanes: usize,
118 lane_offsets: &'a [u32],
119 indices: &'a [u16],
120 values: &'a [V],
121 constant: V,
122}
123
124impl<V: NativePType> ApplyPatches<'_, V> {
125 fn apply<F>(self, cmp: F) -> VortexResult<()>
126 where
127 F: Fn(V, V) -> bool,
128 {
129 for index in 0..(self.lane_offsets.len() - 1) {
130 let chunk = index / self.n_lanes;
131
132 let lane_start = self.lane_offsets[index] as usize;
133 let lane_end = self.lane_offsets[index + 1] as usize;
134
135 for (&patch_index, &patch_value) in std::iter::zip(
136 &self.indices[lane_start..lane_end],
137 &self.values[lane_start..lane_end],
138 ) {
139 let bit_index = chunk * 1024 + patch_index as usize;
140 if bit_index < self.offset {
142 continue;
143 }
144 let bit_index = bit_index - self.offset;
145 if bit_index >= self.bits.len() {
146 break;
147 }
148 if cmp(patch_value, self.constant) {
149 self.bits.set(bit_index)
150 } else {
151 self.bits.unset(bit_index)
152 }
153 }
154 }
155
156 Ok(())
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use vortex_buffer::buffer;
163 use vortex_error::VortexResult;
164 use vortex_error::vortex_err;
165
166 use crate::ExecutionCtx;
167 use crate::IntoArray;
168 use crate::LEGACY_SESSION;
169 use crate::arrays::BoolArray;
170 use crate::arrays::ConstantArray;
171 use crate::arrays::Patched;
172 use crate::arrays::PrimitiveArray;
173 use crate::assert_arrays_eq;
174 use crate::optimizer::ArrayOptimizer;
175 use crate::patches::Patches;
176 use crate::scalar_fn::fns::binary::CompareKernel;
177 use crate::scalar_fn::fns::operators::CompareOperator;
178 use crate::validity::Validity;
179
180 #[test]
181 fn test_basic() {
182 let lhs = PrimitiveArray::from_iter(0u32..512).into_array();
183 let patches = Patches::new(
184 512,
185 0,
186 buffer![509u16, 510, 511].into_array(),
187 buffer![u32::MAX; 3].into_array(),
188 None,
189 )
190 .unwrap();
191
192 let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone());
193
194 let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx)
195 .unwrap()
196 .into_array()
197 .try_downcast::<Patched>()
198 .unwrap();
199
200 let rhs = ConstantArray::new(u32::MAX, 512).into_array();
201
202 let result =
203 <Patched as CompareKernel>::compare(lhs.as_view(), &rhs, CompareOperator::Eq, &mut ctx)
204 .unwrap()
205 .unwrap();
206
207 let expected =
208 BoolArray::from_indices(512, [509, 510, 511], Validity::NonNullable).into_array();
209
210 assert_arrays_eq!(expected, result);
211 }
212
213 #[test]
214 fn test_with_offset() {
215 let lhs = PrimitiveArray::from_iter(0u32..512).into_array();
216 let patches = Patches::new(
217 512,
218 0,
219 buffer![5u16, 510, 511].into_array(),
220 buffer![u32::MAX; 3].into_array(),
221 None,
222 )
223 .unwrap();
224
225 let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone());
226
227 let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx).unwrap();
228 let lhs_ref = lhs.into_array().slice(10..512).unwrap().optimize().unwrap();
230 let lhs = lhs_ref.try_downcast::<Patched>().unwrap();
231
232 assert_eq!(lhs.len(), 502);
233
234 let rhs = ConstantArray::new(u32::MAX, lhs.len()).into_array();
235
236 let result =
237 <Patched as CompareKernel>::compare(lhs.as_view(), &rhs, CompareOperator::Eq, &mut ctx)
238 .unwrap()
239 .unwrap();
240
241 let expected = BoolArray::from_indices(502, [500, 501], Validity::NonNullable).into_array();
242
243 assert_arrays_eq!(expected, result);
244 }
245
246 #[test]
247 fn test_subnormal_f32() -> VortexResult<()> {
248 let subnormal: f32 = f32::MIN_POSITIVE / 2.0;
250 assert!(subnormal > 0.0 && subnormal < f32::MIN_POSITIVE);
251
252 let lhs = PrimitiveArray::from_iter((0..512).map(|i| i as f32)).into_array();
253
254 let patches = Patches::new(
255 512,
256 0,
257 buffer![509u16, 510, 511].into_array(),
258 buffer![f32::NAN, subnormal, f32::NEG_INFINITY].into_array(),
259 None,
260 )?;
261
262 let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone());
263 let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx)?
264 .into_array()
265 .try_downcast::<Patched>()
266 .map_err(|_| vortex_err!("expected patched array"))?;
267
268 let rhs = ConstantArray::new(subnormal, 512).into_array();
269
270 let result = <Patched as CompareKernel>::compare(
271 lhs.as_view(),
272 &rhs,
273 CompareOperator::Eq,
274 &mut ctx,
275 )?
276 .ok_or_else(|| vortex_err!("expected compare result"))?;
277
278 let expected = BoolArray::from_indices(512, [510], Validity::NonNullable).into_array();
279
280 assert_arrays_eq!(expected, result);
281 Ok(())
282 }
283
284 #[test]
285 fn test_pos_neg_zero() -> VortexResult<()> {
286 let lhs = PrimitiveArray::from_iter([-0.0f32; 10]).into_array();
287
288 let patches = Patches::new(
289 10,
290 0,
291 buffer![5u16, 6, 7, 8, 9].into_array(),
292 buffer![f32::NAN, f32::NEG_INFINITY, 0f32, -0.0f32, f32::INFINITY].into_array(),
293 None,
294 )?;
295
296 let mut ctx = ExecutionCtx::new(LEGACY_SESSION.clone());
297 let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx)?
298 .into_array()
299 .try_downcast::<Patched>()
300 .map_err(|_| vortex_err!("expected patched array"))?;
301
302 let rhs = ConstantArray::new(0.0f32, 10).into_array();
303
304 let result = <Patched as CompareKernel>::compare(
305 lhs.as_view(),
306 &rhs,
307 CompareOperator::Eq,
308 &mut ctx,
309 )?
310 .ok_or_else(|| vortex_err!("expected compare result"))?;
311
312 let expected = BoolArray::from_indices(10, [7], Validity::NonNullable).into_array();
313
314 assert_arrays_eq!(expected, result);
315
316 Ok(())
317 }
318}