vortex_array/pipeline/operators/
scalar_compare.rs1use std::any::Any;
5use std::marker::PhantomData;
6use std::rc::Rc;
7
8use vortex_dtype::{NativePType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail};
10use vortex_scalar::Scalar;
11
12use crate::compute::Operator as BinaryOperator;
13use crate::match_each_compare_op;
14use crate::pipeline::bits::BitView;
15use crate::pipeline::operators::BindContext;
16use crate::pipeline::operators::compare::CompareOp;
17use crate::pipeline::types::{Element, VType};
18use crate::pipeline::vec::VectorId;
19use crate::pipeline::view::ViewMut;
20use crate::pipeline::{Kernel, KernelContext, Operator};
21
22#[derive(Debug, Hash)]
24pub struct ScalarCompareOperator {
25 children: [Rc<dyn Operator>; 1],
26 pub op: BinaryOperator,
27 pub scalar: Scalar,
28}
29
30impl ScalarCompareOperator {
31 pub fn new(child: Rc<dyn Operator>, op: BinaryOperator, scalar: Scalar) -> Self {
32 assert_eq!(child.vtype(), VType::Primitive(scalar.dtype().as_ptype()));
33 Self {
34 children: [child],
35 op,
36 scalar,
37 }
38 }
39}
40
41impl Operator for ScalarCompareOperator {
42 fn as_any(&self) -> &dyn Any {
43 self
44 }
45
46 fn children(&self) -> &[Rc<dyn Operator>] {
47 &self.children
48 }
49
50 fn vtype(&self) -> VType {
51 VType::Bool
52 }
53
54 fn bind(&self, ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
55 match self.children[0].vtype() {
56 VType::Primitive(ptype) => {
57 match_each_native_ptype!(ptype, |T| {
58 match_each_compare_op!(self.op, |Op| {
59 Ok(Box::new(ScalarComparePrimitiveKernel::<T, Op> {
60 lhs: ctx.children()[0],
61 rhs: self
62 .scalar
63 .as_primitive()
64 .typed_value::<T>()
65 .vortex_expect("scalar value not of type T"),
66 _phantom: PhantomData,
67 }) as Box<dyn Kernel>)
68 })
69 })
70 }
71 _ => vortex_bail!(
72 "Unsupported type for comparison: {}",
73 self.children[0].vtype()
74 ),
75 }
76 }
77
78 fn with_children(&self, mut children: Vec<Rc<dyn Operator>>) -> Rc<dyn Operator> {
79 Rc::new(ScalarCompareOperator::new(
80 children.remove(0),
81 self.op,
82 self.scalar.clone(),
83 ))
84 }
85}
86
87struct ScalarComparePrimitiveKernel<T: Element + NativePType, Op: CompareOp<T>> {
88 lhs: VectorId,
89 rhs: T,
90 _phantom: PhantomData<Op>,
91}
92
93impl<T: Element + NativePType, Op: CompareOp<T>> Kernel for ScalarComparePrimitiveKernel<T, Op> {
94 fn seek(&mut self, chunk_idx: usize) -> VortexResult<()> {
95 Ok(())
96 }
97
98 fn step(
99 &mut self,
100 ctx: &KernelContext,
101 selected: BitView,
102 out: &mut ViewMut,
103 ) -> VortexResult<()> {
104 let lhs_vec = ctx.vector(self.lhs);
105 let lhs = lhs_vec.as_slice::<T>();
106
107 let bools = out.as_slice_mut::<bool>();
108
109 debug_assert_eq!(selected.true_count(), lhs.len());
110 lhs.iter().zip(bools).for_each(|(lhs, bool)| {
111 *bool = Op::compare(lhs, &self.rhs);
112 });
113
114 Ok(())
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use std::rc::Rc;
121
122 use vortex_buffer::BufferMut;
123 use vortex_dtype::Nullability;
124 use vortex_scalar::Scalar;
125
126 use super::*;
127 use crate::arrays::PrimitiveArray;
128 use crate::pipeline::bits::BitView;
129 use crate::pipeline::query::QueryPlan;
130 use crate::pipeline::view::ViewMut;
131 use crate::pipeline::{N, N_WORDS};
132
133 #[test]
134 fn test_scalar_compare_stacked_on_primitive() {
135 let size = 16;
137 let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
138 let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
139
140 let compare_value = Scalar::primitive(10i32, Nullability::NonNullable);
142 let scalar_compare_op = Rc::new(ScalarCompareOperator::new(
143 primitive_op,
144 BinaryOperator::Gt,
145 compare_value,
146 ));
147
148 let plan = QueryPlan::new(scalar_compare_op.as_ref()).unwrap();
150 let mut pipeline = plan.executable_plan().unwrap();
151
152 let mask_data = [usize::MAX; N_WORDS];
154 let mask_view = BitView::new(&mask_data);
155
156 let mut output = BufferMut::<bool>::with_capacity(N);
158 unsafe { output.set_len(N) };
159 let mut output_view = ViewMut::new(&mut output[..], None);
160
161 let result = pipeline._step(mask_view, &mut output_view);
163 assert!(result.is_ok());
164
165 for i in 0..size {
167 let expected = i > 10;
168 assert_eq!(
169 output[i], expected,
170 "Position {}: expected {}, got {}",
171 i, expected, output[i]
172 );
173 }
174 }
175
176 #[test]
177 fn test_scalar_compare_different_operators() {
178 let size = 8;
180 let primitive_array = (0..i32::try_from(size).unwrap()).collect::<PrimitiveArray>();
181
182 let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
183
184 let compare_value = Scalar::primitive(3i32, Nullability::NonNullable);
186 let eq_op = Rc::new(ScalarCompareOperator::new(
187 primitive_op,
188 BinaryOperator::Eq,
189 compare_value,
190 ));
191
192 let plan = QueryPlan::new(eq_op.as_ref()).unwrap();
193 let mut pipeline = plan.executable_plan().unwrap();
194
195 let mask_data = [usize::MAX; N_WORDS];
196 let mask_view = BitView::new(&mask_data);
197
198 let mut output = BufferMut::<bool>::with_capacity(N);
199 unsafe { output.set_len(N) };
200 let mut output_view = ViewMut::new(&mut output[..], None);
201
202 let result = pipeline._step(mask_view, &mut output_view);
203 assert!(result.is_ok());
204
205 for i in 0..size {
207 let expected = i == 3;
208 assert_eq!(
209 output[i], expected,
210 "Eq test - Position {}: expected {}, got {}",
211 i, expected, output[i]
212 );
213 }
214 }
215
216 #[test]
217 fn test_scalar_compare_with_f32() {
218 let size = 8;
220 let values: Vec<f32> = (0..size).map(|i| i as f32 + 0.5).collect();
221 let primitive_array = values.into_iter().collect::<PrimitiveArray>();
222
223 let primitive_op = primitive_array.as_ref().to_operator().unwrap().unwrap();
224
225 let compare_value = Scalar::primitive(3.5f32, Nullability::NonNullable);
227 let lt_op = Rc::new(ScalarCompareOperator::new(
228 primitive_op,
229 BinaryOperator::Lt,
230 compare_value,
231 ));
232
233 let plan = QueryPlan::new(lt_op.as_ref()).unwrap();
234 let mut pipeline = plan.executable_plan().unwrap();
235
236 let mask_data = [usize::MAX; N_WORDS];
237 let mask_view = BitView::new(&mask_data);
238
239 let mut output = BufferMut::<bool>::with_capacity(N);
240 unsafe { output.set_len(N) };
241 let mut output_view = ViewMut::new(&mut output[..], None);
242
243 let result = pipeline._step(mask_view, &mut output_view);
244 assert!(result.is_ok());
245
246 for i in 0..size {
248 let value = i as f32 + 0.5;
249 let expected = value < 3.5;
250 assert_eq!(
251 output[i], expected,
252 "Lt test - Position {}: value {} should be {}, got {}",
253 i, value, expected, output[i]
254 );
255 }
256 }
257}