Skip to main content

vortex_array/scalar_fn/fns/
is_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_session::VortexSession;
6use vortex_session::registry::CachedId;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::arrays::ConstantArray;
12use crate::builtins::ArrayBuiltins;
13use crate::dtype::DType;
14use crate::dtype::Nullability;
15use crate::scalar_fn::Arity;
16use crate::scalar_fn::ChildName;
17use crate::scalar_fn::EmptyOptions;
18use crate::scalar_fn::ExecutionArgs;
19use crate::scalar_fn::ScalarFnId;
20use crate::scalar_fn::ScalarFnVTable;
21use crate::validity::Validity;
22
23/// Expression that checks for null values.
24#[derive(Clone)]
25pub struct IsNull;
26
27impl ScalarFnVTable for IsNull {
28    type Options = EmptyOptions;
29
30    fn id(&self) -> ScalarFnId {
31        static ID: CachedId = CachedId::new("vortex.is_null");
32        *ID
33    }
34
35    fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
36        Ok(Some(vec![]))
37    }
38
39    fn deserialize(
40        &self,
41        _metadata: &[u8],
42        _session: &VortexSession,
43    ) -> VortexResult<Self::Options> {
44        Ok(EmptyOptions)
45    }
46
47    fn arity(&self, _options: &Self::Options) -> Arity {
48        Arity::Exact(1)
49    }
50
51    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
52        match child_idx {
53            0 => ChildName::from("input"),
54            _ => unreachable!("Invalid child index {} for IsNull expression", child_idx),
55        }
56    }
57
58    fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
59        Ok(DType::Bool(Nullability::NonNullable))
60    }
61
62    fn execute(
63        &self,
64        _data: &Self::Options,
65        args: &dyn ExecutionArgs,
66        _ctx: &mut ExecutionCtx,
67    ) -> VortexResult<ArrayRef> {
68        let child = args.get(0)?;
69        if let Some(scalar) = child.as_constant() {
70            return Ok(ConstantArray::new(scalar.is_null(), args.row_count()).into_array());
71        }
72
73        match child.validity()? {
74            Validity::NonNullable | Validity::AllValid => {
75                Ok(ConstantArray::new(false, args.row_count()).into_array())
76            }
77            Validity::AllInvalid => Ok(ConstantArray::new(true, args.row_count()).into_array()),
78            Validity::Array(a) => a.not(),
79        }
80    }
81
82    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
83        true
84    }
85
86    fn is_fallible(&self, _instance: &Self::Options) -> bool {
87        false
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use std::sync::LazyLock;
94
95    use vortex_buffer::buffer;
96    use vortex_error::VortexExpect as _;
97    use vortex_error::VortexResult;
98    use vortex_session::VortexSession;
99
100    use crate::IntoArray;
101    use crate::VortexSessionExecute;
102    use crate::array_session;
103    use crate::arrays::PrimitiveArray;
104    use crate::arrays::StructArray;
105    use crate::dtype::DType;
106    use crate::dtype::Nullability;
107    use crate::expr::col;
108    use crate::expr::eq;
109    use crate::expr::get_item;
110    use crate::expr::is_null;
111    use crate::expr::lit;
112    use crate::expr::or;
113    use crate::expr::root;
114    use crate::expr::test_harness;
115    use crate::scalar::Scalar;
116    use crate::stats::StatsSession;
117    use crate::stats::all_non_null;
118    use crate::stats::null_count;
119
120    static STATS_SESSION: LazyLock<VortexSession> =
121        LazyLock::new(|| VortexSession::empty().with::<StatsSession>());
122
123    #[test]
124    fn dtype() {
125        let dtype = test_harness::struct_dtype();
126        assert_eq!(
127            is_null(root()).return_dtype(&dtype).unwrap(),
128            DType::Bool(Nullability::NonNullable)
129        );
130    }
131
132    #[test]
133    fn replace_children() {
134        let expr = is_null(root());
135        expr.with_children([root()])
136            .vortex_expect("operation should succeed in test");
137    }
138
139    #[test]
140    fn evaluate_mask() {
141        let test_array =
142            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
143                .into_array();
144        let expected = [false, true, false, true, false];
145
146        let result = test_array.clone().apply(&is_null(root())).unwrap();
147
148        assert_eq!(result.len(), test_array.len());
149        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
150
151        for (i, expected_value) in expected.iter().enumerate() {
152            assert_eq!(
153                result
154                    .execute_scalar(i, &mut array_session().create_execution_ctx())
155                    .unwrap(),
156                Scalar::bool(*expected_value, Nullability::NonNullable)
157            );
158        }
159    }
160
161    #[test]
162    fn evaluate_all_false() {
163        let test_array = buffer![1, 2, 3, 4, 5].into_array();
164
165        let result = test_array.clone().apply(&is_null(root())).unwrap();
166
167        assert_eq!(result.len(), test_array.len());
168        // All values should be false (non-nullable input)
169        for i in 0..result.len() {
170            assert_eq!(
171                result
172                    .execute_scalar(i, &mut array_session().create_execution_ctx())
173                    .unwrap(),
174                Scalar::bool(false, Nullability::NonNullable)
175            );
176        }
177    }
178
179    #[test]
180    fn evaluate_all_true() {
181        let test_array =
182            PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
183                .into_array();
184
185        let result = test_array.clone().apply(&is_null(root())).unwrap();
186
187        assert_eq!(result.len(), test_array.len());
188        // All values should be true (all nulls)
189        for i in 0..result.len() {
190            assert_eq!(
191                result
192                    .execute_scalar(i, &mut array_session().create_execution_ctx())
193                    .unwrap(),
194                Scalar::bool(true, Nullability::NonNullable)
195            );
196        }
197    }
198
199    #[test]
200    fn evaluate_struct() {
201        let test_array = StructArray::from_fields(&[(
202            "a",
203            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
204                .into_array(),
205        )])
206        .unwrap()
207        .into_array();
208        let expected = [false, true, false, true, false];
209
210        let result = test_array
211            .clone()
212            .apply(&is_null(get_item("a", root())))
213            .unwrap();
214
215        assert_eq!(result.len(), test_array.len());
216        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
217
218        for (i, expected_value) in expected.iter().enumerate() {
219            assert_eq!(
220                result
221                    .execute_scalar(i, &mut array_session().create_execution_ctx())
222                    .unwrap(),
223                Scalar::bool(*expected_value, Nullability::NonNullable)
224            );
225        }
226    }
227
228    #[test]
229    fn test_display() {
230        let expr = is_null(get_item("name", root()));
231        assert_eq!(expr.to_string(), "vortex.is_null($.name)");
232
233        let expr2 = is_null(root());
234        assert_eq!(expr2.to_string(), "vortex.is_null($)");
235    }
236
237    #[test]
238    fn test_is_null_falsification() -> VortexResult<()> {
239        let expr = is_null(col("a"));
240
241        assert_eq!(
242            expr.falsify(&test_harness::struct_dtype(), &STATS_SESSION)?,
243            Some(or(
244                eq(null_count(col("a")), lit(0u64)),
245                all_non_null(col("a")),
246            ))
247        );
248        Ok(())
249    }
250
251    #[test]
252    fn test_is_null_sensitive() {
253        // is_null itself is null-sensitive
254        assert!(is_null(col("a")).signature().is_null_sensitive());
255    }
256}