Skip to main content

vortex_array/scalar_fn/fns/
is_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_session::VortexSession;
6
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::IntoArray;
10use crate::arrays::ConstantArray;
11use crate::builtins::ArrayBuiltins;
12use crate::dtype::DType;
13use crate::dtype::Nullability;
14use crate::expr::Expression;
15use crate::expr::StatsCatalog;
16use crate::expr::eq;
17use crate::expr::lit;
18use crate::expr::stats::Stat;
19use crate::scalar_fn::Arity;
20use crate::scalar_fn::ChildName;
21use crate::scalar_fn::EmptyOptions;
22use crate::scalar_fn::ExecutionArgs;
23use crate::scalar_fn::ScalarFnId;
24use crate::scalar_fn::ScalarFnVTable;
25use crate::validity::Validity;
26
27/// Expression that checks for null values.
28#[derive(Clone)]
29pub struct IsNull;
30
31impl ScalarFnVTable for IsNull {
32    type Options = EmptyOptions;
33
34    fn id(&self) -> ScalarFnId {
35        ScalarFnId::from("vortex.is_null")
36    }
37
38    fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
39        Ok(Some(vec![]))
40    }
41
42    fn deserialize(
43        &self,
44        _metadata: &[u8],
45        _session: &VortexSession,
46    ) -> VortexResult<Self::Options> {
47        Ok(EmptyOptions)
48    }
49
50    fn arity(&self, _options: &Self::Options) -> Arity {
51        Arity::Exact(1)
52    }
53
54    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
55        match child_idx {
56            0 => ChildName::from("input"),
57            _ => unreachable!("Invalid child index {} for IsNull expression", child_idx),
58        }
59    }
60
61    fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
62        Ok(DType::Bool(Nullability::NonNullable))
63    }
64
65    fn execute(
66        &self,
67        _data: &Self::Options,
68        args: &dyn ExecutionArgs,
69        _ctx: &mut ExecutionCtx,
70    ) -> VortexResult<ArrayRef> {
71        let child = args.get(0)?;
72        if let Some(scalar) = child.as_constant() {
73            return Ok(ConstantArray::new(scalar.is_null(), args.row_count()).into_array());
74        }
75
76        match child.validity()? {
77            Validity::NonNullable | Validity::AllValid => {
78                Ok(ConstantArray::new(false, args.row_count()).into_array())
79            }
80            Validity::AllInvalid => Ok(ConstantArray::new(true, args.row_count()).into_array()),
81            Validity::Array(a) => a.not(),
82        }
83    }
84
85    fn stat_falsification(
86        &self,
87        _options: &Self::Options,
88        expr: &Expression,
89        catalog: &dyn StatsCatalog,
90    ) -> Option<Expression> {
91        let null_count_expr = expr.child(0).stat_expression(Stat::NullCount, catalog)?;
92        Some(eq(null_count_expr, lit(0u64)))
93    }
94
95    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
96        true
97    }
98
99    fn is_fallible(&self, _instance: &Self::Options) -> bool {
100        false
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use vortex_buffer::buffer;
107    use vortex_error::VortexExpect as _;
108    use vortex_utils::aliases::hash_map::HashMap;
109    use vortex_utils::aliases::hash_set::HashSet;
110
111    use crate::IntoArray;
112    use crate::LEGACY_SESSION;
113    use crate::VortexSessionExecute;
114    use crate::arrays::PrimitiveArray;
115    use crate::arrays::StructArray;
116    use crate::dtype::DType;
117    use crate::dtype::Field;
118    use crate::dtype::FieldPath;
119    use crate::dtype::FieldPathSet;
120    use crate::dtype::Nullability;
121    use crate::expr::col;
122    use crate::expr::eq;
123    use crate::expr::get_item;
124    use crate::expr::is_null;
125    use crate::expr::lit;
126    use crate::expr::pruning::checked_pruning_expr;
127    use crate::expr::root;
128    use crate::expr::stats::Stat;
129    use crate::expr::test_harness;
130    use crate::scalar::Scalar;
131
132    #[test]
133    fn dtype() {
134        let dtype = test_harness::struct_dtype();
135        assert_eq!(
136            is_null(root()).return_dtype(&dtype).unwrap(),
137            DType::Bool(Nullability::NonNullable)
138        );
139    }
140
141    #[test]
142    fn replace_children() {
143        let expr = is_null(root());
144        expr.with_children([root()])
145            .vortex_expect("operation should succeed in test");
146    }
147
148    #[test]
149    fn evaluate_mask() {
150        let test_array =
151            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
152                .into_array();
153        let expected = [false, true, false, true, false];
154
155        let result = test_array.clone().apply(&is_null(root())).unwrap();
156
157        assert_eq!(result.len(), test_array.len());
158        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
159
160        for (i, expected_value) in expected.iter().enumerate() {
161            assert_eq!(
162                result
163                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
164                    .unwrap(),
165                Scalar::bool(*expected_value, Nullability::NonNullable)
166            );
167        }
168    }
169
170    #[test]
171    fn evaluate_all_false() {
172        let test_array = buffer![1, 2, 3, 4, 5].into_array();
173
174        let result = test_array.clone().apply(&is_null(root())).unwrap();
175
176        assert_eq!(result.len(), test_array.len());
177        // All values should be false (non-nullable input)
178        for i in 0..result.len() {
179            assert_eq!(
180                result
181                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
182                    .unwrap(),
183                Scalar::bool(false, Nullability::NonNullable)
184            );
185        }
186    }
187
188    #[test]
189    fn evaluate_all_true() {
190        let test_array =
191            PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
192                .into_array();
193
194        let result = test_array.clone().apply(&is_null(root())).unwrap();
195
196        assert_eq!(result.len(), test_array.len());
197        // All values should be true (all nulls)
198        for i in 0..result.len() {
199            assert_eq!(
200                result
201                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
202                    .unwrap(),
203                Scalar::bool(true, Nullability::NonNullable)
204            );
205        }
206    }
207
208    #[test]
209    fn evaluate_struct() {
210        let test_array = StructArray::from_fields(&[(
211            "a",
212            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
213                .into_array(),
214        )])
215        .unwrap()
216        .into_array();
217        let expected = [false, true, false, true, false];
218
219        let result = test_array
220            .clone()
221            .apply(&is_null(get_item("a", root())))
222            .unwrap();
223
224        assert_eq!(result.len(), test_array.len());
225        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
226
227        for (i, expected_value) in expected.iter().enumerate() {
228            assert_eq!(
229                result
230                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
231                    .unwrap(),
232                Scalar::bool(*expected_value, Nullability::NonNullable)
233            );
234        }
235    }
236
237    #[test]
238    fn test_display() {
239        let expr = is_null(get_item("name", root()));
240        assert_eq!(expr.to_string(), "vortex.is_null($.name)");
241
242        let expr2 = is_null(root());
243        assert_eq!(expr2.to_string(), "vortex.is_null($)");
244    }
245
246    #[test]
247    fn test_is_null_falsification() {
248        let expr = is_null(col("a"));
249
250        let (pruning_expr, st) = checked_pruning_expr(
251            &expr,
252            &FieldPathSet::from_iter([FieldPath::from_iter([
253                Field::Name("a".into()),
254                Field::Name("null_count".into()),
255            ])]),
256        )
257        .unwrap();
258
259        assert_eq!(&pruning_expr, &eq(col("a_null_count"), lit(0u64)));
260        assert_eq!(
261            st.map(),
262            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))])
263        );
264    }
265
266    #[test]
267    fn test_is_null_sensitive() {
268        // is_null itself is null-sensitive
269        assert!(is_null(col("a")).signature().is_null_sensitive());
270    }
271}