Skip to main content

vortex_array/expr/exprs/
is_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use vortex_dtype::DType;
7use vortex_dtype::Nullability;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_session::VortexSession;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::arrays::ConstantArray;
15use crate::builtins::ArrayBuiltins;
16use crate::expr::Arity;
17use crate::expr::ChildName;
18use crate::expr::EmptyOptions;
19use crate::expr::ExecutionArgs;
20use crate::expr::ExprId;
21use crate::expr::Expression;
22use crate::expr::StatsCatalog;
23use crate::expr::VTable;
24use crate::expr::VTableExt;
25use crate::expr::exprs::binary::eq;
26use crate::expr::exprs::literal::lit;
27use crate::expr::stats::Stat;
28use crate::validity::Validity;
29
30/// Expression that checks for null values.
31pub struct IsNull;
32
33impl VTable for IsNull {
34    type Options = EmptyOptions;
35
36    fn id(&self) -> ExprId {
37        ExprId::new_ref("is_null")
38    }
39
40    fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
41        Ok(Some(vec![]))
42    }
43
44    fn deserialize(
45        &self,
46        _metadata: &[u8],
47        _session: &VortexSession,
48    ) -> VortexResult<Self::Options> {
49        Ok(EmptyOptions)
50    }
51
52    fn arity(&self, _options: &Self::Options) -> Arity {
53        Arity::Exact(1)
54    }
55
56    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
57        match child_idx {
58            0 => ChildName::from("input"),
59            _ => unreachable!("Invalid child index {} for IsNull expression", child_idx),
60        }
61    }
62
63    fn fmt_sql(
64        &self,
65        _options: &Self::Options,
66        expr: &Expression,
67        f: &mut Formatter<'_>,
68    ) -> std::fmt::Result {
69        write!(f, "is_null(")?;
70        expr.child(0).fmt_sql(f)?;
71        write!(f, ")")
72    }
73
74    fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
75        Ok(DType::Bool(Nullability::NonNullable))
76    }
77
78    fn execute(&self, _data: &Self::Options, mut args: ExecutionArgs) -> VortexResult<ArrayRef> {
79        let child = args.inputs.pop().vortex_expect("Missing input child");
80        if let Some(scalar) = child.as_constant() {
81            return Ok(ConstantArray::new(scalar.is_null(), args.row_count).into_array());
82        }
83
84        Ok(match child.validity()? {
85            Validity::NonNullable | Validity::AllValid => {
86                ConstantArray::new(false, args.row_count).into_array()
87            }
88            Validity::AllInvalid => ConstantArray::new(true, args.row_count).into_array(),
89            Validity::Array(a) => a.not()?.execute(args.ctx)?,
90        })
91    }
92
93    fn stat_falsification(
94        &self,
95        _options: &Self::Options,
96        expr: &Expression,
97        catalog: &dyn StatsCatalog,
98    ) -> Option<Expression> {
99        let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?;
100        Some(eq(null_count_expr, lit(0u64)))
101    }
102
103    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
104        true
105    }
106
107    fn is_fallible(&self, _instance: &Self::Options) -> bool {
108        false
109    }
110}
111
112/// Creates an expression that checks for null values.
113///
114/// Returns a boolean array indicating which positions contain null values.
115///
116/// ```rust
117/// # use vortex_array::expr::{is_null, root};
118/// let expr = is_null(root());
119/// ```
120pub fn is_null(child: Expression) -> Expression {
121    IsNull.new_expr(EmptyOptions, vec![child])
122}
123
124#[cfg(test)]
125mod tests {
126    use vortex_buffer::buffer;
127    use vortex_dtype::DType;
128    use vortex_dtype::Field;
129    use vortex_dtype::FieldPath;
130    use vortex_dtype::FieldPathSet;
131    use vortex_dtype::Nullability;
132    use vortex_error::VortexExpect as _;
133    use vortex_utils::aliases::hash_map::HashMap;
134    use vortex_utils::aliases::hash_set::HashSet;
135
136    use super::is_null;
137    use crate::IntoArray;
138    use crate::arrays::PrimitiveArray;
139    use crate::arrays::StructArray;
140    use crate::expr::exprs::binary::eq;
141    use crate::expr::exprs::get_item::col;
142    use crate::expr::exprs::get_item::get_item;
143    use crate::expr::exprs::literal::lit;
144    use crate::expr::exprs::root::root;
145    use crate::expr::pruning::checked_pruning_expr;
146    use crate::expr::stats::Stat;
147    use crate::expr::test_harness;
148    use crate::scalar::Scalar;
149
150    #[test]
151    fn dtype() {
152        let dtype = test_harness::struct_dtype();
153        assert_eq!(
154            is_null(root()).return_dtype(&dtype).unwrap(),
155            DType::Bool(Nullability::NonNullable)
156        );
157    }
158
159    #[test]
160    fn replace_children() {
161        let expr = is_null(root());
162        expr.with_children([root()])
163            .vortex_expect("operation should succeed in test");
164    }
165
166    #[test]
167    fn evaluate_mask() {
168        let test_array =
169            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
170                .into_array();
171        let expected = [false, true, false, true, false];
172
173        let result = test_array.clone().apply(&is_null(root())).unwrap();
174
175        assert_eq!(result.len(), test_array.len());
176        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
177
178        for (i, expected_value) in expected.iter().enumerate() {
179            assert_eq!(
180                result.scalar_at(i).unwrap(),
181                Scalar::bool(*expected_value, Nullability::NonNullable)
182            );
183        }
184    }
185
186    #[test]
187    fn evaluate_all_false() {
188        let test_array = buffer![1, 2, 3, 4, 5].into_array();
189
190        let result = test_array.clone().apply(&is_null(root())).unwrap();
191
192        assert_eq!(result.len(), test_array.len());
193        // All values should be false (non-nullable input)
194        for i in 0..result.len() {
195            assert_eq!(
196                result.scalar_at(i).unwrap(),
197                Scalar::bool(false, Nullability::NonNullable)
198            );
199        }
200    }
201
202    #[test]
203    fn evaluate_all_true() {
204        let test_array =
205            PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
206                .into_array();
207
208        let result = test_array.clone().apply(&is_null(root())).unwrap();
209
210        assert_eq!(result.len(), test_array.len());
211        // All values should be true (all nulls)
212        for i in 0..result.len() {
213            assert_eq!(
214                result.scalar_at(i).unwrap(),
215                Scalar::bool(true, Nullability::NonNullable)
216            );
217        }
218    }
219
220    #[test]
221    fn evaluate_struct() {
222        let test_array = StructArray::from_fields(&[(
223            "a",
224            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
225                .into_array(),
226        )])
227        .unwrap()
228        .into_array();
229        let expected = [false, true, false, true, false];
230
231        let result = test_array
232            .clone()
233            .apply(&is_null(get_item("a", root())))
234            .unwrap();
235
236        assert_eq!(result.len(), test_array.len());
237        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
238
239        for (i, expected_value) in expected.iter().enumerate() {
240            assert_eq!(
241                result.scalar_at(i).unwrap(),
242                Scalar::bool(*expected_value, Nullability::NonNullable)
243            );
244        }
245    }
246
247    #[test]
248    fn test_display() {
249        let expr = is_null(get_item("name", root()));
250        assert_eq!(expr.to_string(), "is_null($.name)");
251
252        let expr2 = is_null(root());
253        assert_eq!(expr2.to_string(), "is_null($)");
254    }
255
256    #[test]
257    fn test_is_null_falsification() {
258        let expr = is_null(col("a"));
259
260        let (pruning_expr, st) = checked_pruning_expr(
261            &expr,
262            &FieldPathSet::from_iter([FieldPath::from_iter([
263                Field::Name("a".into()),
264                Field::Name("null_count".into()),
265            ])]),
266        )
267        .unwrap();
268
269        assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64)));
270        assert_eq!(
271            st.map(),
272            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))])
273        );
274    }
275
276    #[test]
277    fn test_is_null_sensitive() {
278        // is_null itself is null-sensitive
279        assert!(is_null(col("a")).signature().is_null_sensitive());
280    }
281}