vortex_array/expr/exprs/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::ops::BitOr;
6use std::ops::Deref;
7
8use arrow_buffer::bit_iterator::BitIndexIterator;
9use vortex_buffer::BitBuffer;
10use vortex_compute::logical::LogicalOr;
11use vortex_dtype::DType;
12use vortex_dtype::IntegerPType;
13use vortex_dtype::Nullability;
14use vortex_dtype::PTypeDowncastExt;
15use vortex_dtype::match_each_integer_ptype;
16use vortex_error::VortexResult;
17use vortex_error::vortex_bail;
18use vortex_error::vortex_err;
19use vortex_mask::Mask;
20use vortex_vector::BoolDatum;
21use vortex_vector::Datum;
22use vortex_vector::Vector;
23use vortex_vector::VectorOps;
24use vortex_vector::bool::BoolScalar;
25use vortex_vector::bool::BoolVector;
26use vortex_vector::listview::ListViewScalar;
27use vortex_vector::listview::ListViewVector;
28use vortex_vector::primitive::PVector;
29
30use crate::ArrayRef;
31use crate::compute::list_contains as compute_list_contains;
32use crate::expr::Arity;
33use crate::expr::Binary;
34use crate::expr::ChildName;
35use crate::expr::EmptyOptions;
36use crate::expr::ExecutionArgs;
37use crate::expr::ExprId;
38use crate::expr::Expression;
39use crate::expr::StatsCatalog;
40use crate::expr::VTable;
41use crate::expr::VTableExt;
42use crate::expr::exprs::binary::and;
43use crate::expr::exprs::binary::gt;
44use crate::expr::exprs::binary::lt;
45use crate::expr::exprs::binary::or;
46use crate::expr::exprs::literal::Literal;
47use crate::expr::exprs::literal::lit;
48use crate::expr::operators;
49
50pub struct ListContains;
51
52impl VTable for ListContains {
53    type Options = EmptyOptions;
54
55    fn id(&self) -> ExprId {
56        ExprId::from("vortex.list.contains")
57    }
58
59    fn serialize(&self, _instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
60        Ok(Some(vec![]))
61    }
62
63    fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Self::Options> {
64        Ok(EmptyOptions)
65    }
66
67    fn arity(&self, _options: &Self::Options) -> Arity {
68        Arity::Exact(2)
69    }
70
71    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
72        match child_idx {
73            0 => ChildName::from("list"),
74            1 => ChildName::from("needle"),
75            _ => unreachable!(
76                "Invalid child index {} for ListContains expression",
77                child_idx
78            ),
79        }
80    }
81    fn fmt_sql(
82        &self,
83        _options: &Self::Options,
84        expr: &Expression,
85        f: &mut Formatter<'_>,
86    ) -> std::fmt::Result {
87        write!(f, "contains(")?;
88        expr.child(0).fmt_sql(f)?;
89        write!(f, ", ")?;
90        expr.child(1).fmt_sql(f)?;
91        write!(f, ")")
92    }
93
94    fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
95        let list_dtype = &arg_dtypes[0];
96        let needle_dtype = &arg_dtypes[0];
97
98        let nullability = match list_dtype {
99            DType::List(_, list_nullability) => list_nullability,
100            _ => {
101                vortex_bail!(
102                    "First argument to ListContains must be a List, got {:?}",
103                    list_dtype
104                );
105            }
106        }
107        .bitor(needle_dtype.nullability());
108
109        Ok(DType::Bool(nullability))
110    }
111
112    fn evaluate(
113        &self,
114        _options: &Self::Options,
115        expr: &Expression,
116        scope: &ArrayRef,
117    ) -> VortexResult<ArrayRef> {
118        let list_array = expr.child(0).evaluate(scope)?;
119        let value_array = expr.child(1).evaluate(scope)?;
120        compute_list_contains(list_array.as_ref(), value_array.as_ref())
121    }
122
123    fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
124        let [lhs, rhs]: [Datum; _] = args
125            .datums
126            .try_into()
127            .map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?;
128
129        match (lhs, rhs) {
130            (Datum::Scalar(list_scalar), Datum::Scalar(needle_scalar)) => {
131                let list = list_scalar.into_list();
132                let found = list_contains_scalar_scalar(&list, &needle_scalar)?;
133                Ok(Datum::Scalar(BoolScalar::new(Some(found)).into()))
134            }
135            (Datum::Scalar(list_scalar), Datum::Vector(needle_vector)) => {
136                let matches =
137                    constant_list_scalar_contains(list_scalar.into_list(), needle_vector)?;
138                Ok(Datum::Vector(matches.into()))
139            }
140            (Datum::Vector(list_vector), Datum::Scalar(needle_scalar)) => {
141                let matches =
142                    list_contains_scalar(list_vector.into_list(), needle_scalar.into_list())?;
143                Ok(Datum::Vector(matches.into()))
144            }
145            (Datum::Vector(_), Datum::Vector(_)) => {
146                vortex_bail!(
147                    "ListContains currently only supports constant needle (RHS) or constant list (LHS)"
148                )
149            }
150        }
151    }
152
153    fn stat_falsification(
154        &self,
155        _options: &Self::Options,
156        expr: &Expression,
157        catalog: &dyn StatsCatalog,
158    ) -> Option<Expression> {
159        let list = expr.child(0);
160        let needle = expr.child(1);
161
162        // falsification(contains([1,2,5], x)) =>
163        //   falsification(x != 1) and falsification(x != 2) and falsification(x != 5)
164        let min = list.stat_min(catalog)?;
165        let max = list.stat_max(catalog)?;
166        // If the list is constant when we can compare each element to the value
167        if min == max {
168            let list_ = min
169                .as_opt::<Literal>()
170                .and_then(|l| l.as_list_opt())
171                .and_then(|l| l.elements())?;
172            if list_.is_empty() {
173                // contains([], x) is always false.
174                return Some(lit(true));
175            }
176            let value_max = needle.stat_max(catalog)?;
177            let value_min = needle.stat_min(catalog)?;
178
179            return list_
180                .iter()
181                .map(move |v| {
182                    or(
183                        lt(value_max.clone(), lit(v.clone())),
184                        gt(value_min.clone(), lit(v.clone())),
185                    )
186                })
187                .reduce(and);
188        }
189
190        None
191    }
192
193    // Nullability matters for contains([], x) where x is false.
194    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
195        true
196    }
197
198    fn is_fallible(&self, _options: &Self::Options) -> bool {
199        false
200    }
201}
202
203/// Creates an expression that checks if a value is contained in a list.
204///
205/// Returns a boolean array indicating whether the value appears in each list.
206///
207/// ```rust
208/// # use vortex_array::expr::{list_contains, lit, root};
209/// let expr = list_contains(root(), lit(42));
210/// ```
211pub fn list_contains(list: Expression, value: Expression) -> Expression {
212    ListContains.new_expr(EmptyOptions, [list, value])
213}
214
215/// Returns a [`BoolVector`] where each bit represents if a list contains the scalar.
216// FIXME(ngates): test implementation and move to vortex-compute
217fn list_contains_scalar(list: ListViewVector, value: ListViewScalar) -> VortexResult<BoolVector> {
218    // If the list array is constant, we perform a single comparison.
219    // if list.len() > 1 && list.is_constant() {
220    //     let contains = list_contains_scalar(&array.slice(0..1), value, nullability)?;
221    //     return Ok(ConstantArray::new(contains.scalar_at(0), array.len()).into_array());
222    // }
223
224    let elems = list.elements();
225    if elems.is_empty() {
226        // Must return false when a list is empty (but valid), or null when the list itself is null.
227        // return crate::compute::list_contains::list_false_or_null(&list_array, nullability);
228        todo!()
229    }
230
231    let matches = Binary
232        .bind(operators::Operator::Eq)
233        .execute(ExecutionArgs {
234            datums: vec![
235                Datum::Vector(elems.deref().clone()),
236                Datum::Scalar(value.into()),
237            ],
238            // FIXME(ngates): dtypes
239            dtypes: vec![],
240            row_count: elems.len(),
241            return_dtype: DType::Bool(Nullability::Nullable),
242        })?
243        .unwrap_into_vector(elems.len())
244        .into_bool()
245        .into_bits();
246
247    // // Fast path: no elements match.
248    // if let Some(pred) = matches.as_constant() {
249    //     return match pred.as_bool().value() {
250    //         // All comparisons are invalid (result in `null`), and search is not null because
251    //         // we already checked for null above.
252    //         None => {
253    //             assert!(
254    //                 !rhs.scalar().is_null(),
255    //                 "Search value must not be null here"
256    //             );
257    //             // False, unless the list itself is null in which case we return null.
258    //             crate::compute::list_contains::list_false_or_null(&list_array, nullability)
259    //         }
260    //         // No elements match, and all comparisons are valid (result in `false`).
261    //         Some(false) => {
262    //             // False, but match the nullability to the input list array.
263    //             Ok(
264    //                 ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
265    //                     .into_array(),
266    //             )
267    //         }
268    //         // All elements match, and all comparisons are valid (result in `true`).
269    //         Some(true) => {
270    //             // True, unless the list itself is empty or NULL.
271    //             crate::compute::list_contains::list_is_not_empty(&list_array, nullability)
272    //         }
273    //     };
274    // }
275
276    // Get the offsets and sizes as primitive arrays.
277    let offsets = list.offsets();
278    let sizes = list.sizes();
279
280    // Process based on the offset and size types.
281    let list_matches = match_each_integer_ptype!(offsets.ptype(), |O| {
282        match_each_integer_ptype!(sizes.ptype(), |S| {
283            process_matches::<O, S>(
284                matches,
285                list.len(),
286                offsets.downcast::<O>(),
287                sizes.downcast::<S>(),
288            )
289        })
290    });
291
292    Ok(BoolVector::new(list_matches, list.validity().clone()))
293}
294
295// Then there is a constant list scalar (haystack) being compared to an array of needles.
296// FIXME(ngates): test implementation and move to vortex-compute
297fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> VortexResult<BoolVector> {
298    let elements = list.value().elements();
299
300    // For each element in the list, we perform a full comparison over the values and OR
301    // the results together.
302    let mut result: BoolVector = BoolVector::new(
303        BitBuffer::new_unset(values.len()),
304        Mask::new(values.len(), true),
305    );
306    for i in 0..elements.len() {
307        let element = Datum::Scalar(elements.scalar_at(i));
308        let compared: BoolDatum = Binary
309            .bind(operators::Operator::Eq)
310            .execute(ExecutionArgs {
311                datums: vec![Datum::Vector(values.clone()), element],
312                dtypes: vec![
313                    // FIXME(ngates): call compute function directly!
314                ],
315                row_count: values.len(),
316                return_dtype: DType::Bool(Nullability::Nullable),
317            })?
318            .into_bool();
319        let compared = Datum::from(compared)
320            .unwrap_into_vector(values.len())
321            .into_bool();
322
323        result = LogicalOr::or(&result, &compared);
324    }
325
326    Ok(result)
327}
328
329/// Used when the needle is a scalar checked for containment in a single list.
330fn list_contains_scalar_scalar(
331    list: &ListViewScalar,
332    needle: &vortex_vector::Scalar,
333) -> VortexResult<bool> {
334    let elements = list.value().elements();
335
336    // Note: If the comparison becomes a bottleneck, look into faster ways to check for list
337    // containment. `execute` allocates the returned vector on the heap. Further, the `eq`
338    // comparison does not short-circuit on the first match found.
339    let found = Binary
340        .bind(operators::Operator::Eq)
341        .execute(ExecutionArgs {
342            datums: vec![
343                Datum::Vector(elements.deref().clone()),
344                Datum::Scalar(needle.clone()),
345            ],
346            dtypes: vec![],
347            row_count: elements.len(),
348            return_dtype: DType::Bool(Nullability::Nullable),
349        })?
350        .unwrap_into_vector(elements.len())
351        .into_bool()
352        .into_bits();
353
354    let mut true_bits = BitIndexIterator::new(found.inner().as_ref(), 0, found.len());
355    Ok(true_bits.next().is_some())
356}
357
358/// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a
359/// [`BoolArray`] of matches on the child elements array.
360///
361/// TODO(ngates): replace this for aggregation function.
362fn process_matches<O, S>(
363    matches: BitBuffer,
364    list_array_len: usize,
365    offsets: &PVector<O>,
366    sizes: &PVector<S>,
367) -> BitBuffer
368where
369    O: IntegerPType,
370    S: IntegerPType,
371{
372    let offsets_slice = offsets.elements().as_slice();
373    let sizes_slice = sizes.elements().as_slice();
374
375    (0..list_array_len)
376        .map(|i| {
377            // TODO(ngates): does validity render this invalid?
378            let offset = offsets_slice[i].as_();
379            let size = sizes_slice[i].as_();
380
381            // BitIndexIterator yields indices of true bits only. If `.next()` returns
382            // `Some(_)`, at least one element in this list's range matches.
383            let mut set_bits =
384                BitIndexIterator::new(matches.inner().as_slice(), matches.offset() + offset, size);
385            set_bits.next().is_some()
386        })
387        .collect::<BitBuffer>()
388}
389
390#[cfg(test)]
391mod tests {
392    use std::sync::Arc;
393
394    use vortex_buffer::BitBuffer;
395    use vortex_dtype::DType;
396    use vortex_dtype::Field;
397    use vortex_dtype::FieldPath;
398    use vortex_dtype::FieldPathSet;
399    use vortex_dtype::Nullability;
400    use vortex_dtype::PType::I32;
401    use vortex_dtype::StructFields;
402    use vortex_scalar::Scalar;
403    use vortex_utils::aliases::hash_map::HashMap;
404    use vortex_utils::aliases::hash_set::HashSet;
405
406    use super::list_contains;
407    use crate::Array;
408    use crate::ArrayRef;
409    use crate::IntoArray;
410    use crate::arrays::BoolArray;
411    use crate::arrays::ListArray;
412    use crate::arrays::PrimitiveArray;
413    use crate::expr::exprs::binary::and;
414    use crate::expr::exprs::binary::gt;
415    use crate::expr::exprs::binary::lt;
416    use crate::expr::exprs::binary::or;
417    use crate::expr::exprs::get_item::col;
418    use crate::expr::exprs::get_item::get_item;
419    use crate::expr::exprs::literal::lit;
420    use crate::expr::exprs::root::root;
421    use crate::expr::pruning::checked_pruning_expr;
422    use crate::expr::stats::Stat;
423    use crate::validity::Validity;
424
425    fn test_array() -> ArrayRef {
426        ListArray::try_new(
427            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2, 2, 2, 3, 3, 3]).into_array(),
428            PrimitiveArray::from_iter(vec![0, 5, 10]).into_array(),
429            Validity::AllValid,
430        )
431        .unwrap()
432        .into_array()
433    }
434
435    #[test]
436    pub fn test_one() {
437        let arr = test_array();
438
439        let expr = list_contains(root(), lit(1));
440        let item = expr.evaluate(&arr).unwrap();
441
442        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
443        assert_eq!(
444            item.scalar_at(1),
445            Scalar::bool(false, Nullability::Nullable)
446        );
447    }
448
449    #[test]
450    pub fn test_all() {
451        let arr = test_array();
452
453        let expr = list_contains(root(), lit(2));
454        let item = expr.evaluate(&arr).unwrap();
455
456        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
457        assert_eq!(item.scalar_at(1), Scalar::bool(true, Nullability::Nullable));
458    }
459
460    #[test]
461    pub fn test_none() {
462        let arr = test_array();
463
464        let expr = list_contains(root(), lit(4));
465        let item = expr.evaluate(&arr).unwrap();
466
467        assert_eq!(
468            item.scalar_at(0),
469            Scalar::bool(false, Nullability::Nullable)
470        );
471        assert_eq!(
472            item.scalar_at(1),
473            Scalar::bool(false, Nullability::Nullable)
474        );
475    }
476
477    #[test]
478    pub fn test_empty() {
479        let arr = ListArray::try_new(
480            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
481            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
482            Validity::AllValid,
483        )
484        .unwrap()
485        .into_array();
486
487        let expr = list_contains(root(), lit(2));
488        let item = expr.evaluate(&arr).unwrap();
489
490        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
491        assert_eq!(
492            item.scalar_at(1),
493            Scalar::bool(false, Nullability::Nullable)
494        );
495    }
496
497    #[test]
498    pub fn test_nullable() {
499        let arr = ListArray::try_new(
500            PrimitiveArray::from_iter(vec![1, 1, 2, 2, 2]).into_array(),
501            PrimitiveArray::from_iter(vec![0, 5, 5]).into_array(),
502            Validity::Array(BoolArray::from(BitBuffer::from(vec![true, false])).into_array()),
503        )
504        .unwrap()
505        .into_array();
506
507        let expr = list_contains(root(), lit(2));
508        let item = expr.evaluate(&arr).unwrap();
509
510        assert_eq!(item.scalar_at(0), Scalar::bool(true, Nullability::Nullable));
511        assert!(!item.is_valid(1));
512    }
513
514    #[test]
515    pub fn test_return_type() {
516        let scope = DType::Struct(
517            StructFields::new(
518                ["array"].into(),
519                vec![DType::List(
520                    Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
521                    Nullability::Nullable,
522                )],
523            ),
524            Nullability::NonNullable,
525        );
526
527        let expr = list_contains(get_item("array", root()), lit(2));
528
529        // Expect nullable, although scope is non-nullable
530        assert_eq!(
531            expr.return_dtype(&scope).unwrap(),
532            DType::Bool(Nullability::Nullable)
533        );
534    }
535
536    #[test]
537    pub fn list_falsification() {
538        let expr = list_contains(
539            lit(Scalar::list(
540                Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
541                vec![1.into(), 2.into(), 3.into()],
542                Nullability::NonNullable,
543            )),
544            col("a"),
545        );
546
547        let (expr, st) = checked_pruning_expr(
548            &expr,
549            &FieldPathSet::from_iter([
550                FieldPath::from_iter([Field::Name("a".into()), Field::Name("max".into())]),
551                FieldPath::from_iter([Field::Name("a".into()), Field::Name("min".into())]),
552            ]),
553        )
554        .unwrap();
555
556        assert_eq!(
557            &expr,
558            &and(
559                and(
560                    or(lt(col("a_max"), lit(1i32)), gt(col("a_min"), lit(1i32)),),
561                    or(lt(col("a_max"), lit(2i32)), gt(col("a_min"), lit(2i32)),)
562                ),
563                or(lt(col("a_max"), lit(3i32)), gt(col("a_min"), lit(3i32)),)
564            )
565        );
566
567        assert_eq!(
568            st.map(),
569            &HashMap::from_iter([(
570                FieldPath::from_name("a"),
571                HashSet::from([Stat::Min, Stat::Max])
572            )])
573        );
574    }
575
576    #[test]
577    pub fn test_display() {
578        let expr = list_contains(get_item("tags", root()), lit("urgent"));
579        assert_eq!(expr.to_string(), "contains($.tags, \"urgent\")");
580
581        let expr2 = list_contains(root(), lit(42));
582        assert_eq!(expr2.to_string(), "contains($, 42i32)");
583    }
584}