1use std::fmt::Formatter;
5
6#[expect(deprecated)]
7pub use boolean::and_kleene;
8#[expect(deprecated)]
9pub use boolean::or_kleene;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15
16use crate::ArrayRef;
17use crate::dtype::DType;
18use crate::expr::StatsCatalog;
19use crate::expr::and;
20use crate::expr::and_collect;
21use crate::expr::eq;
22use crate::expr::expression::Expression;
23use crate::expr::gt;
24use crate::expr::gt_eq;
25use crate::expr::lit;
26use crate::expr::lt;
27use crate::expr::lt_eq;
28use crate::expr::or_collect;
29use crate::expr::stats::Stat;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::fns::operators::CompareOperator;
36use crate::scalar_fn::fns::operators::Operator;
37
38pub(crate) mod boolean;
39pub(crate) use boolean::*;
40mod compare;
41pub use compare::*;
42mod numeric;
43pub(crate) use numeric::*;
44
45#[derive(Clone)]
46pub struct Binary;
47
48impl ScalarFnVTable for Binary {
49 type Options = Operator;
50
51 fn id(&self) -> ScalarFnId {
52 ScalarFnId::from("vortex.binary")
53 }
54
55 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
56 Ok(Some(
57 pb::BinaryOpts {
58 op: (*instance).into(),
59 }
60 .encode_to_vec(),
61 ))
62 }
63
64 fn deserialize(
65 &self,
66 _metadata: &[u8],
67 _session: &VortexSession,
68 ) -> VortexResult<Self::Options> {
69 let opts = pb::BinaryOpts::decode(_metadata)?;
70 Operator::try_from(opts.op)
71 }
72
73 fn arity(&self, _options: &Self::Options) -> Arity {
74 Arity::Exact(2)
75 }
76
77 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
78 match child_idx {
79 0 => ChildName::from("lhs"),
80 1 => ChildName::from("rhs"),
81 _ => unreachable!("Binary has only two children"),
82 }
83 }
84
85 fn fmt_sql(
86 &self,
87 operator: &Operator,
88 expr: &Expression,
89 f: &mut Formatter<'_>,
90 ) -> std::fmt::Result {
91 write!(f, "(")?;
92 expr.child(0).fmt_sql(f)?;
93 write!(f, " {} ", operator)?;
94 expr.child(1).fmt_sql(f)?;
95 write!(f, ")")
96 }
97
98 fn return_dtype(&self, operator: &Operator, arg_dtypes: &[DType]) -> VortexResult<DType> {
99 let lhs = &arg_dtypes[0];
100 let rhs = &arg_dtypes[1];
101
102 if operator.is_arithmetic() {
103 if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
104 return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
105 }
106 vortex_bail!(
107 "incompatible types for arithmetic operation: {} {}",
108 lhs,
109 rhs
110 );
111 }
112
113 if operator.is_comparison()
114 && !lhs.eq_ignore_nullability(rhs)
115 && !lhs.is_extension()
116 && !rhs.is_extension()
117 {
118 vortex_bail!("Cannot compare different DTypes {} and {}", lhs, rhs);
119 }
120
121 Ok(DType::Bool((lhs.is_nullable() || rhs.is_nullable()).into()))
122 }
123
124 fn execute(&self, op: &Operator, args: ExecutionArgs) -> VortexResult<ArrayRef> {
125 let [lhs, rhs] = &args.inputs[..] else {
126 vortex_bail!("Wrong arg count")
127 };
128
129 match op {
130 Operator::Eq => execute_compare(lhs, rhs, CompareOperator::Eq),
131 Operator::NotEq => execute_compare(lhs, rhs, CompareOperator::NotEq),
132 Operator::Lt => execute_compare(lhs, rhs, CompareOperator::Lt),
133 Operator::Lte => execute_compare(lhs, rhs, CompareOperator::Lte),
134 Operator::Gt => execute_compare(lhs, rhs, CompareOperator::Gt),
135 Operator::Gte => execute_compare(lhs, rhs, CompareOperator::Gte),
136 Operator::And => execute_boolean(lhs, rhs, Operator::And),
137 Operator::Or => execute_boolean(lhs, rhs, Operator::Or),
138 Operator::Add => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Add),
139 Operator::Sub => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Sub),
140 Operator::Mul => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Mul),
141 Operator::Div => execute_numeric(lhs, rhs, crate::scalar::NumericOperator::Div),
142 }
143 }
144
145 fn stat_falsification(
146 &self,
147 operator: &Operator,
148 expr: &Expression,
149 catalog: &dyn StatsCatalog,
150 ) -> Option<Expression> {
151 #[inline]
165 fn with_nan_predicate(
166 lhs: &Expression,
167 rhs: &Expression,
168 value_predicate: Expression,
169 catalog: &dyn StatsCatalog,
170 ) -> Expression {
171 let nan_predicate = and_collect(
172 lhs.stat_expression(Stat::NaNCount, catalog)
173 .into_iter()
174 .chain(rhs.stat_expression(Stat::NaNCount, catalog))
175 .map(|nans| eq(nans, lit(0u64))),
176 );
177
178 if let Some(nan_check) = nan_predicate {
179 and(nan_check, value_predicate)
180 } else {
181 value_predicate
182 }
183 }
184
185 let lhs = expr.child(0);
186 let rhs = expr.child(1);
187 match operator {
188 Operator::Eq => {
189 let min_lhs = lhs.stat_min(catalog);
190 let max_lhs = lhs.stat_max(catalog);
191
192 let min_rhs = rhs.stat_min(catalog);
193 let max_rhs = rhs.stat_max(catalog);
194
195 let left = min_lhs.zip(max_rhs).map(|(a, b)| gt(a, b));
196 let right = min_rhs.zip(max_lhs).map(|(a, b)| gt(a, b));
197
198 let min_max_check = or_collect(left.into_iter().chain(right))?;
199
200 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
202 }
203 Operator::NotEq => {
204 let min_lhs = lhs.stat_min(catalog)?;
205 let max_lhs = lhs.stat_max(catalog)?;
206
207 let min_rhs = rhs.stat_min(catalog)?;
208 let max_rhs = rhs.stat_max(catalog)?;
209
210 let min_max_check = and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs));
211
212 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
213 }
214 Operator::Gt => {
215 let min_max_check = lt_eq(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
216
217 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
218 }
219 Operator::Gte => {
220 let min_max_check = lt(lhs.stat_max(catalog)?, rhs.stat_min(catalog)?);
222
223 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
224 }
225 Operator::Lt => {
226 let min_max_check = gt_eq(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
228
229 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
230 }
231 Operator::Lte => {
232 let min_max_check = gt(lhs.stat_min(catalog)?, rhs.stat_max(catalog)?);
234
235 Some(with_nan_predicate(lhs, rhs, min_max_check, catalog))
236 }
237 Operator::And => or_collect(
238 lhs.stat_falsification(catalog)
239 .into_iter()
240 .chain(rhs.stat_falsification(catalog)),
241 ),
242 Operator::Or => Some(and(
243 lhs.stat_falsification(catalog)?,
244 rhs.stat_falsification(catalog)?,
245 )),
246 Operator::Add | Operator::Sub | Operator::Mul | Operator::Div => None,
247 }
248 }
249
250 fn validity(
251 &self,
252 operator: &Operator,
253 expression: &Expression,
254 ) -> VortexResult<Option<Expression>> {
255 let lhs = expression.child(0).validity()?;
256 let rhs = expression.child(1).validity()?;
257
258 Ok(match operator {
259 Operator::And => None,
261 Operator::Or => None,
262 _ => {
263 Some(and(lhs, rhs))
265 }
266 })
267 }
268
269 fn is_null_sensitive(&self, _operator: &Operator) -> bool {
270 false
271 }
272
273 fn is_fallible(&self, operator: &Operator) -> bool {
274 let infallible = matches!(
277 operator,
278 Operator::Eq
279 | Operator::NotEq
280 | Operator::Gt
281 | Operator::Gte
282 | Operator::Lt
283 | Operator::Lte
284 | Operator::And
285 | Operator::Or
286 );
287
288 !infallible
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use vortex_error::VortexExpect;
295
296 use super::*;
297 use crate::assert_arrays_eq;
298 use crate::builtins::ArrayBuiltins;
299 use crate::dtype::DType;
300 use crate::dtype::Nullability;
301 use crate::expr::Expression;
302 use crate::expr::and_collect;
303 use crate::expr::col;
304 use crate::expr::lit;
305 use crate::expr::lt;
306 use crate::expr::not_eq;
307 use crate::expr::or;
308 use crate::expr::or_collect;
309 use crate::expr::test_harness;
310 use crate::scalar::Scalar;
311 #[test]
312 fn and_collect_balanced() {
313 let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
314
315 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
316 vortex.binary(and)
317 ├── lhs: vortex.binary(and)
318 │ ├── lhs: vortex.literal(1i32)
319 │ └── rhs: vortex.literal(2i32)
320 └── rhs: vortex.binary(and)
321 ├── lhs: vortex.binary(and)
322 │ ├── lhs: vortex.literal(3i32)
323 │ └── rhs: vortex.literal(4i32)
324 └── rhs: vortex.literal(5i32)
325 ");
326
327 let values = vec![lit(1), lit(2), lit(3), lit(4)];
329 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @r"
330 vortex.binary(and)
331 ├── lhs: vortex.binary(and)
332 │ ├── lhs: vortex.literal(1i32)
333 │ └── rhs: vortex.literal(2i32)
334 └── rhs: vortex.binary(and)
335 ├── lhs: vortex.literal(3i32)
336 └── rhs: vortex.literal(4i32)
337 ");
338
339 let values = vec![lit(1)];
341 insta::assert_snapshot!(and_collect(values.into_iter()).unwrap().display_tree(), @"vortex.literal(1i32)");
342
343 let values: Vec<Expression> = vec![];
345 assert!(and_collect(values.into_iter()).is_none());
346 }
347
348 #[test]
349 fn or_collect_balanced() {
350 let values = vec![lit(1), lit(2), lit(3), lit(4)];
352 insta::assert_snapshot!(or_collect(values.into_iter()).unwrap().display_tree(), @r"
353 vortex.binary(or)
354 ├── lhs: vortex.binary(or)
355 │ ├── lhs: vortex.literal(1i32)
356 │ └── rhs: vortex.literal(2i32)
357 └── rhs: vortex.binary(or)
358 ├── lhs: vortex.literal(3i32)
359 └── rhs: vortex.literal(4i32)
360 ");
361 }
362
363 #[test]
364 fn dtype() {
365 let dtype = test_harness::struct_dtype();
366 let bool1: Expression = col("bool1");
367 let bool2: Expression = col("bool2");
368 assert_eq!(
369 and(bool1.clone(), bool2.clone())
370 .return_dtype(&dtype)
371 .unwrap(),
372 DType::Bool(Nullability::NonNullable)
373 );
374 assert_eq!(
375 or(bool1, bool2).return_dtype(&dtype).unwrap(),
376 DType::Bool(Nullability::NonNullable)
377 );
378
379 let col1: Expression = col("col1");
380 let col2: Expression = col("col2");
381
382 assert_eq!(
383 eq(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
384 DType::Bool(Nullability::Nullable)
385 );
386 assert_eq!(
387 not_eq(col1.clone(), col2.clone())
388 .return_dtype(&dtype)
389 .unwrap(),
390 DType::Bool(Nullability::Nullable)
391 );
392 assert_eq!(
393 gt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
394 DType::Bool(Nullability::Nullable)
395 );
396 assert_eq!(
397 gt_eq(col1.clone(), col2.clone())
398 .return_dtype(&dtype)
399 .unwrap(),
400 DType::Bool(Nullability::Nullable)
401 );
402 assert_eq!(
403 lt(col1.clone(), col2.clone()).return_dtype(&dtype).unwrap(),
404 DType::Bool(Nullability::Nullable)
405 );
406 assert_eq!(
407 lt_eq(col1.clone(), col2.clone())
408 .return_dtype(&dtype)
409 .unwrap(),
410 DType::Bool(Nullability::Nullable)
411 );
412
413 assert_eq!(
414 or(lt(col1.clone(), col2.clone()), not_eq(col1, col2))
415 .return_dtype(&dtype)
416 .unwrap(),
417 DType::Bool(Nullability::Nullable)
418 );
419 }
420
421 #[test]
422 fn test_display_print() {
423 let expr = gt(lit(1), lit(2));
424 assert_eq!(format!("{expr}"), "(1i32 > 2i32)");
425 }
426
427 #[test]
430 fn test_struct_comparison() {
431 use crate::IntoArray;
432 use crate::arrays::StructArray;
433
434 let lhs_struct = StructArray::from_fields(&[
436 (
437 "a",
438 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
439 ),
440 (
441 "b",
442 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
443 ),
444 ])
445 .unwrap()
446 .into_array();
447
448 let rhs_struct_equal = StructArray::from_fields(&[
449 (
450 "a",
451 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
452 ),
453 (
454 "b",
455 crate::arrays::PrimitiveArray::from_iter([3i32]).into_array(),
456 ),
457 ])
458 .unwrap()
459 .into_array();
460
461 let rhs_struct_different = StructArray::from_fields(&[
462 (
463 "a",
464 crate::arrays::PrimitiveArray::from_iter([1i32]).into_array(),
465 ),
466 (
467 "b",
468 crate::arrays::PrimitiveArray::from_iter([4i32]).into_array(),
469 ),
470 ])
471 .unwrap()
472 .into_array();
473
474 let result_equal = lhs_struct.binary(rhs_struct_equal, Operator::Eq).unwrap();
476 assert_eq!(
477 result_equal.scalar_at(0).vortex_expect("value"),
478 Scalar::bool(true, Nullability::NonNullable),
479 "Equal structs should be equal"
480 );
481
482 let result_different = lhs_struct
483 .binary(rhs_struct_different, Operator::Eq)
484 .unwrap();
485 assert_eq!(
486 result_different.scalar_at(0).vortex_expect("value"),
487 Scalar::bool(false, Nullability::NonNullable),
488 "Different structs should not be equal"
489 );
490 }
491
492 #[test]
493 fn test_or_kleene_validity() {
494 use crate::IntoArray;
495 use crate::arrays::BoolArray;
496 use crate::arrays::StructArray;
497 use crate::expr::col;
498
499 let struct_arr = StructArray::from_fields(&[
500 ("a", BoolArray::from_iter([Some(true)]).into_array()),
501 (
502 "b",
503 BoolArray::from_iter([Option::<bool>::None]).into_array(),
504 ),
505 ])
506 .unwrap()
507 .into_array();
508
509 let expr = or(col("a"), col("b"));
510 let result = struct_arr.apply(&expr).unwrap();
511
512 assert_arrays_eq!(result, BoolArray::from_iter([Some(true)]).into_array())
513 }
514}