vortex_array/pipeline/operators/
scalar_compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::marker::PhantomData;
6use std::rc::Rc;
7
8use vortex_dtype::{NativePType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail};
10use vortex_scalar::Scalar;
11
12use crate::compute::Operator as BinaryOperator;
13use crate::match_each_compare_op;
14use crate::pipeline::bits::BitView;
15use crate::pipeline::operators::BindContext;
16use crate::pipeline::operators::compare::CompareOp;
17use crate::pipeline::types::{Element, VType};
18use crate::pipeline::vec::VectorId;
19use crate::pipeline::view::ViewMut;
20use crate::pipeline::{Kernel, KernelContext, Operator};
21
22/// Pipeline operator for comparing an array against a scalar value.
23#[derive(Debug, Hash)]
24pub struct ScalarCompareOperator {
25    children: [Rc<dyn Operator>; 1],
26    pub op: BinaryOperator,
27    pub scalar: Scalar,
28}
29
30impl ScalarCompareOperator {
31    pub fn new(child: Rc<dyn Operator>, op: BinaryOperator, scalar: Scalar) -> Self {
32        assert_eq!(child.vtype(), VType::Primitive(scalar.dtype().as_ptype()));
33        Self {
34            children: [child],
35            op,
36            scalar,
37        }
38    }
39}
40
41impl Operator for ScalarCompareOperator {
42    fn as_any(&self) -> &dyn Any {
43        self
44    }
45
46    fn children(&self) -> &[Rc<dyn Operator>] {
47        &self.children
48    }
49
50    fn vtype(&self) -> VType {
51        VType::Bool
52    }
53
54    fn bind(&self, ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
55        match self.children[0].vtype() {
56            VType::Primitive(ptype) => {
57                match_each_native_ptype!(ptype, |T| {
58                    match_each_compare_op!(self.op, |Op| {
59                        Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
60                            lhs: ctx.children()[0],
61                            rhs: self
62                                .scalar
63                                .as_primitive()
64                                .typed_value::<T>()
65                                .vortex_expect("scalar value not of type T"),
66                            _phantom: PhantomData,
67                        }) as Box<dyn Kernel>)
68                    })
69                })
70            }
71            _ => vortex_bail!(
72                "Unsupported type for comparison: {}",
73                self.children[0].vtype()
74            ),
75        }
76    }
77
78    fn with_children(&self, mut children: Vec<Rc<dyn Operator>>) -> Rc<dyn Operator> {
79        Rc::new(ScalarCompareOperator::new(
80            children.remove(0),
81            self.op,
82            self.scalar.clone(),
83        ))
84    }
85}
86
87struct ScalarComparePrimitiveKernel<T: Element + NativePType, Op: CompareOp<T>> {
88    lhs: VectorId,
89    rhs: T,
90    _phantom: PhantomData<Op>,
91}
92
93impl<T: Element + NativePType, Op: CompareOp<T>> Kernel for ScalarComparePrimitiveKernel<T, Op> {
94    fn seek(&mut self, chunk_idx: usize) -> VortexResult<()> {
95        Ok(())
96    }
97
98    fn step(
99        &mut self,
100        ctx: &KernelContext,
101        selected: BitView,
102        out: &mut ViewMut,
103    ) -> VortexResult<()> {
104        let lhs_vec = ctx.vector(self.lhs);
105        let lhs = lhs_vec.as_slice::<T>();
106
107        let bools = out.as_slice_mut::<bool>();
108
109        debug_assert_eq!(selected.true_count(), lhs.len());
110        lhs.iter().zip(bools).for_each(|(lhs, bool)| {
111            *bool = Op::compare(lhs, &self.rhs);
112        });
113
114        Ok(())
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use std::rc::Rc;
121
122    use vortex_buffer::BufferMut;
123    use vortex_dtype::Nullability;
124    use vortex_scalar::Scalar;
125
126    use super::*;
127    use crate::arrays::PrimitiveArray;
128    use crate::pipeline::bits::BitView;
129    use crate::pipeline::query::QueryPlan;
130    use crate::pipeline::view::ViewMut;
131    use crate::pipeline::{N, N_WORDS};
132
133    #[test]
134    fn test_scalar_compare_stacked_on_primitive() {
135        // Create input data: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
136        let size = 16;
137        let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
138        let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
139
140        // Create scalar compare operator: primitive_value > 10
141        let compare_value = Scalar::primitive(10i32, Nullability::NonNullable);
142        let scalar_compare_op = Rc::new(ScalarCompareOperator::new(
143            primitive_op,
144            BinaryOperator::Gt,
145            compare_value,
146        ));
147
148        // Create query plan from the stacked operators
149        let plan = QueryPlan::new(scalar_compare_op.as_ref()).unwrap();
150        let mut pipeline = plan.executable_plan().unwrap();
151
152        // Create all-true mask for simplicity
153        let mask_data = [usize::MAX; N_WORDS];
154        let mask_view = BitView::new(&mask_data);
155
156        // Create output buffer for boolean results
157        let mut output = BufferMut::<bool>::with_capacity(N);
158        unsafe { output.set_len(N) };
159        let mut output_view = ViewMut::new(&mut output[..], None);
160
161        // Execute the pipeline
162        let result = pipeline._step(mask_view, &mut output_view);
163        assert!(result.is_ok());
164
165        // Verify results: values 0-10 should be false, values 11-15 should be true
166        for i in 0..size {
167            let expected = i > 10;
168            assert_eq!(
169                output[i], expected,
170                "Position {}: expected {}, got {}",
171                i, expected, output[i]
172            );
173        }
174    }
175
176    #[test]
177    fn test_scalar_compare_different_operators() {
178        // Test with different comparison operators
179        let size = 8;
180        let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
181
182        let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
183
184        // Test Eq: values == 3
185        let compare_value = Scalar::primitive(3i32, Nullability::NonNullable);
186        let eq_op = Rc::new(ScalarCompareOperator::new(
187            primitive_op,
188            BinaryOperator::Eq,
189            compare_value,
190        ));
191
192        let plan = QueryPlan::new(eq_op.as_ref()).unwrap();
193        let mut pipeline = plan.executable_plan().unwrap();
194
195        let mask_data = [usize::MAX; N_WORDS];
196        let mask_view = BitView::new(&mask_data);
197
198        let mut output = BufferMut::<bool>::with_capacity(N);
199        unsafe { output.set_len(N) };
200        let mut output_view = ViewMut::new(&mut output[..], None);
201
202        let result = pipeline._step(mask_view, &mut output_view);
203        assert!(result.is_ok());
204
205        // Only position 3 should be true
206        for i in 0..size {
207            let expected = i == 3;
208            assert_eq!(
209                output[i], expected,
210                "Eq test - Position {}: expected {}, got {}",
211                i, expected, output[i]
212            );
213        }
214    }
215
216    #[test]
217    fn test_scalar_compare_with_f32() {
218        // Test with floating-point values
219        let size = 8;
220        let values: Vec<f32> = (0..size).map(|i| i as f32 + 0.5).collect();
221        let primitive_array = values.into_iter().collect::<PrimitiveArray>();
222
223        let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
224
225        // Test Lt: values < 3.5
226        let compare_value = Scalar::primitive(3.5f32, Nullability::NonNullable);
227        let lt_op = Rc::new(ScalarCompareOperator::new(
228            primitive_op,
229            BinaryOperator::Lt,
230            compare_value,
231        ));
232
233        let plan = QueryPlan::new(lt_op.as_ref()).unwrap();
234        let mut pipeline = plan.executable_plan().unwrap();
235
236        let mask_data = [usize::MAX; N_WORDS];
237        let mask_view = BitView::new(&mask_data);
238
239        let mut output = BufferMut::<bool>::with_capacity(N);
240        unsafe { output.set_len(N) };
241        let mut output_view = ViewMut::new(&mut output[..], None);
242
243        let result = pipeline._step(mask_view, &mut output_view);
244        assert!(result.is_ok());
245
246        // Values 0.5, 1.5, 2.5 should be < 3.5 (true), 3.5+ should be false
247        for i in 0..size {
248            let value = i as f32 + 0.5;
249            let expected = value < 3.5;
250            assert_eq!(
251                output[i], expected,
252                "Lt test - Position {}: value {} should be {}, got {}",
253                i, value, expected, output[i]
254            );
255        }
256    }
257}