vortex_array/expr/exprs/
is_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::ops::Not;
6
7use is_null::IsNullFn;
8use vortex_dtype::DType;
9use vortex_dtype::Nullability;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_mask::Mask;
14use vortex_vector::Vector;
15use vortex_vector::VectorOps;
16use vortex_vector::bool::BoolVector;
17
18use crate::Array;
19use crate::ArrayRef;
20use crate::IntoArray;
21use crate::arrays::BoolArray;
22use crate::arrays::ConstantArray;
23use crate::expr::ChildName;
24use crate::expr::ExecutionArgs;
25use crate::expr::ExprId;
26use crate::expr::Expression;
27use crate::expr::ExpressionView;
28use crate::expr::ScalarFnExprExt;
29use crate::expr::StatsCatalog;
30use crate::expr::VTable;
31use crate::expr::VTableExt;
32use crate::expr::exprs::binary::eq;
33use crate::expr::exprs::literal::lit;
34use crate::expr::functions::EmptyOptions;
35use crate::expr::stats::Stat;
36use crate::scalar_fns::is_null;
37
38/// Expression that checks for null values.
39pub struct IsNull;
40
41impl VTable for IsNull {
42    type Instance = ();
43
44    fn id(&self) -> ExprId {
45        ExprId::new_ref("is_null")
46    }
47
48    fn serialize(&self, _instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
49        Ok(Some(vec![]))
50    }
51
52    fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
53        Ok(Some(()))
54    }
55
56    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
57        if expr.children().len() != 1 {
58            vortex_bail!(
59                "IsNull expression expects exactly one child, got {}",
60                expr.children().len()
61            );
62        }
63        Ok(())
64    }
65
66    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
67        match child_idx {
68            0 => ChildName::from("input"),
69            _ => unreachable!("Invalid child index {} for IsNull expression", child_idx),
70        }
71    }
72
73    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
74        write!(f, "is_null(")?;
75        expr.child(0).fmt_sql(f)?;
76        write!(f, ")")
77    }
78
79    fn return_dtype(&self, _expr: &ExpressionView<Self>, _scope: &DType) -> VortexResult<DType> {
80        Ok(DType::Bool(Nullability::NonNullable))
81    }
82
83    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
84        let array = expr.child(0).evaluate(scope)?;
85        match array.validity_mask() {
86            Mask::AllTrue(len) => Ok(ConstantArray::new(false, len).into_array()),
87            Mask::AllFalse(len) => Ok(ConstantArray::new(true, len).into_array()),
88            Mask::Values(mask) => Ok(BoolArray::from(mask.bit_buffer().not()).into_array()),
89        }
90    }
91
92    fn stat_falsification(
93        &self,
94        expr: &ExpressionView<Self>,
95        catalog: &dyn StatsCatalog,
96    ) -> Option<Expression> {
97        let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?;
98        Some(eq(null_count_expr, lit(0u64)))
99    }
100
101    fn execute(&self, _data: &Self::Instance, mut args: ExecutionArgs) -> VortexResult<Vector> {
102        let child = args.vectors.pop().vortex_expect("Missing input child");
103        Ok(BoolVector::new(
104            child.validity().to_bit_buffer().not(),
105            Mask::new_true(child.len()),
106        )
107        .into())
108    }
109
110    fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
111        true
112    }
113
114    fn is_fallible(&self, _instance: &Self::Instance) -> bool {
115        false
116    }
117
118    fn expr_v2(&self, view: &ExpressionView<Self>) -> VortexResult<Expression> {
119        ScalarFnExprExt::try_new_expr(&IsNullFn, EmptyOptions, view.children().clone())
120    }
121}
122
123/// Creates an expression that checks for null values.
124///
125/// Returns a boolean array indicating which positions contain null values.
126///
127/// ```rust
128/// # use vortex_array::expr::{is_null, root};
129/// let expr = is_null(root());
130/// ```
131pub fn is_null(child: Expression) -> Expression {
132    IsNull.new_expr((), vec![child])
133}
134
135#[cfg(test)]
136mod tests {
137    use vortex_buffer::buffer;
138    use vortex_dtype::DType;
139    use vortex_dtype::Field;
140    use vortex_dtype::FieldPath;
141    use vortex_dtype::FieldPathSet;
142    use vortex_dtype::Nullability;
143    use vortex_error::VortexUnwrap as _;
144    use vortex_scalar::Scalar;
145    use vortex_utils::aliases::hash_map::HashMap;
146    use vortex_utils::aliases::hash_set::HashSet;
147
148    use super::is_null;
149    use crate::IntoArray;
150    use crate::arrays::PrimitiveArray;
151    use crate::arrays::StructArray;
152    use crate::expr::exprs::binary::eq;
153    use crate::expr::exprs::get_item::col;
154    use crate::expr::exprs::get_item::get_item;
155    use crate::expr::exprs::literal::lit;
156    use crate::expr::exprs::root::root;
157    use crate::expr::pruning::checked_pruning_expr;
158    use crate::expr::stats::Stat;
159    use crate::expr::test_harness;
160
161    #[test]
162    fn dtype() {
163        let dtype = test_harness::struct_dtype();
164        assert_eq!(
165            is_null(root()).return_dtype(&dtype).unwrap(),
166            DType::Bool(Nullability::NonNullable)
167        );
168    }
169
170    #[test]
171    fn replace_children() {
172        let expr = is_null(root());
173        expr.with_children([root()]).vortex_unwrap();
174    }
175
176    #[test]
177    fn evaluate_mask() {
178        let test_array =
179            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
180                .into_array();
181        let expected = [false, true, false, true, false];
182
183        let result = is_null(root()).evaluate(&test_array.clone()).unwrap();
184
185        assert_eq!(result.len(), test_array.len());
186        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
187
188        for (i, expected_value) in expected.iter().enumerate() {
189            assert_eq!(
190                result.scalar_at(i),
191                Scalar::bool(*expected_value, Nullability::NonNullable)
192            );
193        }
194    }
195
196    #[test]
197    fn evaluate_all_false() {
198        let test_array = buffer![1, 2, 3, 4, 5].into_array();
199
200        let result = is_null(root()).evaluate(&test_array.clone()).unwrap();
201
202        assert_eq!(result.len(), test_array.len());
203        assert_eq!(
204            result.as_constant().unwrap(),
205            Scalar::bool(false, Nullability::NonNullable)
206        );
207    }
208
209    #[test]
210    fn evaluate_all_true() {
211        let test_array =
212            PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
213                .into_array();
214
215        let result = is_null(root()).evaluate(&test_array.clone()).unwrap();
216
217        assert_eq!(result.len(), test_array.len());
218        assert_eq!(
219            result.as_constant().unwrap(),
220            Scalar::bool(true, Nullability::NonNullable)
221        );
222    }
223
224    #[test]
225    fn evaluate_struct() {
226        let test_array = StructArray::from_fields(&[(
227            "a",
228            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
229                .into_array(),
230        )])
231        .unwrap()
232        .into_array();
233        let expected = [false, true, false, true, false];
234
235        let result = is_null(get_item("a", root()))
236            .evaluate(&test_array.clone())
237            .unwrap();
238
239        assert_eq!(result.len(), test_array.len());
240        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
241
242        for (i, expected_value) in expected.iter().enumerate() {
243            assert_eq!(
244                result.scalar_at(i),
245                Scalar::bool(*expected_value, Nullability::NonNullable)
246            );
247        }
248    }
249
250    #[test]
251    fn test_display() {
252        let expr = is_null(get_item("name", root()));
253        assert_eq!(expr.to_string(), "is_null($.name)");
254
255        let expr2 = is_null(root());
256        assert_eq!(expr2.to_string(), "is_null($)");
257    }
258
259    #[test]
260    fn test_is_null_falsification() {
261        let expr = is_null(col("a"));
262
263        let (pruning_expr, st) = checked_pruning_expr(
264            &expr,
265            &FieldPathSet::from_iter([FieldPath::from_iter([
266                Field::Name("a".into()),
267                Field::Name("null_count".into()),
268            ])]),
269        )
270        .unwrap();
271
272        assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64)));
273        assert_eq!(
274            st.map(),
275            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))])
276        );
277    }
278
279    #[test]
280    fn test_is_null_sensitive() {
281        // is_null itself is null-sensitive
282        assert!(is_null(col("a")).is_null_sensitive());
283    }
284}