vortex_array/pipeline/operators/
compare.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::marker::PhantomData;
6use std::rc::Rc;
7
8use itertools::Itertools;
9use vortex_dtype::{NativePType, match_each_native_ptype};
10use vortex_error::{VortexExpect, VortexResult, vortex_bail};
11
12use crate::arrays::ConstantOperator;
13use crate::compute::Operator as BinaryOperator;
14use crate::pipeline::bits::BitView;
15use crate::pipeline::operators::scalar_compare::ScalarCompareOperator;
16use crate::pipeline::operators::{BindContext, Operator};
17use crate::pipeline::types::{Element, VType};
18use crate::pipeline::vec::VectorId;
19use crate::pipeline::view::ViewMut;
20use crate::pipeline::{Kernel, KernelContext};
21
22#[macro_export]
23macro_rules! match_each_compare_op {
24    ($self:expr, | $enc:ident | $body:block) => {{
25        match $self {
26            BinaryOperator::Eq => {
27                type $enc = $crate::pipeline::operators::compare::Eq;
28                $body
29            }
30            BinaryOperator::NotEq => {
31                type $enc = $crate::pipeline::operators::compare::NotEq;
32                $body
33            }
34            BinaryOperator::Gt => {
35                type $enc = $crate::pipeline::operators::compare::Gt;
36                $body
37            }
38            BinaryOperator::Gte => {
39                type $enc = $crate::pipeline::operators::compare::Gte;
40                $body
41            }
42            BinaryOperator::Lt => {
43                type $enc = $crate::pipeline::operators::compare::Lt;
44                $body
45            }
46            BinaryOperator::Lte => {
47                type $enc = $crate::pipeline::operators::compare::Lte;
48                $body
49            }
50        }
51    }};
52}
53
54/// Pipeline operator for comparing two arrays using various comparison operations.
55#[derive(Debug, Hash)]
56pub struct CompareOperator {
57    children: [Rc<dyn Operator>; 2],
58    op: BinaryOperator,
59}
60
61impl CompareOperator {
62    pub fn new(lhs: Rc<dyn Operator>, rhs: Rc<dyn Operator>, op: BinaryOperator) -> Self {
63        assert_eq!(lhs.vtype(), rhs.vtype(), "Operands must have the same type");
64        Self {
65            children: [lhs, rhs],
66            op,
67        }
68    }
69}
70
71impl Operator for CompareOperator {
72    fn as_any(&self) -> &dyn Any {
73        self
74    }
75
76    fn vtype(&self) -> VType {
77        VType::Bool
78    }
79
80    fn children(&self) -> &[Rc<dyn Operator>] {
81        &self.children
82    }
83
84    fn with_children(&self, children: Vec<Rc<dyn Operator>>) -> Rc<dyn Operator> {
85        let [lhs, rhs] = children
86            .try_into()
87            .ok()
88            .vortex_expect("Expected 2 children");
89        Rc::new(CompareOperator::new(lhs, rhs, self.op))
90    }
91
92    fn bind(&self, ctx: &dyn BindContext) -> VortexResult<Box<dyn Kernel>> {
93        debug_assert_eq!(self.children[0].vtype(), self.children[1].vtype());
94
95        let VType::Primitive(ptype) = self.children[0].vtype() else {
96            vortex_bail!(
97                "Unsupported type for comparison: {}",
98                self.children[0].vtype()
99            )
100        };
101
102        match_each_native_ptype!(ptype, |T| {
103            match_each_compare_op!(self.op, |Op| {
104                Ok(Box::new(ComparePrimitiveKernel::<T, Op> {
105                    lhs: ctx.children()[0],
106                    rhs: ctx.children()[1],
107                    _phantom: PhantomData,
108                }) as Box<dyn Kernel>)
109            })
110        })
111    }
112
113    fn reduce_children(&self, children: &[Rc<dyn Operator>]) -> Option<Rc<dyn Operator>> {
114        let constants = children
115            .iter()
116            .enumerate()
117            .filter_map(|(idx, c)| {
118                c.as_any()
119                    .downcast_ref::<ConstantOperator>()
120                    .map(|c| (idx, c))
121            })
122            .collect_vec();
123
124        if constants.len() != 1 {
125            return None;
126        }
127        let [(idx, lhs)] = constants
128            .try_into()
129            .ok()
130            .vortex_expect("Expected 1 constant");
131
132        if idx == 0 {
133            Some(Rc::new(ScalarCompareOperator::new(
134                children[1].clone(),
135                self.op.inverse(),
136                lhs.scalar.clone(),
137            )))
138        } else {
139            Some(Rc::new(ScalarCompareOperator::new(
140                children[0].clone(),
141                self.op,
142                lhs.scalar.clone(),
143            )))
144        }
145    }
146}
147
148/// A compare operator for primitive types that compares two vectors element-wise using a binary
149/// operation.
150/// Kernel that performs primitive type comparisons between two input vectors.
151pub struct ComparePrimitiveKernel<T, Op> {
152    lhs: VectorId,
153    rhs: VectorId,
154    _phantom: PhantomData<(T, Op)>,
155}
156
157impl<T: Element + NativePType, Op: CompareOp<T>> Kernel for ComparePrimitiveKernel<T, Op> {
158    fn step(
159        &mut self,
160        ctx: &KernelContext,
161        selected: BitView,
162        out: &mut ViewMut,
163    ) -> VortexResult<()> {
164        let lhs_vec = ctx.vector(self.lhs);
165        let lhs = lhs_vec.as_slice::<T>();
166        let rhs_vec = ctx.vector(self.rhs);
167        let rhs = rhs_vec.as_slice::<T>();
168        let bools = out.as_slice_mut::<bool>();
169
170        assert_eq!(
171            lhs.len(),
172            rhs.len(),
173            "LHS and RHS must have the same length"
174        );
175
176        lhs.iter()
177            .zip(rhs.iter())
178            .zip(bools)
179            .for_each(|((lhs, rhs), bool)| *bool = Op::compare(lhs, rhs));
180
181        Ok(())
182    }
183}
184
185pub(crate) trait CompareOp<T> {
186    fn compare(lhs: &T, rhs: &T) -> bool;
187}
188
189/// Equality comparison operation.
190pub struct Eq;
191impl<T: PartialEq> CompareOp<T> for Eq {
192    #[inline(always)]
193    fn compare(lhs: &T, rhs: &T) -> bool {
194        lhs == rhs
195    }
196}
197
198/// Not equal comparison operation.
199pub struct NotEq;
200impl<T: PartialEq> CompareOp<T> for NotEq {
201    #[inline(always)]
202    fn compare(lhs: &T, rhs: &T) -> bool {
203        lhs != rhs
204    }
205}
206
207/// Greater than comparison operation.
208pub struct Gt;
209impl<T: PartialOrd> CompareOp<T> for Gt {
210    #[inline(always)]
211    fn compare(lhs: &T, rhs: &T) -> bool {
212        lhs > rhs
213    }
214}
215
216/// Greater than or equal comparison operation.
217pub struct Gte;
218impl<T: PartialOrd> CompareOp<T> for Gte {
219    #[inline(always)]
220    fn compare(lhs: &T, rhs: &T) -> bool {
221        lhs >= rhs
222    }
223}
224
225/// Less than comparison operation.
226pub struct Lt;
227impl<T: PartialOrd> CompareOp<T> for Lt {
228    #[inline(always)]
229    fn compare(lhs: &T, rhs: &T) -> bool {
230        lhs < rhs
231    }
232}
233
234/// Less than or equal comparison operation.
235pub struct Lte;
236impl<T: PartialOrd> CompareOp<T> for Lte {
237    #[inline(always)]
238    fn compare(lhs: &T, rhs: &T) -> bool {
239        lhs <= rhs
240    }
241}