1use std::fmt::Formatter;
5use std::ops::Not;
6
7use vortex_dtype::{DType, Nullability};
8use vortex_error::{VortexResult, vortex_bail};
9use vortex_mask::Mask;
10
11use crate::arrays::{BoolArray, ConstantArray};
12use crate::expr::exprs::binary::eq;
13use crate::expr::exprs::literal::lit;
14use crate::expr::{ChildName, ExprId, Expression, ExpressionView, StatsCatalog, VTable, VTableExt};
15use crate::stats::Stat;
16use crate::{Array, ArrayRef, IntoArray};
17
18pub struct IsNull;
20
21impl VTable for IsNull {
22 type Instance = ();
23
24 fn id(&self) -> ExprId {
25 ExprId::new_ref("is_null")
26 }
27
28 fn serialize(&self, _instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
29 Ok(Some(vec![]))
30 }
31
32 fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
33 Ok(Some(()))
34 }
35
36 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
37 if expr.children().len() != 1 {
38 vortex_bail!(
39 "IsNull expression expects exactly one child, got {}",
40 expr.children().len()
41 );
42 }
43 Ok(())
44 }
45
46 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
47 match child_idx {
48 0 => ChildName::from("input"),
49 _ => unreachable!("Invalid child index {} for IsNull expression", child_idx),
50 }
51 }
52
53 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
54 write!(f, "is_null(")?;
55 expr.child(0).fmt_sql(f)?;
56 write!(f, ")")
57 }
58
59 fn return_dtype(&self, _expr: &ExpressionView<Self>, _scope: &DType) -> VortexResult<DType> {
60 Ok(DType::Bool(Nullability::NonNullable))
61 }
62
63 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
64 let array = expr.child(0).evaluate(scope)?;
65 match array.validity_mask() {
66 Mask::AllTrue(len) => Ok(ConstantArray::new(false, len).into_array()),
67 Mask::AllFalse(len) => Ok(ConstantArray::new(true, len).into_array()),
68 Mask::Values(mask) => Ok(BoolArray::from(mask.bit_buffer().not()).into_array()),
69 }
70 }
71
72 fn stat_falsification(
73 &self,
74 expr: &ExpressionView<Self>,
75 catalog: &mut dyn StatsCatalog,
76 ) -> Option<Expression> {
77 let field_path = expr.children()[0].stat_field_path()?;
78 let null_count_expr = catalog.stats_ref(&field_path, Stat::NullCount)?;
79 Some(eq(null_count_expr, lit(0u64)))
80 }
81}
82
83pub fn is_null(child: Expression) -> Expression {
92 IsNull.new_expr((), vec![child])
93}
94
95#[cfg(test)]
96mod tests {
97 use vortex_buffer::buffer;
98 use vortex_dtype::{DType, Field, FieldPath, FieldPathSet, Nullability};
99 use vortex_error::VortexUnwrap as _;
100 use vortex_scalar::Scalar;
101 use vortex_utils::aliases::hash_map::HashMap;
102 use vortex_utils::aliases::hash_set::HashSet;
103
104 use super::is_null;
105 use crate::IntoArray;
106 use crate::arrays::{PrimitiveArray, StructArray};
107 use crate::expr::exprs::binary::eq;
108 use crate::expr::exprs::get_item::{col, get_item};
109 use crate::expr::exprs::literal::lit;
110 use crate::expr::exprs::root::root;
111 use crate::expr::pruning::checked_pruning_expr;
112 use crate::expr::test_harness;
113 use crate::stats::Stat;
114
115 #[test]
116 fn dtype() {
117 let dtype = test_harness::struct_dtype();
118 assert_eq!(
119 is_null(root()).return_dtype(&dtype).unwrap(),
120 DType::Bool(Nullability::NonNullable)
121 );
122 }
123
124 #[test]
125 fn replace_children() {
126 let expr = is_null(root());
127 expr.with_children([root()]).vortex_unwrap();
128 }
129
130 #[test]
131 fn evaluate_mask() {
132 let test_array =
133 PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
134 .into_array();
135 let expected = [false, true, false, true, false];
136
137 let result = is_null(root()).evaluate(&test_array.clone()).unwrap();
138
139 assert_eq!(result.len(), test_array.len());
140 assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
141
142 for (i, expected_value) in expected.iter().enumerate() {
143 assert_eq!(
144 result.scalar_at(i),
145 Scalar::bool(*expected_value, Nullability::NonNullable)
146 );
147 }
148 }
149
150 #[test]
151 fn evaluate_all_false() {
152 let test_array = buffer![1, 2, 3, 4, 5].into_array();
153
154 let result = is_null(root()).evaluate(&test_array.clone()).unwrap();
155
156 assert_eq!(result.len(), test_array.len());
157 assert_eq!(
158 result.as_constant().unwrap(),
159 Scalar::bool(false, Nullability::NonNullable)
160 );
161 }
162
163 #[test]
164 fn evaluate_all_true() {
165 let test_array =
166 PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
167 .into_array();
168
169 let result = is_null(root()).evaluate(&test_array.clone()).unwrap();
170
171 assert_eq!(result.len(), test_array.len());
172 assert_eq!(
173 result.as_constant().unwrap(),
174 Scalar::bool(true, Nullability::NonNullable)
175 );
176 }
177
178 #[test]
179 fn evaluate_struct() {
180 let test_array = StructArray::from_fields(&[(
181 "a",
182 PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
183 .into_array(),
184 )])
185 .unwrap()
186 .into_array();
187 let expected = [false, true, false, true, false];
188
189 let result = is_null(get_item("a", root()))
190 .evaluate(&test_array.clone())
191 .unwrap();
192
193 assert_eq!(result.len(), test_array.len());
194 assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
195
196 for (i, expected_value) in expected.iter().enumerate() {
197 assert_eq!(
198 result.scalar_at(i),
199 Scalar::bool(*expected_value, Nullability::NonNullable)
200 );
201 }
202 }
203
204 #[test]
205 fn test_display() {
206 let expr = is_null(get_item("name", root()));
207 assert_eq!(expr.to_string(), "is_null($.name)");
208
209 let expr2 = is_null(root());
210 assert_eq!(expr2.to_string(), "is_null($)");
211 }
212
213 #[test]
214 fn test_is_null_falsification() {
215 let expr = is_null(col("a"));
216
217 let (pruning_expr, st) = checked_pruning_expr(
218 &expr,
219 &FieldPathSet::from_iter([FieldPath::from_iter([
220 Field::Name("a".into()),
221 Field::Name("null_count".into()),
222 ])]),
223 )
224 .unwrap();
225
226 assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64)));
227 assert_eq!(
228 st.map(),
229 &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))])
230 );
231 }
232}