1use std::any::Any;
2use std::fmt::Display;
3use std::ops::Not;
4use std::sync::Arc;
5
6use vortex_array::arrays::{BoolArray, ConstantArray};
7use vortex_array::{Array, ArrayRef, IntoArray};
8use vortex_dtype::{DType, Nullability};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_mask::Mask;
11
12use crate::{ExprRef, VortexExpr};
13
14#[derive(Debug, Eq, Hash)]
15#[allow(clippy::derived_hash_with_manual_eq)]
16pub struct IsNull {
17 child: ExprRef,
18}
19
20impl IsNull {
21 pub fn new_expr(child: ExprRef) -> ExprRef {
22 Arc::new(Self { child })
23 }
24}
25
26impl PartialEq for IsNull {
27 fn eq(&self, other: &Self) -> bool {
28 self.child.eq(&other.child)
29 }
30}
31
32impl Display for IsNull {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 write!(f, "is_null({})", self.child)
35 }
36}
37
38#[cfg(feature = "proto")]
39pub(crate) mod proto {
40 use vortex_error::{VortexResult, vortex_bail};
41 use vortex_proto::expr::kind;
42 use vortex_proto::expr::kind::Kind;
43
44 use crate::is_null::IsNull;
45 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
46
47 pub(crate) struct IsNullSerde;
48
49 impl Id for IsNullSerde {
50 fn id(&self) -> &'static str {
51 "is_null"
52 }
53 }
54
55 impl ExprDeserialize for IsNullSerde {
56 fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
57 let Kind::IsNull(kind::IsNull {}) = kind else {
58 vortex_bail!("wrong kind {:?}, want is_null", kind)
59 };
60
61 Ok(IsNull::new_expr(children[0].clone()))
62 }
63 }
64
65 impl ExprSerializable for IsNull {
66 fn id(&self) -> &'static str {
67 IsNullSerde.id()
68 }
69
70 fn serialize_kind(&self) -> VortexResult<Kind> {
71 Ok(Kind::IsNull(kind::IsNull {}))
72 }
73 }
74}
75
76impl VortexExpr for IsNull {
77 fn as_any(&self) -> &dyn Any {
78 self
79 }
80
81 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
82 let array = self.child.evaluate(batch)?;
83 match array.validity_mask()? {
84 Mask::AllTrue(len) => Ok(ConstantArray::new(false, len).into_array()),
85 Mask::AllFalse(len) => Ok(ConstantArray::new(true, len).into_array()),
86 Mask::Values(mask) => Ok(BoolArray::from(mask.boolean_buffer().not()).into_array()),
87 }
88 }
89
90 fn children(&self) -> Vec<&ExprRef> {
91 vec![&self.child]
92 }
93
94 fn replacing_children(self: Arc<Self>, mut children: Vec<ExprRef>) -> ExprRef {
95 Self::new_expr(
96 children
97 .pop()
98 .vortex_expect("IsNull::replacing_children should have one child"),
99 )
100 }
101
102 fn return_dtype(&self, _scope_dtype: &DType) -> VortexResult<DType> {
103 Ok(DType::Bool(Nullability::NonNullable))
104 }
105}
106
107pub fn is_null(child: ExprRef) -> ExprRef {
108 IsNull::new_expr(child)
109}
110
111#[cfg(test)]
112mod tests {
113 use vortex_array::IntoArray;
114 use vortex_array::arrays::{PrimitiveArray, StructArray};
115 use vortex_dtype::{DType, Nullability};
116 use vortex_scalar::Scalar;
117
118 use crate::is_null::is_null;
119 use crate::{get_item, ident, test_harness};
120
121 #[test]
122 fn dtype() {
123 let dtype = test_harness::struct_dtype();
124 assert_eq!(
125 is_null(ident()).return_dtype(&dtype).unwrap(),
126 DType::Bool(Nullability::NonNullable)
127 );
128 }
129
130 #[test]
131 fn replace_children() {
132 let expr = is_null(ident());
133 let _ = expr.replacing_children(vec![ident()]);
134 }
135
136 #[test]
137 fn evaluate_mask() {
138 let test_array =
139 PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
140 .into_array();
141 let expected = [false, true, false, true, false];
142
143 let result = is_null(ident()).unchecked_evaluate(&test_array).unwrap();
144
145 assert_eq!(result.len(), test_array.len());
146 assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
147
148 for (i, expected_value) in expected.iter().enumerate() {
149 assert_eq!(
150 result.scalar_at(i).unwrap(),
151 Scalar::bool(*expected_value, Nullability::NonNullable)
152 );
153 }
154 }
155
156 #[test]
157 fn evaluate_all_false() {
158 let test_array = PrimitiveArray::from_iter(vec![1, 2, 3, 4, 5]).into_array();
159
160 let result = is_null(ident()).unchecked_evaluate(&test_array).unwrap();
161
162 assert_eq!(result.len(), test_array.len());
163 assert_eq!(
164 result.as_constant().unwrap(),
165 Scalar::bool(false, Nullability::NonNullable)
166 );
167 }
168
169 #[test]
170 fn evaluate_all_true() {
171 let test_array =
172 PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
173 .into_array();
174
175 let result = is_null(ident()).unchecked_evaluate(&test_array).unwrap();
176
177 assert_eq!(result.len(), test_array.len());
178 assert_eq!(
179 result.as_constant().unwrap(),
180 Scalar::bool(true, Nullability::NonNullable)
181 );
182 }
183
184 #[test]
185 fn evaluate_struct() {
186 let test_array = StructArray::from_fields(&[(
187 "a",
188 PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
189 .into_array(),
190 )])
191 .unwrap()
192 .into_array();
193 let expected = [false, true, false, true, false];
194
195 let result = is_null(get_item("a", ident()))
196 .unchecked_evaluate(&test_array)
197 .unwrap();
198
199 assert_eq!(result.len(), test_array.len());
200 assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
201
202 for (i, expected_value) in expected.iter().enumerate() {
203 assert_eq!(
204 result.scalar_at(i).unwrap(),
205 Scalar::bool(*expected_value, Nullability::NonNullable)
206 );
207 }
208 }
209}