1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use vortex_array::compute::invert;
7use vortex_array::{Array, ArrayRef};
8use vortex_dtype::DType;
9use vortex_error::VortexResult;
10
11use crate::{ExprRef, VortexExpr};
12
13#[derive(Debug, Eq, Hash)]
14#[allow(clippy::derived_hash_with_manual_eq)]
16pub struct Not {
17 child: ExprRef,
18}
19
20impl Not {
21 pub fn new_expr(child: ExprRef) -> ExprRef {
22 Arc::new(Self { child })
23 }
24
25 pub fn child(&self) -> &ExprRef {
26 &self.child
27 }
28}
29
30impl Display for Not {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 write!(f, "!")?;
33 self.child.fmt(f)
34 }
35}
36
37impl VortexExpr for Not {
38 fn as_any(&self) -> &dyn Any {
39 self
40 }
41
42 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
43 let child_result = self.child.evaluate(batch)?;
44 invert(&child_result)
45 }
46
47 fn children(&self) -> Vec<&ExprRef> {
48 vec![&self.child]
49 }
50
51 fn replacing_children(self: Arc<Self>, mut children: Vec<ExprRef>) -> ExprRef {
52 assert_eq!(children.len(), 0);
53 Self::new_expr(children.remove(0))
54 }
55
56 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
57 self.child.return_dtype(scope_dtype)
58 }
59}
60
61impl PartialEq for Not {
62 fn eq(&self, other: &Not) -> bool {
63 other.child.eq(&self.child)
64 }
65}
66
67pub fn not(operand: ExprRef) -> ExprRef {
68 Not::new_expr(operand)
69}
70
71#[cfg(test)]
72mod tests {
73 use vortex_array::ToCanonical;
74 use vortex_array::arrays::BoolArray;
75 use vortex_dtype::{DType, Nullability};
76
77 use crate::{col, ident, not, test_harness};
78
79 #[test]
80 fn invert_booleans() {
81 let not_expr = not(ident());
82 let bools = BoolArray::from_iter([false, true, false, false, true, true]);
83 assert_eq!(
84 not_expr
85 .evaluate(&bools)
86 .unwrap()
87 .to_bool()
88 .unwrap()
89 .boolean_buffer()
90 .iter()
91 .collect::<Vec<_>>(),
92 vec![true, false, true, true, false, false]
93 );
94 }
95
96 #[test]
97 fn dtype() {
98 let not_expr = not(ident());
99 assert_eq!(
100 not_expr
101 .return_dtype(&DType::Bool(Nullability::NonNullable))
102 .unwrap(),
103 DType::Bool(Nullability::NonNullable)
104 );
105
106 let dtype = test_harness::struct_dtype();
107 assert_eq!(
108 not(col("bool1")).return_dtype(&dtype).unwrap(),
109 DType::Bool(Nullability::NonNullable)
110 );
111 }
112}