1use std::fmt::Formatter;
5
6use vortex_dtype::DType;
7use vortex_dtype::Nullability;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_session::VortexSession;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::arrays::ConstantArray;
15use crate::builtins::ArrayBuiltins;
16use crate::expr::Arity;
17use crate::expr::ChildName;
18use crate::expr::EmptyOptions;
19use crate::expr::ExecutionArgs;
20use crate::expr::ExprId;
21use crate::expr::Expression;
22use crate::expr::StatsCatalog;
23use crate::expr::VTable;
24use crate::expr::VTableExt;
25use crate::expr::exprs::binary::eq;
26use crate::expr::exprs::literal::lit;
27use crate::expr::stats::Stat;
28use crate::validity::Validity;
29
30pub struct IsNull;
32
33impl VTable for IsNull {
34 type Options = EmptyOptions;
35
36 fn id(&self) -> ExprId {
37 ExprId::new_ref("is_null")
38 }
39
40 fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
41 Ok(Some(vec![]))
42 }
43
44 fn deserialize(
45 &self,
46 _metadata: &[u8],
47 _session: &VortexSession,
48 ) -> VortexResult<Self::Options> {
49 Ok(EmptyOptions)
50 }
51
52 fn arity(&self, _options: &Self::Options) -> Arity {
53 Arity::Exact(1)
54 }
55
56 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
57 match child_idx {
58 0 => ChildName::from("input"),
59 _ => unreachable!("Invalid child index {} for IsNull expression", child_idx),
60 }
61 }
62
63 fn fmt_sql(
64 &self,
65 _options: &Self::Options,
66 expr: &Expression,
67 f: &mut Formatter<'_>,
68 ) -> std::fmt::Result {
69 write!(f, "is_null(")?;
70 expr.child(0).fmt_sql(f)?;
71 write!(f, ")")
72 }
73
74 fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
75 Ok(DType::Bool(Nullability::NonNullable))
76 }
77
78 fn execute(&self, _data: &Self::Options, mut args: ExecutionArgs) -> VortexResult<ArrayRef> {
79 let child = args.inputs.pop().vortex_expect("Missing input child");
80 if let Some(scalar) = child.as_constant() {
81 return Ok(ConstantArray::new(scalar.is_null(), args.row_count).into_array());
82 }
83
84 Ok(match child.validity()? {
85 Validity::NonNullable | Validity::AllValid => {
86 ConstantArray::new(false, args.row_count).into_array()
87 }
88 Validity::AllInvalid => ConstantArray::new(true, args.row_count).into_array(),
89 Validity::Array(a) => a.not()?.execute(args.ctx)?,
90 })
91 }
92
93 fn stat_falsification(
94 &self,
95 _options: &Self::Options,
96 expr: &Expression,
97 catalog: &dyn StatsCatalog,
98 ) -> Option<Expression> {
99 let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?;
100 Some(eq(null_count_expr, lit(0u64)))
101 }
102
103 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
104 true
105 }
106
107 fn is_fallible(&self, _instance: &Self::Options) -> bool {
108 false
109 }
110}
111
112pub fn is_null(child: Expression) -> Expression {
121 IsNull.new_expr(EmptyOptions, vec![child])
122}
123
124#[cfg(test)]
125mod tests {
126 use vortex_buffer::buffer;
127 use vortex_dtype::DType;
128 use vortex_dtype::Field;
129 use vortex_dtype::FieldPath;
130 use vortex_dtype::FieldPathSet;
131 use vortex_dtype::Nullability;
132 use vortex_error::VortexExpect as _;
133 use vortex_utils::aliases::hash_map::HashMap;
134 use vortex_utils::aliases::hash_set::HashSet;
135
136 use super::is_null;
137 use crate::IntoArray;
138 use crate::arrays::PrimitiveArray;
139 use crate::arrays::StructArray;
140 use crate::expr::exprs::binary::eq;
141 use crate::expr::exprs::get_item::col;
142 use crate::expr::exprs::get_item::get_item;
143 use crate::expr::exprs::literal::lit;
144 use crate::expr::exprs::root::root;
145 use crate::expr::pruning::checked_pruning_expr;
146 use crate::expr::stats::Stat;
147 use crate::expr::test_harness;
148 use crate::scalar::Scalar;
149
150 #[test]
151 fn dtype() {
152 let dtype = test_harness::struct_dtype();
153 assert_eq!(
154 is_null(root()).return_dtype(&dtype).unwrap(),
155 DType::Bool(Nullability::NonNullable)
156 );
157 }
158
159 #[test]
160 fn replace_children() {
161 let expr = is_null(root());
162 expr.with_children([root()])
163 .vortex_expect("operation should succeed in test");
164 }
165
166 #[test]
167 fn evaluate_mask() {
168 let test_array =
169 PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
170 .into_array();
171 let expected = [false, true, false, true, false];
172
173 let result = test_array.clone().apply(&is_null(root())).unwrap();
174
175 assert_eq!(result.len(), test_array.len());
176 assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
177
178 for (i, expected_value) in expected.iter().enumerate() {
179 assert_eq!(
180 result.scalar_at(i).unwrap(),
181 Scalar::bool(*expected_value, Nullability::NonNullable)
182 );
183 }
184 }
185
186 #[test]
187 fn evaluate_all_false() {
188 let test_array = buffer![1, 2, 3, 4, 5].into_array();
189
190 let result = test_array.clone().apply(&is_null(root())).unwrap();
191
192 assert_eq!(result.len(), test_array.len());
193 for i in 0..result.len() {
195 assert_eq!(
196 result.scalar_at(i).unwrap(),
197 Scalar::bool(false, Nullability::NonNullable)
198 );
199 }
200 }
201
202 #[test]
203 fn evaluate_all_true() {
204 let test_array =
205 PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
206 .into_array();
207
208 let result = test_array.clone().apply(&is_null(root())).unwrap();
209
210 assert_eq!(result.len(), test_array.len());
211 for i in 0..result.len() {
213 assert_eq!(
214 result.scalar_at(i).unwrap(),
215 Scalar::bool(true, Nullability::NonNullable)
216 );
217 }
218 }
219
220 #[test]
221 fn evaluate_struct() {
222 let test_array = StructArray::from_fields(&[(
223 "a",
224 PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
225 .into_array(),
226 )])
227 .unwrap()
228 .into_array();
229 let expected = [false, true, false, true, false];
230
231 let result = test_array
232 .clone()
233 .apply(&is_null(get_item("a", root())))
234 .unwrap();
235
236 assert_eq!(result.len(), test_array.len());
237 assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
238
239 for (i, expected_value) in expected.iter().enumerate() {
240 assert_eq!(
241 result.scalar_at(i).unwrap(),
242 Scalar::bool(*expected_value, Nullability::NonNullable)
243 );
244 }
245 }
246
247 #[test]
248 fn test_display() {
249 let expr = is_null(get_item("name", root()));
250 assert_eq!(expr.to_string(), "is_null($.name)");
251
252 let expr2 = is_null(root());
253 assert_eq!(expr2.to_string(), "is_null($)");
254 }
255
256 #[test]
257 fn test_is_null_falsification() {
258 let expr = is_null(col("a"));
259
260 let (pruning_expr, st) = checked_pruning_expr(
261 &expr,
262 &FieldPathSet::from_iter([FieldPath::from_iter([
263 Field::Name("a".into()),
264 Field::Name("null_count".into()),
265 ])]),
266 )
267 .unwrap();
268
269 assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64)));
270 assert_eq!(
271 st.map(),
272 &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))])
273 );
274 }
275
276 #[test]
277 fn test_is_null_sensitive() {
278 assert!(is_null(col("a")).signature().is_null_sensitive());
280 }
281}