Skip to main content

vortex_array/scalar_fn/fns/
is_not_null.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use vortex_array::scalar_fn::internal::row_count::RowCount;
7use vortex_error::VortexResult;
8use vortex_session::VortexSession;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::arrays::ConstantArray;
14use crate::dtype::DType;
15use crate::dtype::Nullability;
16use crate::expr::Expression;
17use crate::expr::StatsCatalog;
18use crate::expr::eq;
19use crate::expr::stats::Stat;
20use crate::scalar_fn::Arity;
21use crate::scalar_fn::ChildName;
22use crate::scalar_fn::EmptyOptions;
23use crate::scalar_fn::ExecutionArgs;
24use crate::scalar_fn::ScalarFnId;
25use crate::scalar_fn::ScalarFnVTable;
26use crate::scalar_fn::ScalarFnVTableExt;
27use crate::validity::Validity;
28
29/// Expression that checks for non-null values.
30#[derive(Clone)]
31pub struct IsNotNull;
32
33impl ScalarFnVTable for IsNotNull {
34    type Options = EmptyOptions;
35
36    fn id(&self) -> ScalarFnId {
37        ScalarFnId::new("vortex.is_not_null")
38    }
39
40    fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
41        Ok(Some(vec![]))
42    }
43
44    fn deserialize(
45        &self,
46        _metadata: &[u8],
47        _session: &VortexSession,
48    ) -> VortexResult<Self::Options> {
49        Ok(EmptyOptions)
50    }
51
52    fn arity(&self, _options: &Self::Options) -> Arity {
53        Arity::Exact(1)
54    }
55
56    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
57        match child_idx {
58            0 => ChildName::from("input"),
59            _ => unreachable!("Invalid child index {} for IsNotNull expression", child_idx),
60        }
61    }
62
63    fn fmt_sql(
64        &self,
65        _options: &Self::Options,
66        expr: &Expression,
67        f: &mut Formatter<'_>,
68    ) -> std::fmt::Result {
69        write!(f, "is_not_null(")?;
70        expr.child(0).fmt_sql(f)?;
71        write!(f, ")")
72    }
73
74    fn return_dtype(&self, _options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
75        Ok(DType::Bool(Nullability::NonNullable))
76    }
77
78    fn execute(
79        &self,
80        _data: &Self::Options,
81        args: &dyn ExecutionArgs,
82        _ctx: &mut ExecutionCtx,
83    ) -> VortexResult<ArrayRef> {
84        let child = args.get(0)?;
85        match child.validity()? {
86            Validity::NonNullable | Validity::AllValid => {
87                Ok(ConstantArray::new(true, args.row_count()).into_array())
88            }
89            Validity::AllInvalid => Ok(ConstantArray::new(false, args.row_count()).into_array()),
90            Validity::Array(a) => Ok(a),
91        }
92    }
93
94    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
95        true
96    }
97
98    fn is_fallible(&self, _instance: &Self::Options) -> bool {
99        false
100    }
101
102    fn stat_falsification(
103        &self,
104        _options: &Self::Options,
105        expr: &Expression,
106        catalog: &dyn StatsCatalog,
107    ) -> Option<Expression> {
108        // is_not_null is falsified when ALL values are null, i.e. null_count == row_count.
109        let child = expr.child(0);
110        let null_count_expr = child.stat_expression(Stat::NullCount, catalog)?;
111        Some(eq(null_count_expr, RowCount.new_expr(EmptyOptions, [])))
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use vortex_buffer::buffer;
118    use vortex_error::VortexExpect as _;
119    use vortex_utils::aliases::hash_map::HashMap;
120    use vortex_utils::aliases::hash_set::HashSet;
121
122    use crate::IntoArray;
123    use crate::LEGACY_SESSION;
124    use crate::VortexSessionExecute;
125    use crate::arrays::PrimitiveArray;
126    use crate::arrays::StructArray;
127    use crate::dtype::DType;
128    use crate::dtype::Field;
129    use crate::dtype::FieldPath;
130    use crate::dtype::FieldPathSet;
131    use crate::dtype::Nullability;
132    use crate::expr::col;
133    use crate::expr::eq;
134    use crate::expr::get_item;
135    use crate::expr::is_not_null;
136    use crate::expr::pruning::checked_pruning_expr;
137    use crate::expr::root;
138    use crate::expr::stats::Stat;
139    use crate::expr::test_harness;
140    use crate::scalar::Scalar;
141    use crate::scalar_fn::EmptyOptions;
142    use crate::scalar_fn::internal::row_count::RowCount;
143    use crate::scalar_fn::vtable::ScalarFnVTableExt;
144
145    #[test]
146    fn dtype() {
147        let dtype = test_harness::struct_dtype();
148        assert_eq!(
149            is_not_null(root()).return_dtype(&dtype).unwrap(),
150            DType::Bool(Nullability::NonNullable)
151        );
152    }
153
154    #[test]
155    fn replace_children() {
156        let expr = is_not_null(root());
157        expr.with_children([root()])
158            .vortex_expect("operation should succeed in test");
159    }
160
161    #[test]
162    fn evaluate_mask() {
163        let test_array =
164            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
165                .into_array();
166        let expected = [true, false, true, false, true];
167
168        let result = test_array.clone().apply(&is_not_null(root())).unwrap();
169
170        assert_eq!(result.len(), test_array.len());
171        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
172
173        for (i, expected_value) in expected.iter().enumerate() {
174            assert_eq!(
175                result
176                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
177                    .unwrap(),
178                Scalar::bool(*expected_value, Nullability::NonNullable)
179            );
180        }
181    }
182
183    #[test]
184    fn evaluate_all_true() {
185        let test_array = buffer![1, 2, 3, 4, 5].into_array();
186
187        let result = test_array.clone().apply(&is_not_null(root())).unwrap();
188
189        assert_eq!(result.len(), test_array.len());
190        for i in 0..result.len() {
191            assert_eq!(
192                result
193                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
194                    .unwrap(),
195                Scalar::bool(true, Nullability::NonNullable)
196            );
197        }
198    }
199
200    #[test]
201    fn evaluate_all_false() {
202        let test_array =
203            PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
204                .into_array();
205
206        let result = test_array.clone().apply(&is_not_null(root())).unwrap();
207
208        assert_eq!(result.len(), test_array.len());
209        for i in 0..result.len() {
210            assert_eq!(
211                result
212                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
213                    .unwrap(),
214                Scalar::bool(false, Nullability::NonNullable)
215            );
216        }
217    }
218
219    #[test]
220    fn evaluate_struct() {
221        let test_array = StructArray::from_fields(&[(
222            "a",
223            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
224                .into_array(),
225        )])
226        .unwrap()
227        .into_array();
228        let expected = [true, false, true, false, true];
229
230        let result = test_array
231            .clone()
232            .apply(&is_not_null(get_item("a", root())))
233            .unwrap();
234
235        assert_eq!(result.len(), test_array.len());
236        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
237
238        for (i, expected_value) in expected.iter().enumerate() {
239            assert_eq!(
240                result
241                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
242                    .unwrap(),
243                Scalar::bool(*expected_value, Nullability::NonNullable)
244            );
245        }
246    }
247
248    #[test]
249    fn test_display() {
250        let expr = is_not_null(get_item("name", root()));
251        assert_eq!(expr.to_string(), "is_not_null($.name)");
252
253        let expr2 = is_not_null(root());
254        assert_eq!(expr2.to_string(), "is_not_null($)");
255    }
256
257    #[test]
258    fn test_is_not_null_sensitive() {
259        assert!(is_not_null(col("a")).signature().is_null_sensitive());
260    }
261
262    #[test]
263    fn test_is_not_null_falsification() {
264        let expr = is_not_null(col("a"));
265
266        let (pruning_expr, st) = checked_pruning_expr(
267            &expr,
268            &FieldPathSet::from_iter([FieldPath::from_iter([
269                Field::Name("a".into()),
270                Field::Name("null_count".into()),
271            ])]),
272        )
273        .unwrap();
274
275        assert_eq!(
276            &pruning_expr,
277            &eq(col("a_null_count"), RowCount.new_expr(EmptyOptions, []))
278        );
279        assert_eq!(
280            st.map(),
281            &HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from([Stat::NullCount]))])
282        );
283    }
284}