vortex_expr/
is_null.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::ops::Not;
4use std::sync::Arc;
5
6use vortex_array::arrays::{BoolArray, ConstantArray};
7use vortex_array::{Array, ArrayRef, IntoArray};
8use vortex_dtype::{DType, Nullability};
9use vortex_error::{VortexExpect, VortexResult};
10use vortex_mask::Mask;
11
12use crate::{ExprRef, VortexExpr};
13
14#[derive(Debug, Eq, Hash)]
15#[allow(clippy::derived_hash_with_manual_eq)]
16pub struct IsNull {
17    child: ExprRef,
18}
19
20impl IsNull {
21    pub fn new_expr(child: ExprRef) -> ExprRef {
22        Arc::new(Self { child })
23    }
24}
25
26impl PartialEq for IsNull {
27    fn eq(&self, other: &Self) -> bool {
28        self.child.eq(&other.child)
29    }
30}
31
32impl Display for IsNull {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(f, "is_null({})", self.child)
35    }
36}
37
38#[cfg(feature = "proto")]
39pub(crate) mod proto {
40    use vortex_error::{VortexResult, vortex_bail};
41    use vortex_proto::expr::kind;
42    use vortex_proto::expr::kind::Kind;
43
44    use crate::is_null::IsNull;
45    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
46
47    pub(crate) struct IsNullSerde;
48
49    impl Id for IsNullSerde {
50        fn id(&self) -> &'static str {
51            "is_null"
52        }
53    }
54
55    impl ExprDeserialize for IsNullSerde {
56        fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
57            let Kind::IsNull(kind::IsNull {}) = kind else {
58                vortex_bail!("wrong kind {:?}, want is_null", kind)
59            };
60
61            Ok(IsNull::new_expr(children[0].clone()))
62        }
63    }
64
65    impl ExprSerializable for IsNull {
66        fn id(&self) -> &'static str {
67            IsNullSerde.id()
68        }
69
70        fn serialize_kind(&self) -> VortexResult<Kind> {
71            Ok(Kind::IsNull(kind::IsNull {}))
72        }
73    }
74}
75
76impl VortexExpr for IsNull {
77    fn as_any(&self) -> &dyn Any {
78        self
79    }
80
81    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
82        let array = self.child.evaluate(batch)?;
83        match array.validity_mask()? {
84            Mask::AllTrue(len) => Ok(ConstantArray::new(false, len).into_array()),
85            Mask::AllFalse(len) => Ok(ConstantArray::new(true, len).into_array()),
86            Mask::Values(mask) => Ok(BoolArray::from(mask.boolean_buffer().not()).into_array()),
87        }
88    }
89
90    fn children(&self) -> Vec<&ExprRef> {
91        vec![&self.child]
92    }
93
94    fn replacing_children(self: Arc<Self>, mut children: Vec<ExprRef>) -> ExprRef {
95        Self::new_expr(
96            children
97                .pop()
98                .vortex_expect("IsNull::replacing_children should have one child"),
99        )
100    }
101
102    fn return_dtype(&self, _scope_dtype: &DType) -> VortexResult<DType> {
103        Ok(DType::Bool(Nullability::NonNullable))
104    }
105}
106
107pub fn is_null(child: ExprRef) -> ExprRef {
108    IsNull::new_expr(child)
109}
110
111#[cfg(test)]
112mod tests {
113    use vortex_array::IntoArray;
114    use vortex_array::arrays::{PrimitiveArray, StructArray};
115    use vortex_dtype::{DType, Nullability};
116    use vortex_scalar::Scalar;
117
118    use crate::is_null::is_null;
119    use crate::{get_item, ident, test_harness};
120
121    #[test]
122    fn dtype() {
123        let dtype = test_harness::struct_dtype();
124        assert_eq!(
125            is_null(ident()).return_dtype(&dtype).unwrap(),
126            DType::Bool(Nullability::NonNullable)
127        );
128    }
129
130    #[test]
131    fn replace_children() {
132        let expr = is_null(ident());
133        let _ = expr.replacing_children(vec![ident()]);
134    }
135
136    #[test]
137    fn evaluate_mask() {
138        let test_array =
139            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
140                .into_array();
141        let expected = [false, true, false, true, false];
142
143        let result = is_null(ident()).unchecked_evaluate(&test_array).unwrap();
144
145        assert_eq!(result.len(), test_array.len());
146        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
147
148        for (i, expected_value) in expected.iter().enumerate() {
149            assert_eq!(
150                result.scalar_at(i).unwrap(),
151                Scalar::bool(*expected_value, Nullability::NonNullable)
152            );
153        }
154    }
155
156    #[test]
157    fn evaluate_all_false() {
158        let test_array = PrimitiveArray::from_iter(vec![1, 2, 3, 4, 5]).into_array();
159
160        let result = is_null(ident()).unchecked_evaluate(&test_array).unwrap();
161
162        assert_eq!(result.len(), test_array.len());
163        assert_eq!(
164            result.as_constant().unwrap(),
165            Scalar::bool(false, Nullability::NonNullable)
166        );
167    }
168
169    #[test]
170    fn evaluate_all_true() {
171        let test_array =
172            PrimitiveArray::from_option_iter(vec![None::<i32>, None, None, None, None])
173                .into_array();
174
175        let result = is_null(ident()).unchecked_evaluate(&test_array).unwrap();
176
177        assert_eq!(result.len(), test_array.len());
178        assert_eq!(
179            result.as_constant().unwrap(),
180            Scalar::bool(true, Nullability::NonNullable)
181        );
182    }
183
184    #[test]
185    fn evaluate_struct() {
186        let test_array = StructArray::from_fields(&[(
187            "a",
188            PrimitiveArray::from_option_iter(vec![Some(1), None, Some(2), None, Some(3)])
189                .into_array(),
190        )])
191        .unwrap()
192        .into_array();
193        let expected = [false, true, false, true, false];
194
195        let result = is_null(get_item("a", ident()))
196            .unchecked_evaluate(&test_array)
197            .unwrap();
198
199        assert_eq!(result.len(), test_array.len());
200        assert_eq!(result.dtype(), &DType::Bool(Nullability::NonNullable));
201
202        for (i, expected_value) in expected.iter().enumerate() {
203            assert_eq!(
204                result.scalar_at(i).unwrap(),
205                Scalar::bool(*expected_value, Nullability::NonNullable)
206            );
207        }
208    }
209}