vortex_expr/exprs/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Formatter};
5use std::hash::Hash;
6
7use vortex_array::compute::list_contains as compute_list_contains;
8use vortex_array::{ArrayRef, DeserializeMetadata, EmptyMetadata};
9use vortex_dtype::DType;
10use vortex_error::{VortexResult, vortex_bail};
11
12use crate::display::{DisplayAs, DisplayFormat};
13use crate::{
14    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, LiteralVTable, Scope, StatsCatalog,
15    VTable, and, gt, lit, lt, or, vtable,
16};
17
18vtable!(ListContains);
19
20#[allow(clippy::derived_hash_with_manual_eq)]
21#[derive(Debug, Clone, Hash, Eq)]
22pub struct ListContainsExpr {
23    list: ExprRef,
24    value: ExprRef,
25}
26
27impl PartialEq for ListContainsExpr {
28    fn eq(&self, other: &Self) -> bool {
29        self.list.eq(&other.list) && self.value.eq(&other.value)
30    }
31}
32
33pub struct ListContainsExprEncoding;
34
35impl VTable for ListContainsVTable {
36    type Expr = ListContainsExpr;
37    type Encoding = ListContainsExprEncoding;
38    type Metadata = EmptyMetadata;
39
40    fn id(_encoding: &Self::Encoding) -> ExprId {
41        ExprId::new_ref("list_contains")
42    }
43
44    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
45        ExprEncodingRef::new_ref(ListContainsExprEncoding.as_ref())
46    }
47
48    fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
49        Some(EmptyMetadata)
50    }
51
52    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
53        vec![&expr.list, &expr.value]
54    }
55
56    fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
57        Ok(ListContainsExpr::new(
58            children[0].clone(),
59            children[1].clone(),
60        ))
61    }
62
63    fn build(
64        _encoding: &Self::Encoding,
65        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
66        children: Vec<ExprRef>,
67    ) -> VortexResult<Self::Expr> {
68        if children.len() != 2 {
69            vortex_bail!(
70                "ListContains expression must have exactly 2 children, got {}",
71                children.len()
72            );
73        }
74        Ok(ListContainsExpr::new(
75            children[0].clone(),
76            children[1].clone(),
77        ))
78    }
79
80    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
81        compute_list_contains(
82            expr.list.evaluate(scope)?.as_ref(),
83            expr.value.evaluate(scope)?.as_ref(),
84        )
85    }
86
87    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
88        Ok(DType::Bool(
89            expr.list.return_dtype(scope)?.nullability()
90                | expr.value.return_dtype(scope)?.nullability(),
91        ))
92    }
93}
94
95impl ListContainsExpr {
96    pub fn new(list: ExprRef, value: ExprRef) -> Self {
97        Self { list, value }
98    }
99
100    pub fn new_expr(list: ExprRef, value: ExprRef) -> ExprRef {
101        Self::new(list, value).into_expr()
102    }
103
104    pub fn value(&self) -> &ExprRef {
105        &self.value
106    }
107}
108
109/// Creates an expression that checks if a value is contained in a list.
110///
111/// Returns a boolean array indicating whether the value appears in each list.
112///
113/// ```rust
114/// # use vortex_expr::{list_contains, lit, root};
115/// let expr = list_contains(root(), lit(42));
116/// ```
117pub fn list_contains(list: ExprRef, value: ExprRef) -> ExprRef {
118    ListContainsExpr::new(list, value).into_expr()
119}
120
121impl DisplayAs for ListContainsExpr {
122    fn fmt_as(&self, df: DisplayFormat, f: &mut Formatter) -> std::fmt::Result {
123        match df {
124            DisplayFormat::Compact => {
125                write!(f, "contains({}, {})", &self.list, &self.value)
126            }
127            DisplayFormat::Tree => {
128                write!(f, "ListContains")
129            }
130        }
131    }
132
133    fn child_names(&self) -> Option<Vec<String>> {
134        Some(vec!["list".to_string(), "value".to_string()])
135    }
136}
137
138impl AnalysisExpr for ListContainsExpr {
139    // falsification(contains([1,2,5], x)) =>
140    //   falsification(x != 1) and falsification(x != 2) and falsification(x != 5)
141
142    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
143        let min = self.list.min(catalog)?;
144        let max = self.list.max(catalog)?;
145        // If the list is constant when we can compare each element to the value
146        if min == max {
147            let list_ = min
148                .as_opt::<LiteralVTable>()
149                .and_then(|l| l.value().as_list_opt())
150                .and_then(|l| l.elements())?;
151            if list_.is_empty() {
152                // contains([], x) is always false.
153                return Some(lit(true));
154            }
155            let value_max = self.value.max(catalog)?;
156            let value_min = self.value.min(catalog)?;
157
158            return list_
159                .iter()
160                .map(move |v| {
161                    or(
162                        lt(value_max.clone(), lit(v.clone())),
163                        gt(value_min.clone(), lit(v.clone())),
164                    )
165                })
166                .reduce(and);
167        }
168
169        None
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use vortex_array::arrays::{BoolArray, BooleanBuffer, ListArray};
176    use vortex_array::stats::Stat;
177    use vortex_array::validity::Validity;
178    use vortex_array::{Array, ArrayRef, IntoArray};
179    use vortex_buffer::buffer;
180    use vortex_dtype::PType::I32;
181    use vortex_dtype::{DType, Field, FieldPath, FieldPathSet, Nullability, StructFields};
182    use vortex_scalar::Scalar;
183    use vortex_utils::aliases::hash_map::HashMap;
184
185    use crate::list_contains::list_contains;
186    use crate::pruning::checked_pruning_expr;
187    use crate::{Arc, HashSet, Scope, and, col, get_item, gt, lit, lt, or, root};
188
189    fn test_array() -> ArrayRef {
190        ListArray::try_new(
191            buffer![1, 1, 2, 2, 2, 2, 2, 3, 3, 3].into_array(),
192            buffer![0, 5, 10].into_array(),
193            Validity::AllValid,
194        )
195        .unwrap()
196        .into_array()
197    }
198
199    #[test]
200    pub fn test_one() {
201        let arr = test_array();
202
203        let expr = list_contains(root(), lit(1));
204        let item = expr.evaluate(&Scope::new(arr)).unwrap();
205
206        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
207        assert_eq!(
208            item.scalar_at(1),
209            Scalar::bool(false, Nullability::Nullable)
210        );
211    }
212
213    #[test]
214    pub fn test_all() {
215        let arr = test_array();
216
217        let expr = list_contains(root(), lit(2));
218        let item = expr.evaluate(&Scope::new(arr)).unwrap();
219
220        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
221        assert_eq!(item.scalar_at(1), Scalar::bool(true, Nullability::Nullable));
222    }
223
224    #[test]
225    pub fn test_none() {
226        let arr = test_array();
227
228        let expr = list_contains(root(), lit(4));
229        let item = expr.evaluate(&Scope::new(arr)).unwrap();
230
231        assert_eq!(
232            item.scalar_at(0),
233            Scalar::bool(false, Nullability::Nullable)
234        );
235        assert_eq!(
236            item.scalar_at(1),
237            Scalar::bool(false, Nullability::Nullable)
238        );
239    }
240
241    #[test]
242    pub fn test_empty() {
243        let arr = ListArray::try_new(
244            buffer![1, 1, 2, 2, 2].into_array(),
245            buffer![0, 5, 5].into_array(),
246            Validity::AllValid,
247        )
248        .unwrap()
249        .into_array();
250
251        let expr = list_contains(root(), lit(2));
252        let item = expr.evaluate(&Scope::new(arr)).unwrap();
253
254        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
255        assert_eq!(
256            item.scalar_at(1),
257            Scalar::bool(false, Nullability::Nullable)
258        );
259    }
260
261    #[test]
262    pub fn test_nullable() {
263        let arr = ListArray::try_new(
264            buffer![1, 1, 2, 2, 2].into_array(),
265            buffer![0, 5, 5].into_array(),
266            Validity::Array(BoolArray::from(BooleanBuffer::from(vec![true, false])).into_array()),
267        )
268        .unwrap()
269        .into_array();
270
271        let expr = list_contains(root(), lit(2));
272        let item = expr.evaluate(&Scope::new(arr)).unwrap();
273
274        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
275        assert!(!item.is_valid(1));
276    }
277
278    #[test]
279    pub fn test_return_type() {
280        let scope = DType::Struct(
281            StructFields::new(
282                ["array"].into(),
283                vec![DType::List(
284                    Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
285                    Nullability::Nullable,
286                )],
287            ),
288            Nullability::NonNullable,
289        );
290
291        let expr = list_contains(get_item("array", root()), lit(2));
292
293        // Expect nullable, although scope is non-nullable
294        assert_eq!(
295            expr.return_dtype(&scope).unwrap(),
296            DType::Bool(Nullability::Nullable)
297        );
298    }
299
300    #[test]
301    pub fn list_falsification() {
302        let expr = list_contains(
303            lit(Scalar::list(
304                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
305                vec![1.into(), 2.into(), 3.into()],
306                Nullability::NonNullable,
307            )),
308            col("a"),
309        );
310
311        let (expr, st) = checked_pruning_expr(
312            &expr,
313            &FieldPathSet::from_iter([
314                FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
315                FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
316            ]),
317        )
318        .unwrap();
319
320        assert_eq!(
321            &expr,
322            &and(
323                and(
324                    or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
325                    or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
326                ),
327                or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
328            )
329        );
330
331        assert_eq!(
332            st.map(),
333            &HashMap::from_iter([(
334                FieldPath::from_name("a"),
335                HashSet::from([Stat::Min, Stat::Max])
336            )])
337        );
338    }
339
340    #[test]
341    pub fn test_display() {
342        let expr = list_contains(get_item("tags", root()), lit("urgent"));
343        assert_eq!(expr.to_string(), "contains($.tags, \"urgent\")");
344
345        let expr2 = list_contains(root(), lit(42));
346        assert_eq!(expr2.to_string(), "contains($, 42i32)");
347    }
348}