vortex_array/expr/exprs/
is_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::ops::Not;
6
7use vortex_dtype::{DType, Nullability};
8use vortex_error::{VortexResult, vortex_bail};
9use vortex_mask::Mask;
10
11use crate::arrays::{BoolArray, ConstantArray};
12use crate::expr::exprs::binary::eq;
13use crate::expr::exprs::literal::lit;
14use crate::expr::{ChildName, ExprId, Expression, ExpressionView, StatsCatalog, VTable, VTableExt};
15use crate::stats::Stat;
16use crate::{Array, ArrayRef, IntoArray};
17
18/// Expression that checks for null values.
19pub struct IsNull;
20
21impl VTable for IsNull {
22    type Instance = ();
23
24    fn id(&self) -> ExprId {
25        ExprId::new_ref("is_null")
26    }
27
28    fn serialize(&self, _instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
29        Ok(Some(vec![]))
30    }
31
32    fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
33        Ok(Some(()))
34    }
35
36    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
37        if expr.children().len() != 1 {
38            vortex_bail!(
39                "IsNull expression expects exactly one child, got {}",
40                expr.children().len()
41            );
42        }
43        Ok(())
44    }
45
46    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
47        match child_idx {
48            0 => ChildName::from("input"),
49            _ => unreachable!("Invalid child index {} for IsNull expression", child_idx),
50        }
51    }
52
53    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
54        write!(f, "is_null(")?;
55        expr.child(0).fmt_sql(f)?;
56        write!(f, ")")
57    }
58
59    fn return_dtype(&self, _expr: &ExpressionView<Self>, _scope: &DType) -> VortexResult<DType> {
60        Ok(DType::Bool(Nullability::NonNullable))
61    }
62
63    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
64        let array = expr.child(0).evaluate(scope)?;
65        match array.validity_mask() {
66            Mask::AllTrue(len) => Ok(ConstantArray::new(false, len).into_array()),
67            Mask::AllFalse(len) => Ok(ConstantArray::new(true, len).into_array()),
68            Mask::Values(mask) => Ok(BoolArray::from(mask.bit_buffer().not()).into_array()),
69        }
70    }
71
72    fn stat_falsification(
73        &self,
74        expr: &ExpressionView<Self>,
75        catalog: &mut dyn StatsCatalog,
76    ) -> Option<Expression> {
77        let field_path = expr.children()[0].stat_field_path()?;
78        let null_count_expr = catalog.stats_ref(&field_path, Stat::NullCount)?;
79        Some(eq(null_count_expr, lit(0u64)))
80    }
81}
82
83/// Creates an expression that checks for null values.
84///
85/// Returns a boolean array indicating which positions contain null values.
86///
87/// ```rust
88/// # use vortex_array::expr::{is_null, root};
89/// let expr = is_null(root());
90/// ```
91pub fn is_null(child: Expression) -> Expression {
92    IsNull.new_expr((), vec![child])
93}
94
95#[cfg(test)]
96mod tests {
97    use vortex_buffer::buffer;
98    use vortex_dtype::{DType, Field, FieldPath, FieldPathSet, Nullability};
99    use vortex_error::VortexUnwrap as _;
100    use vortex_scalar::Scalar;
101    use vortex_utils::aliases::hash_map::HashMap;
102    use vortex_utils::aliases::hash_set::HashSet;
103
104    use super::is_null;
105    use crate::IntoArray;
106    use crate::arrays::{PrimitiveArray, StructArray};
107    use crate::expr::exprs::binary::eq;
108    use crate::expr::exprs::get_item::{col, get_item};
109    use crate::expr::exprs::literal::lit;
110    use crate::expr::exprs::root::root;
111    use crate::expr::pruning::checked_pruning_expr;
112    use crate::expr::test_harness;
113    use crate::stats::Stat;
114
115    #[test]
116    fn dtype() {
117        let dtype = test_harness::struct_dtype();
118        assert_eq!(
119            is_null(root()).return_dtype(&dtype).unwrap(),
120            DType::Bool(Nullability::NonNullable)
121        );
122    }
123
124    #[test]
125    fn replace_children() {
126        let expr = is_null(root());
127        expr.with_children([root()]).vortex_unwrap();
128    }
129
130    #[test]
131    fn evaluate_mask() {
132        let test_array =
133            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
134                .into_array();
135        let expected = [false, true, false, true, false];
136
137        let result = is_null(root()).evaluate(&test_array.clone()).unwrap();
138
139        assert_eq!(result.len(), test_array.len());
140        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
141
142        for (i, expected_value) in expected.iter().enumerate() {
143            assert_eq!(
144                result.scalar_at(i),
145                Scalar::bool(*expected_value, Nullability::NonNullable)
146            );
147        }
148    }
149
150    #[test]
151    fn evaluate_all_false() {
152        let test_array = buffer![1, 2, 3, 4, 5].into_array();
153
154        let result = is_null(root()).evaluate(&test_array.clone()).unwrap();
155
156        assert_eq!(result.len(), test_array.len());
157        assert_eq!(
158            result.as_constant().unwrap(),
159            Scalar::bool(false, Nullability::NonNullable)
160        );
161    }
162
163    #[test]
164    fn evaluate_all_true() {
165        let test_array =
166            PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
167                .into_array();
168
169        let result = is_null(root()).evaluate(&test_array.clone()).unwrap();
170
171        assert_eq!(result.len(), test_array.len());
172        assert_eq!(
173            result.as_constant().unwrap(),
174            Scalar::bool(true, Nullability::NonNullable)
175        );
176    }
177
178    #[test]
179    fn evaluate_struct() {
180        let test_array = StructArray::from_fields(&[(
181            "a",
182            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
183                .into_array(),
184        )])
185        .unwrap()
186        .into_array();
187        let expected = [false, true, false, true, false];
188
189        let result = is_null(get_item("a", root()))
190            .evaluate(&test_array.clone())
191            .unwrap();
192
193        assert_eq!(result.len(), test_array.len());
194        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
195
196        for (i, expected_value) in expected.iter().enumerate() {
197            assert_eq!(
198                result.scalar_at(i),
199                Scalar::bool(*expected_value, Nullability::NonNullable)
200            );
201        }
202    }
203
204    #[test]
205    fn test_display() {
206        let expr = is_null(get_item("name", root()));
207        assert_eq!(expr.to_string(), "is_null($.name)");
208
209        let expr2 = is_null(root());
210        assert_eq!(expr2.to_string(), "is_null($)");
211    }
212
213    #[test]
214    fn test_is_null_falsification() {
215        let expr = is_null(col("a"));
216
217        let (pruning_expr, st) = checked_pruning_expr(
218            &expr,
219            &FieldPathSet::from_iter([FieldPath::from_iter([
220                Field::Name("a".into()),
221                Field::Name("null_count".into()),
222            ])]),
223        )
224        .unwrap();
225
226        assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64)));
227        assert_eq!(
228            st.map(),
229            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))])
230        );
231    }
232}