vortex_expr/exprs/
list_contains.rs

1use std::any::Any;
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::Hash;
4use std::sync::Arc;
5
6use vortex_array::ArrayRef;
7use vortex_array::compute::list_contains as compute_list_contains;
8use vortex_dtype::DType;
9use vortex_error::VortexResult;
10
11use crate::{
12    AnalysisExpr, ExprRef, Literal, Scope, ScopeDType, StatsCatalog, VortexExpr, and, gt, lit, lt,
13    or,
14};
15
16#[derive(Debug, Clone, Eq, Hash)]
17#[allow(clippy::derived_hash_with_manual_eq)]
18pub struct ListContains {
19    list: ExprRef,
20    value: ExprRef,
21}
22
23impl ListContains {
24    pub fn new_expr(list: ExprRef, value: ExprRef) -> ExprRef {
25        Arc::new(Self { list, value })
26    }
27
28    pub fn value(&self) -> &ExprRef {
29        &self.value
30    }
31}
32
33pub fn list_contains(list: ExprRef, value: ExprRef) -> ExprRef {
34    ListContains::new_expr(list, value)
35}
36
37impl Display for ListContains {
38    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
39        write!(f, "contains({}, {})", &self.list, &self.value)
40    }
41}
42
43#[cfg(feature = "proto")]
44pub(crate) mod proto {
45    use vortex_error::{VortexResult, vortex_bail};
46    use vortex_proto::expr::kind;
47    use vortex_proto::expr::kind::Kind;
48
49    use crate::list_contains::ListContains;
50    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
51
52    pub(crate) struct ListContainsSerde;
53
54    impl Id for ListContainsSerde {
55        fn id(&self) -> &'static str {
56            "list_contains"
57        }
58    }
59
60    impl ExprDeserialize for ListContainsSerde {
61        fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
62            let Kind::ListContains(kind::ListContains {}) = kind else {
63                vortex_bail!("wrong kind {:?}, want list_contains", kind)
64            };
65
66            Ok(ListContains::new_expr(
67                children[0].clone(),
68                children[1].clone(),
69            ))
70        }
71    }
72
73    impl ExprSerializable for ListContains {
74        fn id(&self) -> &'static str {
75            ListContainsSerde.id()
76        }
77
78        fn serialize_kind(&self) -> VortexResult<Kind> {
79            Ok(Kind::ListContains(kind::ListContains {}))
80        }
81    }
82}
83
84impl AnalysisExpr for ListContains {
85    // falsification(contains([1,2,5], x)) =>
86    //   falsification(x != 1) and falsification(x != 2) and falsification(x != 5)
87
88    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
89        let min = self.list.min(catalog)?;
90        let max = self.list.max(catalog)?;
91        // If the list is constant when we can compare each element to the value
92        if min == max {
93            let list_ = min
94                .as_any()
95                .downcast_ref::<Literal>()
96                .and_then(|l| l.value().as_list_opt())
97                .and_then(|l| l.elements())?;
98            if list_.is_empty() {
99                // contains([], x) is always false.
100                return Some(lit(true));
101            }
102            let value_max = self.value.max(catalog)?;
103            let value_min = self.value.min(catalog)?;
104
105            return list_
106                .iter()
107                .map(move |v| {
108                    or(
109                        lt(value_max.clone(), lit(v.clone())),
110                        gt(value_min.clone(), lit(v.clone())),
111                    )
112                })
113                .reduce(and);
114        }
115
116        None
117    }
118}
119
120impl VortexExpr for ListContains {
121    fn as_any(&self) -> &dyn Any {
122        self
123    }
124
125    fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
126        compute_list_contains(
127            self.list.evaluate(scope)?.as_ref(),
128            self.value.evaluate(scope)?.as_ref(),
129        )
130    }
131
132    fn children(&self) -> Vec<&ExprRef> {
133        vec![&self.list, &self.value]
134    }
135
136    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
137        assert_eq!(children.len(), 2);
138        Self::new_expr(children[0].clone(), children[1].clone())
139    }
140
141    fn return_dtype(&self, scope_dtype: &ScopeDType) -> VortexResult<DType> {
142        Ok(DType::Bool(
143            self.list.return_dtype(scope_dtype)?.nullability()
144                | self.value.return_dtype(scope_dtype)?.nullability(),
145        ))
146    }
147}
148
149impl PartialEq for ListContains {
150    fn eq(&self, other: &ListContains) -> bool {
151        self.value.eq(&other.value) && self.list.eq(&other.list)
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use vortex_array::arrays::{BoolArray, BooleanBuffer, ListArray, PrimitiveArray};
158    use vortex_array::stats::Stat;
159    use vortex_array::validity::Validity;
160    use vortex_array::{Array, ArrayRef, IntoArray};
161    use vortex_dtype::PType::I32;
162    use vortex_dtype::{Field, FieldPath, FieldPathSet, Nullability, StructFields};
163    use vortex_scalar::Scalar;
164    use vortex_utils::aliases::hash_map::HashMap;
165
166    use crate::list_contains::list_contains;
167    use crate::pruning::checked_pruning_expr;
168    use crate::{
169        AccessPath, Arc, DType, HashSet, Scope, ScopeDType, ScopeFieldPathSet, and, get_item,
170        get_item_scope, gt, lit, lt, or, root,
171    };
172
173    fn test_array() -> ArrayRef {
174        ListArray::try_new(
175            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
176            PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
177            Validity::AllValid,
178        )
179        .unwrap()
180        .into_array()
181    }
182
183    #[test]
184    pub fn test_one() {
185        let arr = test_array();
186
187        let expr = list_contains(root(), lit(1));
188        let item = expr.evaluate(&Scope::new(arr)).unwrap();
189
190        assert_eq!(
191            item.scalar_at(0).unwrap(),
192            Scalar::bool(true, Nullability::Nullable)
193        );
194        assert_eq!(
195            item.scalar_at(1).unwrap(),
196            Scalar::bool(false, Nullability::Nullable)
197        );
198    }
199
200    #[test]
201    pub fn test_all() {
202        let arr = test_array();
203
204        let expr = list_contains(root(), lit(2));
205        let item = expr.evaluate(&Scope::new(arr)).unwrap();
206
207        assert_eq!(
208            item.scalar_at(0).unwrap(),
209            Scalar::bool(true, Nullability::Nullable)
210        );
211        assert_eq!(
212            item.scalar_at(1).unwrap(),
213            Scalar::bool(true, Nullability::Nullable)
214        );
215    }
216
217    #[test]
218    pub fn test_none() {
219        let arr = test_array();
220
221        let expr = list_contains(root(), lit(4));
222        let item = expr.evaluate(&Scope::new(arr)).unwrap();
223
224        assert_eq!(
225            item.scalar_at(0).unwrap(),
226            Scalar::bool(false, Nullability::Nullable)
227        );
228        assert_eq!(
229            item.scalar_at(1).unwrap(),
230            Scalar::bool(false, Nullability::Nullable)
231        );
232    }
233
234    #[test]
235    pub fn test_empty() {
236        let arr = ListArray::try_new(
237            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
238            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
239            Validity::AllValid,
240        )
241        .unwrap()
242        .into_array();
243
244        let expr = list_contains(root(), lit(2));
245        let item = expr.evaluate(&Scope::new(arr)).unwrap();
246
247        assert_eq!(
248            item.scalar_at(0).unwrap(),
249            Scalar::bool(true, Nullability::Nullable)
250        );
251        assert_eq!(
252            item.scalar_at(1).unwrap(),
253            Scalar::bool(false, Nullability::Nullable)
254        );
255    }
256
257    #[test]
258    pub fn test_nullable() {
259        let arr = ListArray::try_new(
260            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
261            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
262            Validity::Array(BoolArray::from(BooleanBuffer::from(vec![true, false])).into_array()),
263        )
264        .unwrap()
265        .into_array();
266
267        let expr = list_contains(root(), lit(2));
268        let item = expr.evaluate(&Scope::new(arr)).unwrap();
269
270        assert_eq!(
271            item.scalar_at(0).unwrap(),
272            Scalar::bool(true, Nullability::Nullable)
273        );
274        assert!(!item.is_valid(1).unwrap());
275    }
276
277    #[test]
278    pub fn test_return_type() {
279        let scope = ScopeDType::new(DType::Struct(
280            StructFields::new(
281                ["array"].into(),
282                vec![DType::List(
283                    Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
284                    Nullability::Nullable,
285                )],
286            ),
287            Nullability::NonNullable,
288        ));
289
290        let expr = list_contains(get_item("array", root()), lit(2));
291
292        // Expect nullable, although scope is non-nullable
293        assert_eq!(
294            expr.return_dtype(&scope).unwrap(),
295            DType::Bool(Nullability::Nullable)
296        );
297    }
298
299    #[test]
300    pub fn list_falsification() {
301        let expr = list_contains(
302            lit(Scalar::list(
303                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
304                vec![1.into(), 2.into(), 3.into()],
305                Nullability::NonNullable,
306            )),
307            get_item_scope("a"),
308        );
309
310        let (expr, st) = checked_pruning_expr(
311            &expr,
312            &ScopeFieldPathSet::new(FieldPathSet::from_iter([
313                FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
314                FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
315            ])),
316        )
317        .unwrap();
318
319        assert_eq!(
320            &expr,
321            &and(
322                and(
323                    or(
324                        lt(get_item_scope("a_max"), lit(1i32)),
325                        gt(get_item_scope("a_min"), lit(1i32)),
326                    ),
327                    or(
328                        lt(get_item_scope("a_max"), lit(2i32)),
329                        gt(get_item_scope("a_min"), lit(2i32)),
330                    )
331                ),
332                or(
333                    lt(get_item_scope("a_max"), lit(3i32)),
334                    gt(get_item_scope("a_min"), lit(3i32)),
335                )
336            )
337        );
338
339        assert_eq!(
340            st.map(),
341            &HashMap::from_iter([(
342                AccessPath::root_field("a".into()),
343                HashSet::from([Stat::Min, Stat::Max])
344            )])
345        );
346    }
347}