Skip to main content

vortex_array/scalar_fn/fns/
is_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_session::VortexSession;
6use vortex_session::registry::CachedId;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::arrays::ConstantArray;
12use crate::builtins::ArrayBuiltins;
13use crate::dtype::DType;
14use crate::dtype::Nullability;
15use crate::expr::Expression;
16use crate::expr::StatsCatalog;
17use crate::expr::eq;
18use crate::expr::lit;
19use crate::expr::stats::Stat;
20use crate::scalar_fn::Arity;
21use crate::scalar_fn::ChildName;
22use crate::scalar_fn::EmptyOptions;
23use crate::scalar_fn::ExecutionArgs;
24use crate::scalar_fn::ScalarFnId;
25use crate::scalar_fn::ScalarFnVTable;
26use crate::validity::Validity;
27
28/// Expression that checks for null values.
29#[derive(Clone)]
30pub struct IsNull;
31
32impl ScalarFnVTable for IsNull {
33    type Options = EmptyOptions;
34
35    fn id(&self) -> ScalarFnId {
36        static ID: CachedId = CachedId::new("vortex.is_null");
37        *ID
38    }
39
40    fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
41        Ok(Some(vec![]))
42    }
43
44    fn deserialize(
45        &self,
46        _metadata: &[u8],
47        _session: &VortexSession,
48    ) -> VortexResult<Self::Options> {
49        Ok(EmptyOptions)
50    }
51
52    fn arity(&self, _options: &Self::Options) -> Arity {
53        Arity::Exact(1)
54    }
55
56    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
57        match child_idx {
58            0 => ChildName::from("input"),
59            _ => unreachable!("Invalid child index {} for IsNull expression", child_idx),
60        }
61    }
62
63    fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
64        Ok(DType::Bool(Nullability::NonNullable))
65    }
66
67    fn execute(
68        &self,
69        _data: &Self::Options,
70        args: &dyn ExecutionArgs,
71        _ctx: &mut ExecutionCtx,
72    ) -> VortexResult<ArrayRef> {
73        let child = args.get(0)?;
74        if let Some(scalar) = child.as_constant() {
75            return Ok(ConstantArray::new(scalar.is_null(), args.row_count()).into_array());
76        }
77
78        match child.validity()? {
79            Validity::NonNullable | Validity::AllValid => {
80                Ok(ConstantArray::new(false, args.row_count()).into_array())
81            }
82            Validity::AllInvalid => Ok(ConstantArray::new(true, args.row_count()).into_array()),
83            Validity::Array(a) => a.not(),
84        }
85    }
86
87    fn stat_falsification(
88        &self,
89        _options: &Self::Options,
90        expr: &Expression,
91        catalog: &dyn StatsCatalog,
92    ) -> Option<Expression> {
93        let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?;
94        Some(eq(null_count_expr, lit(0u64)))
95    }
96
97    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
98        true
99    }
100
101    fn is_fallible(&self, _instance: &Self::Options) -> bool {
102        false
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use vortex_buffer::buffer;
109    use vortex_error::VortexExpect as _;
110    use vortex_utils::aliases::hash_map::HashMap;
111    use vortex_utils::aliases::hash_set::HashSet;
112
113    use crate::IntoArray;
114    use crate::LEGACY_SESSION;
115    use crate::VortexSessionExecute;
116    use crate::arrays::PrimitiveArray;
117    use crate::arrays::StructArray;
118    use crate::dtype::DType;
119    use crate::dtype::Field;
120    use crate::dtype::FieldPath;
121    use crate::dtype::FieldPathSet;
122    use crate::dtype::Nullability;
123    use crate::expr::col;
124    use crate::expr::eq;
125    use crate::expr::get_item;
126    use crate::expr::is_null;
127    use crate::expr::lit;
128    use crate::expr::pruning::checked_pruning_expr;
129    use crate::expr::root;
130    use crate::expr::stats::Stat;
131    use crate::expr::test_harness;
132    use crate::scalar::Scalar;
133
134    #[test]
135    fn dtype() {
136        let dtype = test_harness::struct_dtype();
137        assert_eq!(
138            is_null(root()).return_dtype(&dtype).unwrap(),
139            DType::Bool(Nullability::NonNullable)
140        );
141    }
142
143    #[test]
144    fn replace_children() {
145        let expr = is_null(root());
146        expr.with_children([root()])
147            .vortex_expect("operation should succeed in test");
148    }
149
150    #[test]
151    fn evaluate_mask() {
152        let test_array =
153            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
154                .into_array();
155        let expected = [false, true, false, true, false];
156
157        let result = test_array.clone().apply(&is_null(root())).unwrap();
158
159        assert_eq!(result.len(), test_array.len());
160        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
161
162        for (i, expected_value) in expected.iter().enumerate() {
163            assert_eq!(
164                result
165                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
166                    .unwrap(),
167                Scalar::bool(*expected_value, Nullability::NonNullable)
168            );
169        }
170    }
171
172    #[test]
173    fn evaluate_all_false() {
174        let test_array = buffer![1, 2, 3, 4, 5].into_array();
175
176        let result = test_array.clone().apply(&is_null(root())).unwrap();
177
178        assert_eq!(result.len(), test_array.len());
179        // All values should be false (non-nullable input)
180        for i in 0..result.len() {
181            assert_eq!(
182                result
183                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
184                    .unwrap(),
185                Scalar::bool(false, Nullability::NonNullable)
186            );
187        }
188    }
189
190    #[test]
191    fn evaluate_all_true() {
192        let test_array =
193            PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
194                .into_array();
195
196        let result = test_array.clone().apply(&is_null(root())).unwrap();
197
198        assert_eq!(result.len(), test_array.len());
199        // All values should be true (all nulls)
200        for i in 0..result.len() {
201            assert_eq!(
202                result
203                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
204                    .unwrap(),
205                Scalar::bool(true, Nullability::NonNullable)
206            );
207        }
208    }
209
210    #[test]
211    fn evaluate_struct() {
212        let test_array = StructArray::from_fields(&[(
213            "a",
214            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
215                .into_array(),
216        )])
217        .unwrap()
218        .into_array();
219        let expected = [false, true, false, true, false];
220
221        let result = test_array
222            .clone()
223            .apply(&is_null(get_item("a", root())))
224            .unwrap();
225
226        assert_eq!(result.len(), test_array.len());
227        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
228
229        for (i, expected_value) in expected.iter().enumerate() {
230            assert_eq!(
231                result
232                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
233                    .unwrap(),
234                Scalar::bool(*expected_value, Nullability::NonNullable)
235            );
236        }
237    }
238
239    #[test]
240    fn test_display() {
241        let expr = is_null(get_item("name", root()));
242        assert_eq!(expr.to_string(), "vortex.is_null($.name)");
243
244        let expr2 = is_null(root());
245        assert_eq!(expr2.to_string(), "vortex.is_null($)");
246    }
247
248    #[test]
249    fn test_is_null_falsification() {
250        let expr = is_null(col("a"));
251
252        let (pruning_expr, st) = checked_pruning_expr(
253            &expr,
254            &FieldPathSet::from_iter([FieldPath::from_iter([
255                Field::Name("a".into()),
256                Field::Name("null_count".into()),
257            ])]),
258        )
259        .unwrap();
260
261        assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64)));
262        assert_eq!(
263            st.map(),
264            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))])
265        );
266    }
267
268    #[test]
269    fn test_is_null_sensitive() {
270        // is_null itself is null-sensitive
271        assert!(is_null(col("a")).signature().is_null_sensitive());
272    }
273}