Skip to main content

vortex_array/scalar_fn/fns/
is_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use vortex_error::VortexResult;
7use vortex_session::VortexSession;
8
9use crate::ArrayRef;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::arrays::ConstantArray;
13use crate::builtins::ArrayBuiltins;
14use crate::dtype::DType;
15use crate::dtype::Nullability;
16use crate::expr::Expression;
17use crate::expr::StatsCatalog;
18use crate::expr::eq;
19use crate::expr::lit;
20use crate::expr::stats::Stat;
21use crate::scalar_fn::Arity;
22use crate::scalar_fn::ChildName;
23use crate::scalar_fn::EmptyOptions;
24use crate::scalar_fn::ExecutionArgs;
25use crate::scalar_fn::ScalarFnId;
26use crate::scalar_fn::ScalarFnVTable;
27use crate::validity::Validity;
28
29/// Expression that checks for null values.
30#[derive(Clone)]
31pub struct IsNull;
32
33impl ScalarFnVTable for IsNull {
34    type Options = EmptyOptions;
35
36    fn id(&self) -> ScalarFnId {
37        ScalarFnId::new_ref("is_null")
38    }
39
40    fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
41        Ok(Some(vec![]))
42    }
43
44    fn deserialize(
45        &self,
46        _metadata: &[u8],
47        _session: &VortexSession,
48    ) -> VortexResult<Self::Options> {
49        Ok(EmptyOptions)
50    }
51
52    fn arity(&self, _options: &Self::Options) -> Arity {
53        Arity::Exact(1)
54    }
55
56    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
57        match child_idx {
58            0 => ChildName::from("input"),
59            _ => unreachable!("Invalid child index {} for IsNull expression", child_idx),
60        }
61    }
62
63    fn fmt_sql(
64        &self,
65        _options: &Self::Options,
66        expr: &Expression,
67        f: &mut Formatter<'_>,
68    ) -> std::fmt::Result {
69        write!(f, "is_null(")?;
70        expr.child(0).fmt_sql(f)?;
71        write!(f, ")")
72    }
73
74    fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
75        Ok(DType::Bool(Nullability::NonNullable))
76    }
77
78    fn execute(
79        &self,
80        _data: &Self::Options,
81        args: &dyn ExecutionArgs,
82        _ctx: &mut ExecutionCtx,
83    ) -> VortexResult<ArrayRef> {
84        let child = args.get(0)?;
85        if let Some(scalar) = child.as_constant() {
86            return Ok(ConstantArray::new(scalar.is_null(), args.row_count()).into_array());
87        }
88
89        match child.validity()? {
90            Validity::NonNullable | Validity::AllValid => {
91                Ok(ConstantArray::new(false, args.row_count()).into_array())
92            }
93            Validity::AllInvalid => Ok(ConstantArray::new(true, args.row_count()).into_array()),
94            Validity::Array(a) => a.not(),
95        }
96    }
97
98    fn stat_falsification(
99        &self,
100        _options: &Self::Options,
101        expr: &Expression,
102        catalog: &dyn StatsCatalog,
103    ) -> Option<Expression> {
104        let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?;
105        Some(eq(null_count_expr, lit(0u64)))
106    }
107
108    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
109        true
110    }
111
112    fn is_fallible(&self, _instance: &Self::Options) -> bool {
113        false
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use vortex_buffer::buffer;
120    use vortex_error::VortexExpect as _;
121    use vortex_utils::aliases::hash_map::HashMap;
122    use vortex_utils::aliases::hash_set::HashSet;
123
124    use crate::IntoArray;
125    use crate::arrays::PrimitiveArray;
126    use crate::arrays::StructArray;
127    use crate::dtype::DType;
128    use crate::dtype::Field;
129    use crate::dtype::FieldPath;
130    use crate::dtype::FieldPathSet;
131    use crate::dtype::Nullability;
132    use crate::expr::col;
133    use crate::expr::eq;
134    use crate::expr::get_item;
135    use crate::expr::is_null;
136    use crate::expr::lit;
137    use crate::expr::pruning::checked_pruning_expr;
138    use crate::expr::root;
139    use crate::expr::stats::Stat;
140    use crate::expr::test_harness;
141    use crate::scalar::Scalar;
142
143    #[test]
144    fn dtype() {
145        let dtype = test_harness::struct_dtype();
146        assert_eq!(
147            is_null(root()).return_dtype(&dtype).unwrap(),
148            DType::Bool(Nullability::NonNullable)
149        );
150    }
151
152    #[test]
153    fn replace_children() {
154        let expr = is_null(root());
155        expr.with_children([root()])
156            .vortex_expect("operation should succeed in test");
157    }
158
159    #[test]
160    fn evaluate_mask() {
161        let test_array =
162            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
163                .into_array();
164        let expected = [false, true, false, true, false];
165
166        let result = test_array.clone().apply(&is_null(root())).unwrap();
167
168        assert_eq!(result.len(), test_array.len());
169        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
170
171        for (i, expected_value) in expected.iter().enumerate() {
172            assert_eq!(
173                result.scalar_at(i).unwrap(),
174                Scalar::bool(*expected_value, Nullability::NonNullable)
175            );
176        }
177    }
178
179    #[test]
180    fn evaluate_all_false() {
181        let test_array = buffer![1, 2, 3, 4, 5].into_array();
182
183        let result = test_array.clone().apply(&is_null(root())).unwrap();
184
185        assert_eq!(result.len(), test_array.len());
186        // All values should be false (non-nullable input)
187        for i in 0..result.len() {
188            assert_eq!(
189                result.scalar_at(i).unwrap(),
190                Scalar::bool(false, Nullability::NonNullable)
191            );
192        }
193    }
194
195    #[test]
196    fn evaluate_all_true() {
197        let test_array =
198            PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
199                .into_array();
200
201        let result = test_array.clone().apply(&is_null(root())).unwrap();
202
203        assert_eq!(result.len(), test_array.len());
204        // All values should be true (all nulls)
205        for i in 0..result.len() {
206            assert_eq!(
207                result.scalar_at(i).unwrap(),
208                Scalar::bool(true, Nullability::NonNullable)
209            );
210        }
211    }
212
213    #[test]
214    fn evaluate_struct() {
215        let test_array = StructArray::from_fields(&[(
216            "a",
217            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
218                .into_array(),
219        )])
220        .unwrap()
221        .into_array();
222        let expected = [false, true, false, true, false];
223
224        let result = test_array
225            .clone()
226            .apply(&is_null(get_item("a", root())))
227            .unwrap();
228
229        assert_eq!(result.len(), test_array.len());
230        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
231
232        for (i, expected_value) in expected.iter().enumerate() {
233            assert_eq!(
234                result.scalar_at(i).unwrap(),
235                Scalar::bool(*expected_value, Nullability::NonNullable)
236            );
237        }
238    }
239
240    #[test]
241    fn test_display() {
242        let expr = is_null(get_item("name", root()));
243        assert_eq!(expr.to_string(), "is_null($.name)");
244
245        let expr2 = is_null(root());
246        assert_eq!(expr2.to_string(), "is_null($)");
247    }
248
249    #[test]
250    fn test_is_null_falsification() {
251        let expr = is_null(col("a"));
252
253        let (pruning_expr, st) = checked_pruning_expr(
254            &expr,
255            &FieldPathSet::from_iter([FieldPath::from_iter([
256                Field::Name("a".into()),
257                Field::Name("null_count".into()),
258            ])]),
259        )
260        .unwrap();
261
262        assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64)));
263        assert_eq!(
264            st.map(),
265            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))])
266        );
267    }
268
269    #[test]
270    fn test_is_null_sensitive() {
271        // is_null itself is null-sensitive
272        assert!(is_null(col("a")).signature().is_null_sensitive());
273    }
274}