Skip to main content

vortex_array/arrays/patched/compute/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBufferMut;
5use vortex_error::VortexExpect;
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::array::ArrayView;
13use crate::array::child_to_validity;
14use crate::arrays::BoolArray;
15use crate::arrays::ConstantArray;
16use crate::arrays::Patched;
17use crate::arrays::PrimitiveArray;
18use crate::arrays::bool::BoolDataParts;
19use crate::arrays::patched::PatchedArrayExt;
20use crate::arrays::patched::PatchedArraySlotsExt;
21use crate::arrays::primitive::NativeValue;
22use crate::builtins::ArrayBuiltins;
23use crate::dtype::NativePType;
24use crate::match_each_native_ptype;
25use crate::scalar_fn::fns::binary::CompareKernel;
26use crate::scalar_fn::fns::operators::CompareOperator;
27
28impl CompareKernel for Patched {
29    fn compare(
30        lhs: ArrayView<'_, Self>,
31        rhs: &ArrayRef,
32        operator: CompareOperator,
33        ctx: &mut ExecutionCtx,
34    ) -> VortexResult<Option<ArrayRef>> {
35        // We only accelerate comparisons for primitives
36        if !lhs.dtype().is_primitive() {
37            return Ok(None);
38        }
39
40        // We only accelerate comparisons against constants
41        let Some(constant) = rhs.as_constant() else {
42            return Ok(None);
43        };
44
45        // NOTE: due to offset, it's possible that the inner.len != array.len.
46        //  We slice the inner before performing the comparison.
47        let result = lhs
48            .inner()
49            .binary(
50                ConstantArray::new(constant.clone(), lhs.len()).into_array(),
51                operator.into(),
52            )?
53            .execute::<Canonical>(ctx)?
54            .into_bool();
55
56        let validity = child_to_validity(result.slots()[0].as_ref(), result.dtype().nullability());
57        let len = result.len();
58        let BoolDataParts { bits, meta } = result.into_data().into_parts(len);
59
60        let mut bits =
61            BitBufferMut::from_buffer(bits.unwrap_host().into_mut(), meta.offset(), meta.len());
62
63        let lane_offsets = lhs.lane_offsets().clone().execute::<PrimitiveArray>(ctx)?;
64        let indices = lhs.patch_indices().clone().execute::<PrimitiveArray>(ctx)?;
65        let values = lhs.patch_values().clone().execute::<PrimitiveArray>(ctx)?;
66        let n_lanes = lhs.n_lanes();
67
68        match_each_native_ptype!(values.ptype(), |V| {
69            let offset = lhs.offset();
70            let indices = indices.as_slice::<u16>();
71            let values = values.as_slice::<V>();
72            let constant = constant
73                .as_primitive()
74                .as_::<V>()
75                .vortex_expect("compare constant not null");
76
77            let apply_patches = ApplyPatches {
78                bits: &mut bits,
79                offset,
80                n_lanes,
81                lane_offsets: lane_offsets.as_slice::<u32>(),
82                indices,
83                values,
84                constant,
85            };
86
87            match operator {
88                CompareOperator::Eq => {
89                    apply_patches.apply(|l, r| NativeValue(l) == NativeValue(r))?;
90                }
91                CompareOperator::NotEq => {
92                    apply_patches.apply(|l, r| NativeValue(l) != NativeValue(r))?;
93                }
94                CompareOperator::Gt => {
95                    apply_patches.apply(|l, r| NativeValue(l) > NativeValue(r))?;
96                }
97                CompareOperator::Gte => {
98                    apply_patches.apply(|l, r| NativeValue(l) >= NativeValue(r))?;
99                }
100                CompareOperator::Lt => {
101                    apply_patches.apply(|l, r| NativeValue(l) < NativeValue(r))?;
102                }
103                CompareOperator::Lte => {
104                    apply_patches.apply(|l, r| NativeValue(l) <= NativeValue(r))?;
105                }
106            }
107        });
108
109        let result = BoolArray::new(bits.freeze(), validity);
110        Ok(Some(result.into_array()))
111    }
112}
113
114struct ApplyPatches<'a, V: NativePType> {
115    bits: &'a mut BitBufferMut,
116    offset: usize,
117    n_lanes: usize,
118    lane_offsets: &'a [u32],
119    indices: &'a [u16],
120    values: &'a [V],
121    constant: V,
122}
123
124impl<V: NativePType> ApplyPatches<'_, V> {
125    fn apply<F>(self, cmp: F) -> VortexResult<()>
126    where
127        F: Fn(V, V) -> bool,
128    {
129        for index in 0..(self.lane_offsets.len() - 1) {
130            let chunk = index / self.n_lanes;
131
132            let lane_start = self.lane_offsets[index] as usize;
133            let lane_end = self.lane_offsets[index + 1] as usize;
134
135            for (&patch_index, &patch_value) in std::iter::zip(
136                &self.indices[lane_start..lane_end],
137                &self.values[lane_start..lane_end],
138            ) {
139                let bit_index = chunk * 1024 + patch_index as usize;
140                // Skip any indices < the offset.
141                if bit_index < self.offset {
142                    continue;
143                }
144                let bit_index = bit_index - self.offset;
145                if bit_index >= self.bits.len() {
146                    break;
147                }
148                if cmp(patch_value, self.constant) {
149                    self.bits.set(bit_index)
150                } else {
151                    self.bits.unset(bit_index)
152                }
153            }
154        }
155
156        Ok(())
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use vortex_buffer::buffer;
163    use vortex_error::VortexResult;
164    use vortex_error::vortex_err;
165
166    use crate::IntoArray;
167    use crate::VortexSessionExecute;
168    use crate::arrays::BoolArray;
169    use crate::arrays::ConstantArray;
170    use crate::arrays::Patched;
171    use crate::arrays::PrimitiveArray;
172    use crate::assert_arrays_eq;
173    use crate::optimizer::ArrayOptimizer;
174    use crate::patches::Patches;
175    use crate::scalar_fn::fns::binary::CompareKernel;
176    use crate::scalar_fn::fns::operators::CompareOperator;
177    use crate::validity::Validity;
178
179    #[test]
180    fn test_basic() {
181        let lhs = PrimitiveArray::from_iter(0u32..512).into_array();
182        let patches = Patches::new(
183            512,
184            0,
185            buffer![509u16, 510, 511].into_array(),
186            buffer![u32::MAX; 3].into_array(),
187            None,
188        )
189        .unwrap();
190
191        let mut ctx = crate::array_session().create_execution_ctx();
192
193        let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx)
194            .unwrap()
195            .into_array()
196            .try_downcast::<Patched>()
197            .unwrap();
198
199        let rhs = ConstantArray::new(u32::MAX, 512).into_array();
200
201        let result =
202            <Patched as CompareKernel>::compare(lhs.as_view(), &rhs, CompareOperator::Eq, &mut ctx)
203                .unwrap()
204                .unwrap();
205
206        let expected =
207            BoolArray::from_indices(512, [509, 510, 511], Validity::NonNullable).into_array();
208
209        assert_arrays_eq!(expected, result, &mut ctx);
210    }
211
212    #[test]
213    fn test_with_offset() {
214        let lhs = PrimitiveArray::from_iter(0u32..512).into_array();
215        let patches = Patches::new(
216            512,
217            0,
218            buffer![5u16, 510, 511].into_array(),
219            buffer![u32::MAX; 3].into_array(),
220            None,
221        )
222        .unwrap();
223
224        let mut ctx = crate::array_session().create_execution_ctx();
225
226        let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx).unwrap();
227        // Slice the array so that the first patch should be skipped.
228        let lhs_ref = lhs.into_array().slice(10..512).unwrap().optimize().unwrap();
229        let lhs = lhs_ref.try_downcast::<Patched>().unwrap();
230
231        assert_eq!(lhs.len(), 502);
232
233        let rhs = ConstantArray::new(u32::MAX, lhs.len()).into_array();
234
235        let result =
236            <Patched as CompareKernel>::compare(lhs.as_view(), &rhs, CompareOperator::Eq, &mut ctx)
237                .unwrap()
238                .unwrap();
239
240        let expected = BoolArray::from_indices(502, [500, 501], Validity::NonNullable).into_array();
241
242        assert_arrays_eq!(expected, result, &mut ctx);
243    }
244
245    #[test]
246    fn test_subnormal_f32() -> VortexResult<()> {
247        // Subnormal f32 values are smaller than f32::MIN_POSITIVE but greater than 0
248        let subnormal: f32 = f32::MIN_POSITIVE / 2.0;
249        assert!(subnormal > 0.0 && subnormal < f32::MIN_POSITIVE);
250
251        let lhs = PrimitiveArray::from_iter((0..512).map(|i| i as f32)).into_array();
252
253        let patches = Patches::new(
254            512,
255            0,
256            buffer![509u16, 510, 511].into_array(),
257            buffer![f32::NAN, subnormal, f32::NEG_INFINITY].into_array(),
258            None,
259        )?;
260
261        let mut ctx = crate::array_session().create_execution_ctx();
262        let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx)?
263            .into_array()
264            .try_downcast::<Patched>()
265            .map_err(|_| vortex_err!("expected patched array"))?;
266
267        let rhs = ConstantArray::new(subnormal, 512).into_array();
268
269        let result = <Patched as CompareKernel>::compare(
270            lhs.as_view(),
271            &rhs,
272            CompareOperator::Eq,
273            &mut ctx,
274        )?
275        .ok_or_else(|| vortex_err!("expected compare result"))?;
276
277        let expected = BoolArray::from_indices(512, [510], Validity::NonNullable).into_array();
278
279        assert_arrays_eq!(expected, result, &mut ctx);
280        Ok(())
281    }
282
283    #[test]
284    fn test_pos_neg_zero() -> VortexResult<()> {
285        let lhs = PrimitiveArray::from_iter([-0.0f32; 10]).into_array();
286
287        let patches = Patches::new(
288            10,
289            0,
290            buffer![5u16, 6, 7, 8, 9].into_array(),
291            buffer![f32::NAN, f32::NEG_INFINITY, 0f32, -0.0f32, f32::INFINITY].into_array(),
292            None,
293        )?;
294
295        let mut ctx = crate::array_session().create_execution_ctx();
296        let lhs = Patched::from_array_and_patches(lhs, &patches, &mut ctx)?
297            .into_array()
298            .try_downcast::<Patched>()
299            .map_err(|_| vortex_err!("expected patched array"))?;
300
301        let rhs = ConstantArray::new(0.0f32, 10).into_array();
302
303        let result = <Patched as CompareKernel>::compare(
304            lhs.as_view(),
305            &rhs,
306            CompareOperator::Eq,
307            &mut ctx,
308        )?
309        .ok_or_else(|| vortex_err!("expected compare result"))?;
310
311        let expected = BoolArray::from_indices(10, [7], Validity::NonNullable).into_array();
312
313        assert_arrays_eq!(expected, result, &mut ctx);
314
315        Ok(())
316    }
317}