vortex_expr/exprs/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Display, Formatter};
5use std::hash::Hash;
6
7use vortex_array::compute::list_contains as compute_list_contains;
8use vortex_array::{ArrayRef, DeserializeMetadata, EmptyMetadata};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11
12use crate::{
13    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, LiteralVTable, Scope, StatsCatalog,
14    VTable, and, gt, lit, lt, or, vtable,
15};
16
17vtable!(ListContains);
18
19#[allow(clippy::derived_hash_with_manual_eq)]
20#[derive(Debug, Clone, Hash, Eq)]
21pub struct ListContainsExpr {
22    list: ExprRef,
23    value: ExprRef,
24}
25
26impl PartialEq for ListContainsExpr {
27    fn eq(&self, other: &Self) -> bool {
28        self.list.eq(&other.list) && self.value.eq(&other.value)
29    }
30}
31
32pub struct ListContainsExprEncoding;
33
34impl VTable for ListContainsVTable {
35    type Expr = ListContainsExpr;
36    type Encoding = ListContainsExprEncoding;
37    type Metadata = EmptyMetadata;
38
39    fn id(_encoding: &Self::Encoding) -> ExprId {
40        ExprId::new_ref("list_contains")
41    }
42
43    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
44        ExprEncodingRef::new_ref(ListContainsExprEncoding.as_ref())
45    }
46
47    fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
48        Some(EmptyMetadata)
49    }
50
51    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
52        vec![&expr.list, &expr.value]
53    }
54
55    fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
56        Ok(ListContainsExpr::new(
57            children[0].clone(),
58            children[1].clone(),
59        ))
60    }
61
62    fn build(
63        _encoding: &Self::Encoding,
64        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
65        children: Vec<ExprRef>,
66    ) -> VortexResult<Self::Expr> {
67        if children.len() != 2 {
68            vortex_bail!(
69                "ListContains expression must have exactly 2 children, got {}",
70                children.len()
71            );
72        }
73        Ok(ListContainsExpr::new(
74            children[0].clone(),
75            children[1].clone(),
76        ))
77    }
78
79    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
80        compute_list_contains(
81            expr.list.evaluate(scope)?.as_ref(),
82            expr.value.evaluate(scope)?.as_ref(),
83        )
84    }
85
86    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
87        Ok(DType::Bool(
88            expr.list.return_dtype(scope)?.nullability()
89                | expr.value.return_dtype(scope)?.nullability(),
90        ))
91    }
92}
93
94impl ListContainsExpr {
95    pub fn new(list: ExprRef, value: ExprRef) -> Self {
96        Self { list, value }
97    }
98
99    pub fn new_expr(list: ExprRef, value: ExprRef) -> ExprRef {
100        Self::new(list, value).into_expr()
101    }
102
103    pub fn value(&self) -> &ExprRef {
104        &self.value
105    }
106}
107
108/// Creates an expression that checks if a value is contained in a list.
109///
110/// Returns a boolean array indicating whether the value appears in each list.
111///
112/// ```rust
113/// # use vortex_expr::{list_contains, lit, root};
114/// let expr = list_contains(root(), lit(42));
115/// ```
116pub fn list_contains(list: ExprRef, value: ExprRef) -> ExprRef {
117    ListContainsExpr::new(list, value).into_expr()
118}
119
120impl Display for ListContainsExpr {
121    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
122        write!(f, "contains({}, {})", &self.list, &self.value)
123    }
124}
125
126impl AnalysisExpr for ListContainsExpr {
127    // falsification(contains([1,2,5], x)) =>
128    //   falsification(x != 1) and falsification(x != 2) and falsification(x != 5)
129
130    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
131        let min = self.list.min(catalog)?;
132        let max = self.list.max(catalog)?;
133        // If the list is constant when we can compare each element to the value
134        if min == max {
135            let list_ = min
136                .as_opt::<LiteralVTable>()
137                .and_then(|l| l.value().as_list_opt())
138                .and_then(|l| l.elements())?;
139            if list_.is_empty() {
140                // contains([], x) is always false.
141                return Some(lit(true));
142            }
143            let value_max = self.value.max(catalog)?;
144            let value_min = self.value.min(catalog)?;
145
146            return list_
147                .iter()
148                .map(move |v| {
149                    or(
150                        lt(value_max.clone(), lit(v.clone())),
151                        gt(value_min.clone(), lit(v.clone())),
152                    )
153                })
154                .reduce(and);
155        }
156
157        None
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use vortex_array::arrays::{BoolArray, BooleanBuffer, ListArray, PrimitiveArray};
164    use vortex_array::stats::Stat;
165    use vortex_array::validity::Validity;
166    use vortex_array::{Array, ArrayRef, IntoArray};
167    use vortex_dtype::PType::I32;
168    use vortex_dtype::{DType, Field, FieldPath, FieldPathSet, Nullability, StructFields};
169    use vortex_scalar::Scalar;
170    use vortex_utils::aliases::hash_map::HashMap;
171
172    use crate::list_contains::list_contains;
173    use crate::pruning::checked_pruning_expr;
174    use crate::{Arc, HashSet, Scope, and, col, get_item, gt, lit, lt, or, root};
175
176    fn test_array() -> ArrayRef {
177        ListArray::try_new(
178            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
179            PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
180            Validity::AllValid,
181        )
182        .unwrap()
183        .into_array()
184    }
185
186    #[test]
187    pub fn test_one() {
188        let arr = test_array();
189
190        let expr = list_contains(root(), lit(1));
191        let item = expr.evaluate(&Scope::new(arr)).unwrap();
192
193        assert_eq!(
194            item.scalar_at(0).unwrap(),
195            Scalar::bool(true, Nullability::Nullable)
196        );
197        assert_eq!(
198            item.scalar_at(1).unwrap(),
199            Scalar::bool(false, Nullability::Nullable)
200        );
201    }
202
203    #[test]
204    pub fn test_all() {
205        let arr = test_array();
206
207        let expr = list_contains(root(), lit(2));
208        let item = expr.evaluate(&Scope::new(arr)).unwrap();
209
210        assert_eq!(
211            item.scalar_at(0).unwrap(),
212            Scalar::bool(true, Nullability::Nullable)
213        );
214        assert_eq!(
215            item.scalar_at(1).unwrap(),
216            Scalar::bool(true, Nullability::Nullable)
217        );
218    }
219
220    #[test]
221    pub fn test_none() {
222        let arr = test_array();
223
224        let expr = list_contains(root(), lit(4));
225        let item = expr.evaluate(&Scope::new(arr)).unwrap();
226
227        assert_eq!(
228            item.scalar_at(0).unwrap(),
229            Scalar::bool(false, Nullability::Nullable)
230        );
231        assert_eq!(
232            item.scalar_at(1).unwrap(),
233            Scalar::bool(false, Nullability::Nullable)
234        );
235    }
236
237    #[test]
238    pub fn test_empty() {
239        let arr = ListArray::try_new(
240            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
241            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
242            Validity::AllValid,
243        )
244        .unwrap()
245        .into_array();
246
247        let expr = list_contains(root(), lit(2));
248        let item = expr.evaluate(&Scope::new(arr)).unwrap();
249
250        assert_eq!(
251            item.scalar_at(0).unwrap(),
252            Scalar::bool(true, Nullability::Nullable)
253        );
254        assert_eq!(
255            item.scalar_at(1).unwrap(),
256            Scalar::bool(false, Nullability::Nullable)
257        );
258    }
259
260    #[test]
261    pub fn test_nullable() {
262        let arr = ListArray::try_new(
263            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
264            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
265            Validity::Array(BoolArray::from(BooleanBuffer::from(vec![true, false])).into_array()),
266        )
267        .unwrap()
268        .into_array();
269
270        let expr = list_contains(root(), lit(2));
271        let item = expr.evaluate(&Scope::new(arr)).unwrap();
272
273        assert_eq!(
274            item.scalar_at(0).unwrap(),
275            Scalar::bool(true, Nullability::Nullable)
276        );
277        assert!(!item.is_valid(1).unwrap());
278    }
279
280    #[test]
281    pub fn test_return_type() {
282        let scope = DType::Struct(
283            StructFields::new(
284                ["array"].into(),
285                vec![DType::List(
286                    Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
287                    Nullability::Nullable,
288                )],
289            ),
290            Nullability::NonNullable,
291        );
292
293        let expr = list_contains(get_item("array", root()), lit(2));
294
295        // Expect nullable, although scope is non-nullable
296        assert_eq!(
297            expr.return_dtype(&scope).unwrap(),
298            DType::Bool(Nullability::Nullable)
299        );
300    }
301
302    #[test]
303    pub fn list_falsification() {
304        let expr = list_contains(
305            lit(Scalar::list(
306                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
307                vec![1.into(), 2.into(), 3.into()],
308                Nullability::NonNullable,
309            )),
310            col("a"),
311        );
312
313        let (expr, st) = checked_pruning_expr(
314            &expr,
315            &FieldPathSet::from_iter([
316                FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
317                FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
318            ]),
319        )
320        .unwrap();
321
322        assert_eq!(
323            &expr,
324            &and(
325                and(
326                    or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
327                    or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
328                ),
329                or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
330            )
331        );
332
333        assert_eq!(
334            st.map(),
335            &HashMap::from_iter([(
336                FieldPath::from_name("a"),
337                HashSet::from([Stat::Min, Stat::Max])
338            )])
339        );
340    }
341}