vortex_expr/
binary.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use vortex_array::compute::{Operator as ArrayOperator, and_kleene, compare, or_kleene};
7use vortex_array::{Array, ArrayRef};
8use vortex_dtype::DType;
9use vortex_error::VortexResult;
10
11use crate::{ExprRef, Operator, VortexExpr};
12
13#[derive(Debug, Clone, Eq, Hash)]
14#[allow(clippy::derived_hash_with_manual_eq)]
15pub struct BinaryExpr {
16    lhs: ExprRef,
17    operator: Operator,
18    rhs: ExprRef,
19}
20
21impl BinaryExpr {
22    pub fn new_expr(lhs: ExprRef, operator: Operator, rhs: ExprRef) -> ExprRef {
23        Arc::new(Self { lhs, operator, rhs })
24    }
25
26    pub fn lhs(&self) -> &ExprRef {
27        &self.lhs
28    }
29
30    pub fn rhs(&self) -> &ExprRef {
31        &self.rhs
32    }
33
34    pub fn op(&self) -> Operator {
35        self.operator
36    }
37}
38
39impl Display for BinaryExpr {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "({} {} {})", self.lhs, self.operator, self.rhs)
42    }
43}
44
45#[cfg(feature = "proto")]
46pub(crate) mod proto {
47    use vortex_error::{VortexResult, vortex_bail};
48    use vortex_proto::expr::kind::Kind;
49
50    use crate::{BinaryExpr, ExprDeserialize, ExprRef, ExprSerializable, Id};
51
52    pub(crate) struct BinarySerde;
53
54    impl Id for BinarySerde {
55        fn id(&self) -> &'static str {
56            "binary"
57        }
58    }
59
60    impl ExprDeserialize for BinarySerde {
61        fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
62            let Kind::BinaryOp(op) = kind else {
63                vortex_bail!("wrong kind {:?}, binary", kind)
64            };
65
66            Ok(BinaryExpr::new_expr(
67                children[0].clone(),
68                (*op).try_into()?,
69                children[1].clone(),
70            ))
71        }
72    }
73
74    impl ExprSerializable for BinaryExpr {
75        fn id(&self) -> &'static str {
76            BinarySerde.id()
77        }
78
79        fn serialize_kind(&self) -> VortexResult<Kind> {
80            Ok(Kind::BinaryOp(self.operator.into()))
81        }
82    }
83}
84
85impl VortexExpr for BinaryExpr {
86    fn as_any(&self) -> &dyn Any {
87        self
88    }
89
90    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
91        let lhs = self.lhs.evaluate(batch)?;
92        let rhs = self.rhs.evaluate(batch)?;
93
94        match self.operator {
95            Operator::Eq => compare(&lhs, &rhs, ArrayOperator::Eq),
96            Operator::NotEq => compare(&lhs, &rhs, ArrayOperator::NotEq),
97            Operator::Lt => compare(&lhs, &rhs, ArrayOperator::Lt),
98            Operator::Lte => compare(&lhs, &rhs, ArrayOperator::Lte),
99            Operator::Gt => compare(&lhs, &rhs, ArrayOperator::Gt),
100            Operator::Gte => compare(&lhs, &rhs, ArrayOperator::Gte),
101            Operator::And => and_kleene(&lhs, &rhs),
102            Operator::Or => or_kleene(&lhs, &rhs),
103        }
104    }
105
106    fn children(&self) -> Vec<&ExprRef> {
107        vec![&self.lhs, &self.rhs]
108    }
109
110    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
111        assert_eq!(children.len(), 2);
112        BinaryExpr::new_expr(children[0].clone(), self.operator, children[1].clone())
113    }
114
115    fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
116        let lhs = self.lhs.return_dtype(scope_dtype)?;
117        let rhs = self.rhs.return_dtype(scope_dtype)?;
118        Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
119    }
120}
121
122impl PartialEq for BinaryExpr {
123    fn eq(&self, other: &BinaryExpr) -> bool {
124        other.operator == self.operator && other.lhs.eq(&self.lhs) && other.rhs.eq(&self.rhs)
125    }
126}
127
128/// Create a new `BinaryExpr` using the `Eq` operator.
129///
130/// ## Example usage
131///
132/// ```
133/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
134/// use vortex_array::{Array, IntoArray, ToCanonical};
135/// use vortex_array::validity::Validity;
136/// use vortex_buffer::buffer;
137/// use vortex_expr::{eq, ident, lit};
138///
139/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
140/// let result = eq(ident(), lit(3)).evaluate(&xs).unwrap();
141///
142/// assert_eq!(
143///     result.to_bool().unwrap().boolean_buffer(),
144///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
145/// );
146/// ```
147pub fn eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
148    BinaryExpr::new_expr(lhs, Operator::Eq, rhs)
149}
150
151/// Create a new `BinaryExpr` using the `NotEq` operator.
152///
153/// ## Example usage
154///
155/// ```
156/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
157/// use vortex_array::{IntoArray, ToCanonical};
158/// use vortex_array::validity::Validity;
159/// use vortex_buffer::buffer;
160/// use vortex_expr::{ident, lit, not_eq};
161///
162/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
163/// let result = not_eq(ident(), lit(3)).evaluate(&xs).unwrap();
164///
165/// assert_eq!(
166///     result.to_bool().unwrap().boolean_buffer(),
167///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
168/// );
169/// ```
170pub fn not_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
171    BinaryExpr::new_expr(lhs, Operator::NotEq, rhs)
172}
173
174/// Create a new `BinaryExpr` using the `Gte` operator.
175///
176/// ## Example usage
177///
178/// ```
179/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
180/// use vortex_array::{IntoArray, ToCanonical};
181/// use vortex_array::validity::Validity;
182/// use vortex_buffer::buffer;
183/// use vortex_expr::{gt_eq, ident, lit};
184///
185/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
186/// let result = gt_eq(ident(), lit(3)).evaluate(&xs).unwrap();
187///
188/// assert_eq!(
189///     result.to_bool().unwrap().boolean_buffer(),
190///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
191/// );
192/// ```
193pub fn gt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
194    BinaryExpr::new_expr(lhs, Operator::Gte, rhs)
195}
196
197/// Create a new `BinaryExpr` using the `Gt` operator.
198///
199/// ## Example usage
200///
201/// ```
202/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
203/// use vortex_array::{IntoArray, ToCanonical};
204/// use vortex_array::validity::Validity;
205/// use vortex_buffer::buffer;
206/// use vortex_expr::{gt, ident, lit};
207///
208/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
209/// let result = gt(ident(), lit(2)).evaluate(&xs).unwrap();
210///
211/// assert_eq!(
212///     result.to_bool().unwrap().boolean_buffer(),
213///     BoolArray::from_iter(vec![false, false, true]).boolean_buffer(),
214/// );
215/// ```
216pub fn gt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
217    BinaryExpr::new_expr(lhs, Operator::Gt, rhs)
218}
219
220/// Create a new `BinaryExpr` using the `Lte` operator.
221///
222/// ## Example usage
223///
224/// ```
225/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
226/// use vortex_array::{IntoArray, ToCanonical};
227/// use vortex_array::validity::Validity;
228/// use vortex_buffer::buffer;
229/// use vortex_expr::{ident, lit, lt_eq};
230///
231/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
232/// let result = lt_eq(ident(), lit(2)).evaluate(&xs).unwrap();
233///
234/// assert_eq!(
235///     result.to_bool().unwrap().boolean_buffer(),
236///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
237/// );
238/// ```
239pub fn lt_eq(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
240    BinaryExpr::new_expr(lhs, Operator::Lte, rhs)
241}
242
243/// Create a new `BinaryExpr` using the `Lt` operator.
244///
245/// ## Example usage
246///
247/// ```
248/// use vortex_array::arrays::{BoolArray, PrimitiveArray };
249/// use vortex_array::{IntoArray, ToCanonical};
250/// use vortex_array::validity::Validity;
251/// use vortex_buffer::buffer;
252/// use vortex_expr::{ident, lit, lt};
253///
254/// let xs = PrimitiveArray::new(buffer![1i32, 2i32, 3i32], Validity::NonNullable);
255/// let result = lt(ident(), lit(3)).evaluate(&xs).unwrap();
256///
257/// assert_eq!(
258///     result.to_bool().unwrap().boolean_buffer(),
259///     BoolArray::from_iter(vec![true, true, false]).boolean_buffer(),
260/// );
261/// ```
262pub fn lt(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
263    BinaryExpr::new_expr(lhs, Operator::Lt, rhs)
264}
265
266/// Create a new `BinaryExpr` using the `Or` operator.
267///
268/// ## Example usage
269///
270/// ```
271/// use vortex_array::arrays::BoolArray;
272/// use vortex_array::{IntoArray, ToCanonical};
273/// use vortex_expr::{ ident, lit, or};
274///
275/// let xs = BoolArray::from_iter(vec![true, false, true]);
276/// let result = or(ident(), lit(false)).evaluate(&xs).unwrap();
277///
278/// assert_eq!(
279///     result.to_bool().unwrap().boolean_buffer(),
280///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
281/// );
282/// ```
283pub fn or(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
284    BinaryExpr::new_expr(lhs, Operator::Or, rhs)
285}
286
287/// Create a new `BinaryExpr` using the `And` operator.
288///
289/// ## Example usage
290///
291/// ```
292/// use vortex_array::arrays::BoolArray;
293/// use vortex_array::{IntoArray, ToCanonical};
294/// use vortex_expr::{and, ident, lit};
295///
296/// let xs = BoolArray::from_iter(vec![true, false, true]);
297/// let result = and(ident(), lit(true)).evaluate(&xs).unwrap();
298///
299/// assert_eq!(
300///     result.to_bool().unwrap().boolean_buffer(),
301///     BoolArray::from_iter(vec![true, false, true]).boolean_buffer(),
302/// );
303/// ```
304pub fn and(lhs: ExprRef, rhs: ExprRef) -> ExprRef {
305    BinaryExpr::new_expr(lhs, Operator::And, rhs)
306}
307
308#[cfg(test)]
309mod tests {
310    use std::sync::Arc;
311
312    use vortex_dtype::{DType, Nullability};
313
314    use crate::{VortexExpr, and, col, eq, gt, gt_eq, lt, lt_eq, not_eq, or, test_harness};
315
316    #[test]
317    fn dtype() {
318        let dtype = test_harness::struct_dtype();
319        let bool1: Arc<dyn VortexExpr> = col("bool1");
320        let bool2: Arc<dyn VortexExpr> = col("bool2");
321        assert_eq!(
322            and(bool1.clone(), bool2.clone())
323                .return_dtype(&dtype)
324                .unwrap(),
325            DType::Bool(Nullability::NonNullable)
326        );
327        assert_eq!(
328            or(bool1.clone(), bool2.clone())
329                .return_dtype(&dtype)
330                .unwrap(),
331            DType::Bool(Nullability::NonNullable)
332        );
333
334        let col1: Arc<dyn VortexExpr> = col("col1");
335        let col2: Arc<dyn VortexExpr> = col("col2");
336
337        assert_eq!(
338            eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
339            DType::Bool(Nullability::Nullable)
340        );
341        assert_eq!(
342            not_eq(col1.clone(), col2.clone())
343                .return_dtype(&dtype)
344                .unwrap(),
345            DType::Bool(Nullability::Nullable)
346        );
347        assert_eq!(
348            gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
349            DType::Bool(Nullability::Nullable)
350        );
351        assert_eq!(
352            gt_eq(col1.clone(), col2.clone())
353                .return_dtype(&dtype)
354                .unwrap(),
355            DType::Bool(Nullability::Nullable)
356        );
357        assert_eq!(
358            lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
359            DType::Bool(Nullability::Nullable)
360        );
361        assert_eq!(
362            lt_eq(col1.clone(), col2.clone())
363                .return_dtype(&dtype)
364                .unwrap(),
365            DType::Bool(Nullability::Nullable)
366        );
367
368        assert_eq!(
369            or(
370                lt(col1.clone(), col2.clone()),
371                not_eq(col1.clone(), col2.clone())
372            )
373            .return_dtype(&dtype)
374            .unwrap(),
375            DType::Bool(Nullability::Nullable)
376        );
377    }
378}