vortex_expr/exprs/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Formatter};
5use std::hash::Hash;
6
7use vortex_array::compute::list_contains as compute_list_contains;
8use vortex_array::{ArrayRef, DeserializeMetadata, EmptyMetadata};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11
12use crate::display::{DisplayAs, DisplayFormat};
13use crate::{
14    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, LiteralVTable, Scope, StatsCatalog,
15    VTable, and, gt, lit, lt, or, vtable,
16};
17
18vtable!(ListContains);
19
20#[allow(clippy::derived_hash_with_manual_eq)]
21#[derive(Debug, Clone, Hash, Eq)]
22pub struct ListContainsExpr {
23    list: ExprRef,
24    value: ExprRef,
25}
26
27impl PartialEq for ListContainsExpr {
28    fn eq(&self, other: &Self) -> bool {
29        self.list.eq(&other.list) && self.value.eq(&other.value)
30    }
31}
32
33pub struct ListContainsExprEncoding;
34
35impl VTable for ListContainsVTable {
36    type Expr = ListContainsExpr;
37    type Encoding = ListContainsExprEncoding;
38    type Metadata = EmptyMetadata;
39
40    fn id(_encoding: &Self::Encoding) -> ExprId {
41        ExprId::new_ref("list_contains")
42    }
43
44    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
45        ExprEncodingRef::new_ref(ListContainsExprEncoding.as_ref())
46    }
47
48    fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
49        Some(EmptyMetadata)
50    }
51
52    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
53        vec![&expr.list, &expr.value]
54    }
55
56    fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
57        Ok(ListContainsExpr::new(
58            children[0].clone(),
59            children[1].clone(),
60        ))
61    }
62
63    fn build(
64        _encoding: &Self::Encoding,
65        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
66        children: Vec<ExprRef>,
67    ) -> VortexResult<Self::Expr> {
68        if children.len() != 2 {
69            vortex_bail!(
70                "ListContains expression must have exactly 2 children, got {}",
71                children.len()
72            );
73        }
74        Ok(ListContainsExpr::new(
75            children[0].clone(),
76            children[1].clone(),
77        ))
78    }
79
80    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
81        compute_list_contains(
82            expr.list.evaluate(scope)?.as_ref(),
83            expr.value.evaluate(scope)?.as_ref(),
84        )
85    }
86
87    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
88        Ok(DType::Bool(
89            expr.list.return_dtype(scope)?.nullability()
90                | expr.value.return_dtype(scope)?.nullability(),
91        ))
92    }
93}
94
95impl ListContainsExpr {
96    pub fn new(list: ExprRef, value: ExprRef) -> Self {
97        Self { list, value }
98    }
99
100    pub fn new_expr(list: ExprRef, value: ExprRef) -> ExprRef {
101        Self::new(list, value).into_expr()
102    }
103
104    pub fn value(&self) -> &ExprRef {
105        &self.value
106    }
107}
108
109/// Creates an expression that checks if a value is contained in a list.
110///
111/// Returns a boolean array indicating whether the value appears in each list.
112///
113/// ```rust
114/// # use vortex_expr::{list_contains, lit, root};
115/// let expr = list_contains(root(), lit(42));
116/// ```
117pub fn list_contains(list: ExprRef, value: ExprRef) -> ExprRef {
118    ListContainsExpr::new(list, value).into_expr()
119}
120
121impl DisplayAs for ListContainsExpr {
122    fn fmt_as(&self, df: DisplayFormat, f: &mut Formatter) -> std::fmt::Result {
123        match df {
124            DisplayFormat::Compact => {
125                write!(f, "contains({}, {})", &self.list, &self.value)
126            }
127            DisplayFormat::Tree => {
128                write!(f, "ListContains")
129            }
130        }
131    }
132
133    fn child_names(&self) -> Option<Vec<String>> {
134        Some(vec!["list".to_string(), "value".to_string()])
135    }
136}
137
138impl AnalysisExpr for ListContainsExpr {
139    // falsification(contains([1,2,5], x)) =>
140    //   falsification(x != 1) and falsification(x != 2) and falsification(x != 5)
141
142    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
143        let min = self.list.min(catalog)?;
144        let max = self.list.max(catalog)?;
145        // If the list is constant when we can compare each element to the value
146        if min == max {
147            let list_ = min
148                .as_opt::<LiteralVTable>()
149                .and_then(|l| l.value().as_list_opt())
150                .and_then(|l| l.elements())?;
151            if list_.is_empty() {
152                // contains([], x) is always false.
153                return Some(lit(true));
154            }
155            let value_max = self.value.max(catalog)?;
156            let value_min = self.value.min(catalog)?;
157
158            return list_
159                .iter()
160                .map(move |v| {
161                    or(
162                        lt(value_max.clone(), lit(v.clone())),
163                        gt(value_min.clone(), lit(v.clone())),
164                    )
165                })
166                .reduce(and);
167        }
168
169        None
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use vortex_array::arrays::{BoolArray, BooleanBuffer, ListArray, PrimitiveArray};
176    use vortex_array::stats::Stat;
177    use vortex_array::validity::Validity;
178    use vortex_array::{Array, ArrayRef, IntoArray};
179    use vortex_dtype::PType::I32;
180    use vortex_dtype::{DType, Field, FieldPath, FieldPathSet, Nullability, StructFields};
181    use vortex_scalar::Scalar;
182    use vortex_utils::aliases::hash_map::HashMap;
183
184    use crate::list_contains::list_contains;
185    use crate::pruning::checked_pruning_expr;
186    use crate::{Arc, HashSet, Scope, and, col, get_item, gt, lit, lt, or, root};
187
188    fn test_array() -> ArrayRef {
189        ListArray::try_new(
190            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
191            PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
192            Validity::AllValid,
193        )
194        .unwrap()
195        .into_array()
196    }
197
198    #[test]
199    pub fn test_one() {
200        let arr = test_array();
201
202        let expr = list_contains(root(), lit(1));
203        let item = expr.evaluate(&Scope::new(arr)).unwrap();
204
205        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
206        assert_eq!(
207            item.scalar_at(1),
208            Scalar::bool(false, Nullability::Nullable)
209        );
210    }
211
212    #[test]
213    pub fn test_all() {
214        let arr = test_array();
215
216        let expr = list_contains(root(), lit(2));
217        let item = expr.evaluate(&Scope::new(arr)).unwrap();
218
219        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
220        assert_eq!(item.scalar_at(1), Scalar::bool(true, Nullability::Nullable));
221    }
222
223    #[test]
224    pub fn test_none() {
225        let arr = test_array();
226
227        let expr = list_contains(root(), lit(4));
228        let item = expr.evaluate(&Scope::new(arr)).unwrap();
229
230        assert_eq!(
231            item.scalar_at(0),
232            Scalar::bool(false, Nullability::Nullable)
233        );
234        assert_eq!(
235            item.scalar_at(1),
236            Scalar::bool(false, Nullability::Nullable)
237        );
238    }
239
240    #[test]
241    pub fn test_empty() {
242        let arr = ListArray::try_new(
243            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
244            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
245            Validity::AllValid,
246        )
247        .unwrap()
248        .into_array();
249
250        let expr = list_contains(root(), lit(2));
251        let item = expr.evaluate(&Scope::new(arr)).unwrap();
252
253        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
254        assert_eq!(
255            item.scalar_at(1),
256            Scalar::bool(false, Nullability::Nullable)
257        );
258    }
259
260    #[test]
261    pub fn test_nullable() {
262        let arr = ListArray::try_new(
263            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
264            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
265            Validity::Array(BoolArray::from(BooleanBuffer::from(vec![true, false])).into_array()),
266        )
267        .unwrap()
268        .into_array();
269
270        let expr = list_contains(root(), lit(2));
271        let item = expr.evaluate(&Scope::new(arr)).unwrap();
272
273        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
274        assert!(!item.is_valid(1));
275    }
276
277    #[test]
278    pub fn test_return_type() {
279        let scope = DType::Struct(
280            StructFields::new(
281                ["array"].into(),
282                vec![DType::List(
283                    Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
284                    Nullability::Nullable,
285                )],
286            ),
287            Nullability::NonNullable,
288        );
289
290        let expr = list_contains(get_item("array", root()), lit(2));
291
292        // Expect nullable, although scope is non-nullable
293        assert_eq!(
294            expr.return_dtype(&scope).unwrap(),
295            DType::Bool(Nullability::Nullable)
296        );
297    }
298
299    #[test]
300    pub fn list_falsification() {
301        let expr = list_contains(
302            lit(Scalar::list(
303                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
304                vec![1.into(), 2.into(), 3.into()],
305                Nullability::NonNullable,
306            )),
307            col("a"),
308        );
309
310        let (expr, st) = checked_pruning_expr(
311            &expr,
312            &FieldPathSet::from_iter([
313                FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
314                FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
315            ]),
316        )
317        .unwrap();
318
319        assert_eq!(
320            &expr,
321            &and(
322                and(
323                    or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
324                    or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
325                ),
326                or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
327            )
328        );
329
330        assert_eq!(
331            st.map(),
332            &HashMap::from_iter([(
333                FieldPath::from_name("a"),
334                HashSet::from([Stat::Min, Stat::Max])
335            )])
336        );
337    }
338
339    #[test]
340    pub fn test_display() {
341        let expr = list_contains(get_item("tags", root()), lit("urgent"));
342        assert_eq!(expr.to_string(), "contains($.tags, \"urgent\")");
343
344        let expr2 = list_contains(root(), lit(42));
345        assert_eq!(expr2.to_string(), "contains($, 42i32)");
346    }
347}