Skip to main content

vortex_array/expr/exprs/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::ops::BitOr;
6
7use vortex_dtype::DType;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_err;
11use vortex_session::VortexSession;
12
13use crate::ArrayRef;
14use crate::IntoArray;
15use crate::arrays::ConstantArray;
16use crate::compute::list_contains as compute_list_contains;
17use crate::expr::Arity;
18use crate::expr::ChildName;
19use crate::expr::EmptyOptions;
20use crate::expr::ExecutionArgs;
21use crate::expr::ExprId;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::VTable;
25use crate::expr::VTableExt;
26use crate::expr::and_collect;
27use crate::expr::exprs::binary::gt;
28use crate::expr::exprs::binary::lt;
29use crate::expr::exprs::binary::or;
30use crate::expr::exprs::literal::Literal;
31use crate::expr::exprs::literal::lit;
32use crate::scalar::Scalar;
33
34pub struct ListContains;
35
36impl VTable for ListContains {
37    type Options = EmptyOptions;
38
39    fn id(&self) -> ExprId {
40        ExprId::from("vortex.list.contains")
41    }
42
43    fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
44        Ok(Some(vec![]))
45    }
46
47    fn deserialize(
48        &self,
49        _metadata: &[u8],
50        _session: &VortexSession,
51    ) -> VortexResult<Self::Options> {
52        Ok(EmptyOptions)
53    }
54
55    fn arity(&self, _options: &Self::Options) -> Arity {
56        Arity::Exact(2)
57    }
58
59    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
60        match child_idx {
61            0 => ChildName::from("list"),
62            1 => ChildName::from("needle"),
63            _ => unreachable!(
64                "Invalid child index {} for ListContains expression",
65                child_idx
66            ),
67        }
68    }
69    fn fmt_sql(
70        &self,
71        _options: &Self::Options,
72        expr: &Expression,
73        f: &mut Formatter<'_>,
74    ) -> std::fmt::Result {
75        write!(f, "contains(")?;
76        expr.child(0).fmt_sql(f)?;
77        write!(f, ", ")?;
78        expr.child(1).fmt_sql(f)?;
79        write!(f, ")")
80    }
81
82    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
83        let list_dtype = &arg_dtypes[0];
84        let needle_dtype = &arg_dtypes[0];
85
86        let nullability = match list_dtype {
87            DType::List(_, list_nullability) => list_nullability,
88            _ => {
89                vortex_bail!(
90                    "First argument to ListContains must be a List, got {:?}",
91                    list_dtype
92                );
93            }
94        }
95        .bitor(needle_dtype.nullability());
96
97        Ok(DType::Bool(nullability))
98    }
99
100    fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
101        let [list_array, value_array]: [ArrayRef; _] = args
102            .inputs
103            .try_into()
104            .map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?;
105
106        if let Some(list_scalar) = list_array.as_constant()
107            && let Some(value_scalar) = value_array.as_constant()
108        {
109            let result = compute_contains_scalar(&list_scalar, &value_scalar)?;
110            return Ok(ConstantArray::new(result, args.row_count).into_array());
111        }
112
113        compute_list_contains(list_array.as_ref(), value_array.as_ref())?.execute(args.ctx)
114    }
115
116    fn stat_falsification(
117        &self,
118        _options: &Self::Options,
119        expr: &Expression,
120        catalog: &dyn StatsCatalog,
121    ) -> Option<Expression> {
122        let list = expr.child(0);
123        let needle = expr.child(1);
124
125        // falsification(contains([1,2,5], x)) =>
126        //   falsification(x != 1) and falsification(x != 2) and falsification(x != 5)
127        let min = list.stat_min(catalog)?;
128        let max = list.stat_max(catalog)?;
129        // If the list is constant when we can compare each element to the value
130        if min == max {
131            let list_ = min
132                .as_opt::<Literal>()
133                .and_then(|l| l.as_list_opt())
134                .and_then(|l| l.elements())?;
135            if list_.is_empty() {
136                // contains([], x) is always false.
137                return Some(lit(true));
138            }
139            let value_max = needle.stat_max(catalog)?;
140            let value_min = needle.stat_min(catalog)?;
141
142            return and_collect(list_.iter().map(move |v| {
143                or(
144                    lt(value_max.clone(), lit(v.clone())),
145                    gt(value_min.clone(), lit(v.clone())),
146                )
147            }));
148        }
149
150        None
151    }
152
153    // Nullability matters for contains([], x) where x is false.
154    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
155        true
156    }
157
158    fn is_fallible(&self, _options: &Self::Options) -> bool {
159        false
160    }
161}
162
163fn compute_contains_scalar(list: &Scalar, needle: &Scalar) -> VortexResult<Scalar> {
164    let nullability = list.dtype().nullability() | needle.dtype().nullability();
165
166    // Handle null list or null needle
167    if list.is_null() || needle.is_null() {
168        return Ok(Scalar::null(DType::Bool(nullability)));
169    }
170
171    let list_scalar = list.as_list();
172    let elements = list_scalar
173        .elements()
174        .ok_or_else(|| vortex_err!("Expected non-null list"))?;
175
176    let contains = elements.iter().any(|elem| elem == needle);
177    Ok(Scalar::bool(contains, nullability))
178}
179
180/// Creates an expression that checks if a value is contained in a list.
181///
182/// Returns a boolean array indicating whether the value appears in each list.
183///
184/// ```rust
185/// # use vortex_array::expr::{list_contains, lit, root};
186/// let expr = list_contains(root(), lit(42));
187/// ```
188pub fn list_contains(list: Expression, value: Expression) -> Expression {
189    ListContains.new_expr(EmptyOptions, [list, value])
190}
191
192#[cfg(test)]
193mod tests {
194    use std::sync::Arc;
195
196    use vortex_buffer::BitBuffer;
197    use vortex_dtype::DType;
198    use vortex_dtype::Field;
199    use vortex_dtype::FieldPath;
200    use vortex_dtype::FieldPathSet;
201    use vortex_dtype::Nullability;
202    use vortex_dtype::PType::I32;
203    use vortex_dtype::StructFields;
204    use vortex_utils::aliases::hash_map::HashMap;
205    use vortex_utils::aliases::hash_set::HashSet;
206
207    use super::list_contains;
208    use crate::Array;
209    use crate::ArrayRef;
210    use crate::IntoArray;
211    use crate::arrays::BoolArray;
212    use crate::arrays::ListArray;
213    use crate::arrays::PrimitiveArray;
214    use crate::expr::exprs::binary::and;
215    use crate::expr::exprs::binary::gt;
216    use crate::expr::exprs::binary::lt;
217    use crate::expr::exprs::binary::or;
218    use crate::expr::exprs::get_item::col;
219    use crate::expr::exprs::get_item::get_item;
220    use crate::expr::exprs::literal::lit;
221    use crate::expr::exprs::root::root;
222    use crate::expr::pruning::checked_pruning_expr;
223    use crate::expr::stats::Stat;
224    use crate::scalar::Scalar;
225    use crate::validity::Validity;
226
227    fn test_array() -> ArrayRef {
228        ListArray::try_new(
229            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
230            PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
231            Validity::AllValid,
232        )
233        .unwrap()
234        .into_array()
235    }
236
237    #[test]
238    pub fn test_one() {
239        let arr = test_array();
240
241        let expr = list_contains(root(), lit(1));
242        let item = arr.apply(&expr).unwrap();
243
244        assert_eq!(
245            item.scalar_at(0).unwrap(),
246            Scalar::bool(true, Nullability::Nullable)
247        );
248        assert_eq!(
249            item.scalar_at(1).unwrap(),
250            Scalar::bool(false, Nullability::Nullable)
251        );
252    }
253
254    #[test]
255    pub fn test_all() {
256        let arr = test_array();
257
258        let expr = list_contains(root(), lit(2));
259        let item = arr.apply(&expr).unwrap();
260
261        assert_eq!(
262            item.scalar_at(0).unwrap(),
263            Scalar::bool(true, Nullability::Nullable)
264        );
265        assert_eq!(
266            item.scalar_at(1).unwrap(),
267            Scalar::bool(true, Nullability::Nullable)
268        );
269    }
270
271    #[test]
272    pub fn test_none() {
273        let arr = test_array();
274
275        let expr = list_contains(root(), lit(4));
276        let item = arr.apply(&expr).unwrap();
277
278        assert_eq!(
279            item.scalar_at(0).unwrap(),
280            Scalar::bool(false, Nullability::Nullable)
281        );
282        assert_eq!(
283            item.scalar_at(1).unwrap(),
284            Scalar::bool(false, Nullability::Nullable)
285        );
286    }
287
288    #[test]
289    pub fn test_empty() {
290        let arr = ListArray::try_new(
291            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
292            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
293            Validity::AllValid,
294        )
295        .unwrap()
296        .into_array();
297
298        let expr = list_contains(root(), lit(2));
299        let item = arr.apply(&expr).unwrap();
300
301        assert_eq!(
302            item.scalar_at(0).unwrap(),
303            Scalar::bool(true, Nullability::Nullable)
304        );
305        assert_eq!(
306            item.scalar_at(1).unwrap(),
307            Scalar::bool(false, Nullability::Nullable)
308        );
309    }
310
311    #[test]
312    pub fn test_nullable() {
313        let arr = ListArray::try_new(
314            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
315            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
316            Validity::Array(BoolArray::from(BitBuffer::from(vec![true, false])).into_array()),
317        )
318        .unwrap()
319        .into_array();
320
321        let expr = list_contains(root(), lit(2));
322        let item = arr.apply(&expr).unwrap();
323
324        assert_eq!(
325            item.scalar_at(0).unwrap(),
326            Scalar::bool(true, Nullability::Nullable)
327        );
328        assert!(!item.is_valid(1).unwrap());
329    }
330
331    #[test]
332    pub fn test_return_type() {
333        let scope = DType::Struct(
334            StructFields::new(
335                ["array"].into(),
336                vec![DType::List(
337                    Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
338                    Nullability::Nullable,
339                )],
340            ),
341            Nullability::NonNullable,
342        );
343
344        let expr = list_contains(get_item("array", root()), lit(2));
345
346        // Expect nullable, although scope is non-nullable
347        assert_eq!(
348            expr.return_dtype(&scope).unwrap(),
349            DType::Bool(Nullability::Nullable)
350        );
351    }
352
353    #[test]
354    pub fn list_falsification() {
355        let expr = list_contains(
356            lit(Scalar::list(
357                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
358                vec![1.into(), 2.into(), 3.into()],
359                Nullability::NonNullable,
360            )),
361            col("a"),
362        );
363
364        let (expr, st) = checked_pruning_expr(
365            &expr,
366            &FieldPathSet::from_iter([
367                FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
368                FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
369            ]),
370        )
371        .unwrap();
372
373        assert_eq!(
374            &expr,
375            &and(
376                and(
377                    or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
378                    or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
379                ),
380                or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
381            )
382        );
383
384        assert_eq!(
385            st.map(),
386            &HashMap::from_iter([(
387                FieldPath::from_name("a"),
388                HashSet::from([Stat::Min, Stat::Max])
389            )])
390        );
391    }
392
393    #[test]
394    pub fn test_display() {
395        let expr = list_contains(get_item("tags", root()), lit("urgent"));
396        assert_eq!(expr.to_string(), "contains($.tags, \"urgent\")");
397
398        let expr2 = list_contains(root(), lit(42));
399        assert_eq!(expr2.to_string(), "contains($, 42i32)");
400    }
401
402    #[test]
403    pub fn test_constant_scalars() {
404        let arr = test_array();
405
406        // Both list and needle are constants - should use scalar optimization
407        let list_scalar = Scalar::list(
408            Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
409            vec![1.into(), 2.into(), 3.into()],
410            Nullability::NonNullable,
411        );
412
413        // Test contains true
414        let expr = list_contains(lit(list_scalar.clone()), lit(2i32));
415        let result = arr.apply(&expr).unwrap();
416        assert_eq!(
417            result.scalar_at(0).unwrap(),
418            Scalar::bool(true, Nullability::NonNullable)
419        );
420
421        // Test contains false
422        let expr = list_contains(lit(list_scalar), lit(42i32));
423        let result = arr.apply(&expr).unwrap();
424        assert_eq!(
425            result.scalar_at(0).unwrap(),
426            Scalar::bool(false, Nullability::NonNullable)
427        );
428    }
429}