vortex_expr/exprs/
is_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Not;
5
6use vortex_array::arrays::{BoolArray, ConstantArray};
7use vortex_array::stats::Stat;
8use vortex_array::{Array, ArrayRef, DeserializeMetadata, EmptyMetadata, IntoArray};
9use vortex_dtype::{DType, Nullability};
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_mask::Mask;
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::{
15    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, eq, lit,
16    vtable,
17};
18
19vtable!(IsNull);
20
21#[allow(clippy::derived_hash_with_manual_eq)]
22#[derive(Clone, Debug, Hash, Eq)]
23pub struct IsNullExpr {
24    child: ExprRef,
25}
26
27impl PartialEq for IsNullExpr {
28    fn eq(&self, other: &Self) -> bool {
29        self.child.eq(&other.child)
30    }
31}
32
33pub struct IsNullExprEncoding;
34
35impl VTable for IsNullVTable {
36    type Expr = IsNullExpr;
37    type Encoding = IsNullExprEncoding;
38    type Metadata = EmptyMetadata;
39
40    fn id(_encoding: &Self::Encoding) -> ExprId {
41        ExprId::new_ref("is_null")
42    }
43
44    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
45        ExprEncodingRef::new_ref(IsNullExprEncoding.as_ref())
46    }
47
48    fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
49        Some(EmptyMetadata)
50    }
51
52    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
53        vec![&expr.child]
54    }
55
56    fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
57        Ok(IsNullExpr::new(children[0].clone()))
58    }
59
60    fn build(
61        _encoding: &Self::Encoding,
62        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
63        children: Vec<ExprRef>,
64    ) -> VortexResult<Self::Expr> {
65        if children.len() != 1 {
66            vortex_bail!("IsNull expects exactly one child, got {}", children.len());
67        }
68        Ok(IsNullExpr::new(children[0].clone()))
69    }
70
71    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
72        let array = expr.child.unchecked_evaluate(scope)?;
73        match array.validity_mask() {
74            Mask::AllTrue(len) => Ok(ConstantArray::new(false, len).into_array()),
75            Mask::AllFalse(len) => Ok(ConstantArray::new(true, len).into_array()),
76            Mask::Values(mask) => Ok(BoolArray::from(mask.boolean_buffer().not()).into_array()),
77        }
78    }
79
80    fn return_dtype(_expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
81        Ok(DType::Bool(Nullability::NonNullable))
82    }
83}
84
85impl IsNullExpr {
86    pub fn new(child: ExprRef) -> Self {
87        Self { child }
88    }
89
90    pub fn new_expr(child: ExprRef) -> ExprRef {
91        Self::new(child).into_expr()
92    }
93}
94
95impl DisplayAs for IsNullExpr {
96    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
97        match df {
98            DisplayFormat::Compact => {
99                write!(f, "is_null({})", self.child)
100            }
101            DisplayFormat::Tree => {
102                write!(f, "IsNull")
103            }
104        }
105    }
106}
107
108impl AnalysisExpr for IsNullExpr {
109    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
110        let field_path = self.child.field_path()?;
111        let null_count_expr = catalog.stats_ref(&field_path, Stat::NullCount)?;
112        Some(eq(null_count_expr, lit(0u64)))
113    }
114}
115
116/// Creates an expression that checks for null values.
117///
118/// Returns a boolean array indicating which positions contain null values.
119///
120/// ```rust
121/// # use vortex_expr::{is_null, root};
122/// let expr = is_null(root());
123/// ```
124pub fn is_null(child: ExprRef) -> ExprRef {
125    IsNullExpr::new(child).into_expr()
126}
127
128#[cfg(test)]
129mod tests {
130    use vortex_array::IntoArray;
131    use vortex_array::arrays::{PrimitiveArray, StructArray};
132    use vortex_array::stats::Stat;
133    use vortex_buffer::buffer;
134    use vortex_dtype::{DType, Field, FieldPath, FieldPathSet, Nullability};
135    use vortex_scalar::Scalar;
136    use vortex_utils::aliases::hash_map::HashMap;
137
138    use crate::is_null::is_null;
139    use crate::pruning::checked_pruning_expr;
140    use crate::{HashSet, Scope, col, eq, get_item, lit, root, test_harness};
141
142    #[test]
143    fn dtype() {
144        let dtype = test_harness::struct_dtype();
145        assert_eq!(
146            is_null(root()).return_dtype(&dtype).unwrap(),
147            DType::Bool(Nullability::NonNullable)
148        );
149    }
150
151    #[test]
152    fn replace_children() {
153        let expr = is_null(root());
154        let _ = expr.with_children(vec![root()]);
155    }
156
157    #[test]
158    fn evaluate_mask() {
159        let test_array =
160            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
161                .into_array();
162        let expected = [false, true, false, true, false];
163
164        let result = is_null(root())
165            .evaluate(&Scope::new(test_array.clone()))
166            .unwrap();
167
168        assert_eq!(result.len(), test_array.len());
169        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
170
171        for (i, expected_value) in expected.iter().enumerate() {
172            assert_eq!(
173                result.scalar_at(i),
174                Scalar::bool(*expected_value, Nullability::NonNullable)
175            );
176        }
177    }
178
179    #[test]
180    fn evaluate_all_false() {
181        let test_array = buffer![1, 2, 3, 4, 5].into_array();
182
183        let result = is_null(root())
184            .evaluate(&Scope::new(test_array.clone()))
185            .unwrap();
186
187        assert_eq!(result.len(), test_array.len());
188        assert_eq!(
189            result.as_constant().unwrap(),
190            Scalar::bool(false, Nullability::NonNullable)
191        );
192    }
193
194    #[test]
195    fn evaluate_all_true() {
196        let test_array =
197            PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
198                .into_array();
199
200        let result = is_null(root())
201            .evaluate(&Scope::new(test_array.clone()))
202            .unwrap();
203
204        assert_eq!(result.len(), test_array.len());
205        assert_eq!(
206            result.as_constant().unwrap(),
207            Scalar::bool(true, Nullability::NonNullable)
208        );
209    }
210
211    #[test]
212    fn evaluate_struct() {
213        let test_array = StructArray::from_fields(&[(
214            "a",
215            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
216                .into_array(),
217        )])
218        .unwrap()
219        .into_array();
220        let expected = [false, true, false, true, false];
221
222        let result = is_null(get_item("a", root()))
223            .evaluate(&Scope::new(test_array.clone()))
224            .unwrap();
225
226        assert_eq!(result.len(), test_array.len());
227        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
228
229        for (i, expected_value) in expected.iter().enumerate() {
230            assert_eq!(
231                result.scalar_at(i),
232                Scalar::bool(*expected_value, Nullability::NonNullable)
233            );
234        }
235    }
236
237    #[test]
238    fn test_display() {
239        let expr = is_null(get_item("name", root()));
240        assert_eq!(expr.to_string(), "is_null($.name)");
241
242        let expr2 = is_null(root());
243        assert_eq!(expr2.to_string(), "is_null($)");
244    }
245
246    #[test]
247    fn test_is_null_falsification() {
248        let expr = is_null(col("a"));
249
250        let (pruning_expr, st) = checked_pruning_expr(
251            &expr,
252            &FieldPathSet::from_iter([FieldPath::from_iter([
253                Field::Name("a".into()),
254                Field::Name("null_count".into()),
255            ])]),
256        )
257        .unwrap();
258
259        assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64)));
260        assert_eq!(
261            st.map(),
262            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))])
263        );
264    }
265}