vortex_expr/
list_contains.rs

1use std::any::Any;
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::Hash;
4use std::sync::Arc;
5
6use vortex_array::ArrayRef;
7use vortex_array::compute::list_contains as compute_list_contains;
8use vortex_dtype::DType;
9use vortex_error::VortexResult;
10
11use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
12
13#[derive(Debug, Clone, Eq, Hash)]
14#[allow(clippy::derived_hash_with_manual_eq)]
15pub struct ListContains {
16    list: ExprRef,
17    value: ExprRef,
18}
19
20impl ListContains {
21    pub fn new_expr(list: ExprRef, value: ExprRef) -> ExprRef {
22        Arc::new(Self { list, value })
23    }
24
25    pub fn value(&self) -> &ExprRef {
26        &self.value
27    }
28}
29
30pub fn list_contains(list: ExprRef, value: ExprRef) -> ExprRef {
31    ListContains::new_expr(list, value)
32}
33
34impl Display for ListContains {
35    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36        write!(f, "contains({}, {})", &self.list, &self.value)
37    }
38}
39
40#[cfg(feature = "proto")]
41pub(crate) mod proto {
42    use vortex_error::{VortexResult, vortex_bail};
43    use vortex_proto::expr::kind;
44    use vortex_proto::expr::kind::Kind;
45
46    use crate::list_contains::ListContains;
47    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
48
49    pub(crate) struct ListContainsSerde;
50
51    impl Id for ListContainsSerde {
52        fn id(&self) -> &'static str {
53            "list_contains"
54        }
55    }
56
57    impl ExprDeserialize for ListContainsSerde {
58        fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
59            let Kind::ListContains(kind::ListContains {}) = kind else {
60                vortex_bail!("wrong kind {:?}, want list_contains", kind)
61            };
62
63            Ok(ListContains::new_expr(
64                children[0].clone(),
65                children[1].clone(),
66            ))
67        }
68    }
69
70    impl ExprSerializable for ListContains {
71        fn id(&self) -> &'static str {
72            ListContainsSerde.id()
73        }
74
75        fn serialize_kind(&self) -> VortexResult<Kind> {
76            Ok(Kind::ListContains(kind::ListContains {}))
77        }
78    }
79}
80
81impl AnalysisExpr for ListContains {}
82
83impl VortexExpr for ListContains {
84    fn as_any(&self) -> &dyn Any {
85        self
86    }
87
88    fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
89        compute_list_contains(
90            self.list.evaluate(scope)?.as_ref(),
91            self.value.evaluate(scope)?.as_ref(),
92        )
93    }
94
95    fn children(&self) -> Vec<&ExprRef> {
96        vec![&self.list, &self.value]
97    }
98
99    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
100        assert_eq!(children.len(), 2);
101        Self::new_expr(children[0].clone(), children[1].clone())
102    }
103
104    fn return_dtype(&self, scope_dtype: &ScopeDType) -> VortexResult<DType> {
105        Ok(DType::Bool(
106            self.list.return_dtype(scope_dtype)?.nullability()
107                | self.value.return_dtype(scope_dtype)?.nullability(),
108        ))
109    }
110}
111
112impl PartialEq for ListContains {
113    fn eq(&self, other: &ListContains) -> bool {
114        self.value.eq(&other.value) && self.list.eq(&other.list)
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use vortex_array::arrays::{BoolArray, BooleanBuffer, ListArray, PrimitiveArray};
121    use vortex_array::validity::Validity;
122    use vortex_array::{Array, ArrayRef, IntoArray};
123    use vortex_dtype::{FieldNames, Nullability, StructFields};
124    use vortex_scalar::Scalar;
125
126    use crate::list_contains::list_contains;
127    use crate::{Arc, DType, Scope, ScopeDType, get_item, lit, root};
128
129    fn test_array() -> ArrayRef {
130        ListArray::try_new(
131            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
132            PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
133            Validity::AllValid,
134        )
135        .unwrap()
136        .into_array()
137    }
138
139    #[test]
140    pub fn test_one() {
141        let arr = test_array();
142
143        let expr = list_contains(root(), lit(1));
144        let item = expr.evaluate(&Scope::new(arr)).unwrap();
145
146        assert_eq!(
147            item.scalar_at(0).unwrap(),
148            Scalar::bool(true, Nullability::Nullable)
149        );
150        assert_eq!(
151            item.scalar_at(1).unwrap(),
152            Scalar::bool(false, Nullability::Nullable)
153        );
154    }
155
156    #[test]
157    pub fn test_all() {
158        let arr = test_array();
159
160        let expr = list_contains(root(), lit(2));
161        let item = expr.evaluate(&Scope::new(arr)).unwrap();
162
163        assert_eq!(
164            item.scalar_at(0).unwrap(),
165            Scalar::bool(true, Nullability::Nullable)
166        );
167        assert_eq!(
168            item.scalar_at(1).unwrap(),
169            Scalar::bool(true, Nullability::Nullable)
170        );
171    }
172
173    #[test]
174    pub fn test_none() {
175        let arr = test_array();
176
177        let expr = list_contains(root(), lit(4));
178        let item = expr.evaluate(&Scope::new(arr)).unwrap();
179
180        assert_eq!(
181            item.scalar_at(0).unwrap(),
182            Scalar::bool(false, Nullability::Nullable)
183        );
184        assert_eq!(
185            item.scalar_at(1).unwrap(),
186            Scalar::bool(false, Nullability::Nullable)
187        );
188    }
189
190    #[test]
191    pub fn test_empty() {
192        let arr = ListArray::try_new(
193            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
194            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
195            Validity::AllValid,
196        )
197        .unwrap()
198        .into_array();
199
200        let expr = list_contains(root(), lit(2));
201        let item = expr.evaluate(&Scope::new(arr)).unwrap();
202
203        assert_eq!(
204            item.scalar_at(0).unwrap(),
205            Scalar::bool(true, Nullability::Nullable)
206        );
207        assert_eq!(
208            item.scalar_at(1).unwrap(),
209            Scalar::bool(false, Nullability::Nullable)
210        );
211    }
212
213    #[test]
214    pub fn test_nullable() {
215        let arr = ListArray::try_new(
216            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
217            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
218            Validity::Array(BoolArray::from(BooleanBuffer::from(vec![true, false])).into_array()),
219        )
220        .unwrap()
221        .into_array();
222
223        let expr = list_contains(root(), lit(2));
224        let item = expr.evaluate(&Scope::new(arr)).unwrap();
225
226        assert_eq!(
227            item.scalar_at(0).unwrap(),
228            Scalar::bool(true, Nullability::Nullable)
229        );
230        assert!(!item.is_valid(1).unwrap());
231    }
232
233    #[test]
234    pub fn test_return_type() {
235        let scope = ScopeDType::new(DType::Struct(
236            Arc::new(StructFields::new(
237                FieldNames::from(["array".into()]),
238                vec![DType::List(
239                    Arc::new(DType::Primitive(
240                        vortex_dtype::PType::I32,
241                        Nullability::NonNullable,
242                    )),
243                    Nullability::Nullable,
244                )],
245            )),
246            Nullability::NonNullable,
247        ));
248
249        let expr = list_contains(get_item("array", root()), lit(2));
250
251        // Expect nullable, although scope is non-nullable
252        assert_eq!(
253            expr.return_dtype(&scope).unwrap(),
254            DType::Bool(Nullability::Nullable)
255        );
256    }
257}