vortex_array/compute/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! List-related compute operations.
5
6use std::sync::LazyLock;
7
8use arcref::ArcRef;
9use arrow_buffer::BooleanBuffer;
10use arrow_buffer::bit_iterator::BitIndexIterator;
11use num_traits::AsPrimitive;
12use vortex_buffer::Buffer;
13use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
14use vortex_error::{VortexExpect, VortexResult, vortex_bail};
15use vortex_scalar::{ListScalar, Scalar};
16
17use crate::arrays::{BoolArray, ConstantArray, ListArray};
18use crate::compute::{
19    BinaryArgs, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Output, compare,
20    fill_null, or,
21};
22use crate::validity::Validity;
23use crate::vtable::{VTable, ValidityHelper};
24use crate::{Array, ArrayRef, IntoArray, ToCanonical};
25
26static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
27    let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
28    for kernel in inventory::iter::<ListContainsKernelRef> {
29        compute.register_kernel(kernel.0.clone());
30    }
31    compute
32});
33
34pub(crate) fn warm_up_vtable() -> usize {
35    LIST_CONTAINS_FN.kernels().len()
36}
37
38/// Compute a `Bool`-typed array the same length as `array` where elements is `true` if the list
39/// item contains the `value`, `false` otherwise.
40///
41/// ## Null scalar handling
42///
43/// If the `value` or `array` is `null` at any index the result at that index is `null`.
44///
45/// ## Format semantics
46/// ```txt
47/// list_contains(list, elem)
48///   ==> (!is_null(list) or NULL) and (!is_null(elem) or NULL) and any({elem = elem_i | elem_i in list}),
49/// ```
50///
51/// ## Example
52///
53/// ```rust
54/// use vortex_array::{Array, IntoArray, ToCanonical};
55/// use vortex_array::arrays::{ConstantArray, ListArray, VarBinArray};
56/// use vortex_array::compute::list_contains;
57/// use vortex_array::validity::Validity;
58/// use vortex_buffer::buffer;
59/// use vortex_dtype::DType;
60/// use vortex_scalar::Scalar;
61/// let elements = VarBinArray::from_vec(
62///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
63/// let offsets = buffer![0u32, 1, 3, 5].into_array();
64/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
65///
66/// let matches = list_contains(list_array.as_ref(), ConstantArray::new(Scalar::from("b"), list_array.len()).as_ref()).unwrap();
67/// let to_vec: Vec<bool> = matches.to_bool().boolean_buffer().iter().collect();
68/// assert_eq!(to_vec, vec![false, true, false]);
69/// ```
70pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
71    LIST_CONTAINS_FN
72        .invoke(&InvocationArgs {
73            inputs: &[array.into(), value.into()],
74            options: &(),
75        })?
76        .unwrap_array()
77}
78
79pub struct ListContains;
80
81impl ComputeFnVTable for ListContains {
82    fn invoke(
83        &self,
84        args: &InvocationArgs,
85        kernels: &[ArcRef<dyn Kernel>],
86    ) -> VortexResult<Output> {
87        let BinaryArgs {
88            lhs: array,
89            rhs: value,
90            ..
91        } = BinaryArgs::<()>::try_from(args)?;
92
93        let DType::List(elem_dtype, _) = array.dtype() else {
94            vortex_bail!("Array must be of List type");
95        };
96        if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
97            vortex_bail!(
98                "Element type {} of ListArray does not match search value {}",
99                elem_dtype,
100                value.dtype(),
101            );
102        };
103
104        if value.all_invalid() || array.all_invalid() {
105            return Ok(Output::Array(
106                ConstantArray::new(
107                    Scalar::null(DType::Bool(Nullability::Nullable)),
108                    array.len(),
109                )
110                .to_array(),
111            ));
112        }
113
114        for kernel in kernels {
115            if let Some(output) = kernel.invoke(args)? {
116                return Ok(output);
117            }
118        }
119        if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
120            return Ok(output);
121        }
122
123        let nullability = array.dtype().nullability() | value.dtype().nullability();
124
125        let result = if let Some(value_scalar) = value.as_constant() {
126            list_contains_scalar(array, &value_scalar, nullability)
127        } else if let Some(list_scalar) = array.as_constant() {
128            constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
129        } else {
130            todo!("unsupported list contains with list and element as arrays")
131        };
132
133        result.map(Output::Array)
134    }
135
136    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
137        let input = BinaryArgs::<()>::try_from(args)?;
138        Ok(DType::Bool(
139            input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
140        ))
141    }
142
143    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
144        Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
145    }
146
147    fn is_elementwise(&self) -> bool {
148        true
149    }
150}
151
152pub trait ListContainsKernel: VTable {
153    fn list_contains(
154        &self,
155        list: &dyn Array,
156        element: &Self::Array,
157    ) -> VortexResult<Option<ArrayRef>>;
158}
159
160pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
161inventory::collect!(ListContainsKernelRef);
162
163#[derive(Debug)]
164pub struct ListContainsKernelAdapter<V: VTable>(pub V);
165
166impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
167    pub const fn lift(&'static self) -> ListContainsKernelRef {
168        ListContainsKernelRef(ArcRef::new_ref(self))
169    }
170}
171
172impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
173    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
174        let BinaryArgs {
175            lhs: array,
176            rhs: value,
177            ..
178        } = BinaryArgs::<()>::try_from(args)?;
179        let Some(value) = value.as_opt::<V>() else {
180            return Ok(None);
181        };
182        self.0
183            .list_contains(array, value)
184            .map(|c| c.map(Output::Array))
185    }
186}
187
188// Then there is a constant list scalar (haystack) being compared to an array of needles.
189fn constant_list_scalar_contains(
190    list_scalar: &ListScalar,
191    values: &dyn Array,
192    nullability: Nullability,
193) -> VortexResult<ArrayRef> {
194    let elements = list_scalar.elements().vortex_expect("non null");
195
196    let len = values.len();
197    let mut result: Option<ArrayRef> = None;
198    let false_scalar = Scalar::bool(false, nullability);
199    for element in elements {
200        let res = fill_null(
201            &compare(
202                ConstantArray::new(element, len).as_ref(),
203                values,
204                Operator::Eq,
205            )?,
206            &false_scalar,
207        )?;
208        if let Some(acc) = result {
209            result = Some(or(&acc, &res)?)
210        } else {
211            result = Some(res);
212        }
213    }
214    Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
215}
216
217fn list_contains_scalar(
218    array: &dyn Array,
219    value: &Scalar,
220    nullability: Nullability,
221) -> VortexResult<ArrayRef> {
222    // If the list array is constant, we perform a single comparison.
223    if array.len() > 1 && array.is_constant() {
224        let contains = list_contains_scalar(&array.slice(0..1), value, nullability)?;
225        return Ok(ConstantArray::new(contains.scalar_at(0), array.len()).into_array());
226    }
227
228    // Canonicalize to a list array.
229    // NOTE(ngates): we may wish to add elements and offsets accessors to the ListArrayTrait.
230    let list_array = array.to_list();
231
232    let elems = list_array.elements();
233    if elems.is_empty() {
234        // Must return false when a list is empty (but valid), or null when the list itself is null.
235        return list_false_or_null(&list_array, nullability);
236    }
237
238    let rhs = ConstantArray::new(value.clone(), elems.len());
239    let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
240    let matches = matching_elements.to_bool();
241
242    // Fast path: no elements match.
243    if let Some(pred) = matches.as_constant() {
244        return match pred.as_bool().value() {
245            // All comparisons are invalid (result in `null`), and search is not null because
246            // we already checked for null above.
247            None => {
248                assert!(
249                    !rhs.scalar().is_null(),
250                    "Search value must not be null here"
251                );
252                // False, unless the list itself is null in which case we return null.
253                list_false_or_null(&list_array, nullability)
254            }
255            // No elements match, and all comparisons are valid (result in `false`).
256            Some(false) => {
257                // False, but match the nullability to the input list array.
258                Ok(
259                    ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
260                        .into_array(),
261                )
262            }
263            // All elements match, and all comparisons are valid (result in `true`).
264            Some(true) => {
265                // True, unless the list itself is empty or NULL.
266                list_is_not_empty(&list_array, nullability)
267            }
268        };
269    }
270
271    let ends = list_array.offsets().to_primitive();
272    match_each_integer_ptype!(ends.ptype(), |T| {
273        Ok(reduce_with_ends(
274            ends.as_slice::<T>(),
275            matches.boolean_buffer(),
276            list_array.validity().clone().union_nullability(nullability),
277        ))
278    })
279}
280
281/// Returns a `Bool` array with `false` for lists that are valid,
282/// or `NULL` if the list itself is null.
283fn list_false_or_null(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
284    match list_array.validity() {
285        Validity::NonNullable => {
286            // All false.
287            Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
288        }
289        Validity::AllValid => {
290            // All false, but nullable.
291            Ok(
292                ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
293                    .into_array(),
294            )
295        }
296        Validity::AllInvalid => {
297            // All nulls, must be nullable result.
298            Ok(ConstantArray::new(
299                Scalar::null(DType::Bool(Nullability::Nullable)),
300                list_array.len(),
301            )
302            .into_array())
303        }
304        Validity::Array(validity_array) => {
305            // Create a new bool array with false, and the provided nulls
306            let buffer = BooleanBuffer::new_unset(list_array.len());
307            Ok(
308                BoolArray::from_bool_buffer(buffer, Validity::Array(validity_array.clone()))
309                    .into_array(),
310            )
311        }
312    }
313}
314
315/// Returns a `Bool` array with `true` for lists which are NOT empty, or `false` if they are empty,
316/// or `NULL` if the list itself is null.
317fn list_is_not_empty(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
318    // Short-circuit for all invalid.
319    if matches!(list_array.validity(), Validity::AllInvalid) {
320        return Ok(ConstantArray::new(
321            Scalar::null(DType::Bool(Nullability::Nullable)),
322            list_array.len(),
323        )
324        .into_array());
325    }
326
327    let offsets = list_array.offsets().to_primitive();
328    let buffer = match_each_integer_ptype!(offsets.ptype(), |T| {
329        element_is_not_empty(offsets.as_slice::<T>())
330    });
331
332    // Copy over the validity mask from the input.
333    Ok(BoolArray::from_bool_buffer(
334        buffer,
335        list_array.validity().clone().union_nullability(nullability),
336    )
337    .into_array())
338}
339
340/// Reduces each boolean values into a Mask that indicates which elements in the
341/// ListArray contain the matching value.
342fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
343    ends: &[T],
344    matches: &BooleanBuffer,
345    validity: Validity,
346) -> ArrayRef {
347    let mask: BooleanBuffer = ends
348        .windows(2)
349        .map(|window| {
350            let len = window[1].as_() - window[0].as_();
351            let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
352            set_bits.next().is_some()
353        })
354        .collect();
355
356    BoolArray::from_bool_buffer(mask, validity).into_array()
357}
358
359/// Returns a new array of `u64` representing the length of each list element.
360///
361/// ## Example
362///
363/// ```rust
364/// use vortex_array::arrays::{ListArray, VarBinArray};
365/// use vortex_array::{Array, IntoArray};
366/// use vortex_array::compute::{list_elem_len};
367/// use vortex_array::validity::Validity;
368/// use vortex_buffer::buffer;
369/// use vortex_dtype::DType;
370///
371/// let elements = VarBinArray::from_vec(
372///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
373/// let offsets = buffer![0u32, 1, 3, 5].into_array();
374/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
375///
376/// let lens = list_elem_len(list_array.as_ref()).unwrap();
377/// assert_eq!(lens.scalar_at(0), 1u32.into());
378/// assert_eq!(lens.scalar_at(1), 2u32.into());
379/// assert_eq!(lens.scalar_at(2), 2u32.into());
380/// ```
381pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
382    if !matches!(array.dtype(), DType::List(..)) {
383        vortex_bail!("Array must be of list type");
384    }
385
386    // Short-circuit for constant list arrays.
387    if array.is_constant() && array.len() > 1 {
388        let elem_lens = list_elem_len(&array.slice(0..1))?;
389        return Ok(ConstantArray::new(elem_lens.scalar_at(0), array.len()).into_array());
390    }
391
392    let list_array = array.to_list();
393    let offsets = list_array.offsets().to_primitive();
394    let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
395        element_lens(offsets.as_slice::<T>()).into_array()
396    });
397
398    Ok(lens_array)
399}
400
401fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
402    values
403        .windows(2)
404        .map(|window| window[1] - window[0])
405        .collect()
406}
407
408fn element_is_not_empty<T: NativePType>(values: &[T]) -> BooleanBuffer {
409    BooleanBuffer::from_iter(values.windows(2).map(|window| window[1] != window[0]))
410}
411
412#[cfg(test)]
413mod tests {
414    use std::sync::Arc;
415
416    use itertools::Itertools;
417    use rstest::rstest;
418    use vortex_buffer::Buffer;
419    use vortex_dtype::{DType, Nullability, PType};
420    use vortex_scalar::Scalar;
421
422    use crate::arrays::{
423        BoolArray, ConstantArray, ConstantVTable, ListArray, PrimitiveArray, VarBinArray,
424    };
425    use crate::canonical::ToCanonical;
426    use crate::compute::list_contains;
427    use crate::validity::Validity;
428    use crate::vtable::ValidityHelper;
429    use crate::{Array, ArrayRef, IntoArray};
430
431    fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
432        ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
433            .unwrap()
434    }
435
436    fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
437        let elements = values.iter().flatten().cloned().collect_vec();
438        let mut offsets = values
439            .iter()
440            .scan(0u64, |st, v| {
441                *st += v.len() as u64;
442                Some(*st)
443            })
444            .collect_vec();
445        offsets.insert(0, 0u64);
446        let offsets = Buffer::from_iter(offsets).into_array();
447
448        let elements =
449            VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
450
451        ListArray::try_new(elements, offsets, Validity::NonNullable)
452            .unwrap()
453            .into_array()
454    }
455
456    fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
457        BoolArray::from_bool_buffer(values.into_iter().collect(), validity)
458    }
459
460    #[rstest]
461    #[case(
462        nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
463        Some("a"),
464        bool_array(vec![false, true, true], Validity::NonNullable)
465    )]
466    // Cast 2: valid scalar search over nullable list, with all nulls matched
467    #[case(
468        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
469        Some("a"),
470        bool_array(vec![false, true, true], Validity::AllValid)
471    )]
472    // Cast 3: valid scalar search over nullable list, with some nulls not matched (return no nulls)
473    #[case(
474        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
475        Some("a"),
476        bool_array(vec![false, true, false], Validity::AllValid)
477    )]
478    // Case 4: list(utf8) with all elements matching, but some empty lists
479    #[case(
480        nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
481        Some("a"),
482        bool_array(vec![false, true, true], Validity::NonNullable)
483    )]
484    // Case 5: list(utf8) all lists empty.
485    #[case(
486        nonnull_strings(vec![vec![], vec![], vec![]]),
487        Some("a"),
488        bool_array(vec![false, false, false], Validity::NonNullable)
489    )]
490    // Case 6: list(utf8) no elements matching.
491    #[case(
492        nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
493        Some("a"),
494        bool_array(vec![false, false, false], Validity::NonNullable)
495    )]
496    // Case 7: list(utf8?) with empty + NULL elements and NULL search
497    #[case(
498        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
499        None,
500        bool_array(vec![false, true, true], Validity::AllInvalid)
501    )]
502    // Case 8: list(utf8?) with empty + NULL elements and search scalar
503    #[case(
504        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
505        Some("a"),
506        bool_array(vec![false, false, false], Validity::AllValid)
507    )]
508    fn test_contains_nullable(
509        #[case] list_array: ArrayRef,
510        #[case] value: Option<&str>,
511        #[case] expected: BoolArray,
512    ) {
513        let element_nullability = list_array
514            .dtype()
515            .as_list_element_opt()
516            .unwrap()
517            .nullability();
518        let scalar = match value {
519            None => Scalar::null(DType::Utf8(Nullability::Nullable)),
520            Some(v) => Scalar::utf8(v, element_nullability),
521        };
522        let elem = ConstantArray::new(scalar, list_array.len());
523        let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
524        let bool_result = result.to_bool();
525        assert_eq!(
526            bool_result.opt_bool_vec().unwrap(),
527            expected.opt_bool_vec().unwrap()
528        );
529        assert_eq!(bool_result.validity(), expected.validity());
530    }
531
532    #[test]
533    fn test_constant_list() {
534        let list_array = ConstantArray::new(
535            Scalar::list(
536                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
537                vec![1i32.into(), 2i32.into(), 3i32.into()],
538                Nullability::NonNullable,
539            ),
540            2,
541        )
542        .into_array();
543
544        let contains = list_contains(
545            &list_array,
546            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
547        )
548        .unwrap();
549        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
550        assert_eq!(
551            contains.to_bool().boolean_buffer().iter().collect_vec(),
552            vec![true, true],
553        );
554    }
555
556    #[test]
557    fn test_all_nulls() {
558        let list_array = ConstantArray::new(
559            Scalar::null(DType::List(
560                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
561                Nullability::Nullable,
562            )),
563            5,
564        )
565        .into_array();
566
567        let contains = list_contains(
568            &list_array,
569            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
570        )
571        .unwrap();
572        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
573
574        assert_eq!(contains.len(), 5);
575        assert_eq!(contains.to_bool().validity(), &Validity::AllInvalid);
576    }
577
578    #[test]
579    fn test_list_array_element() {
580        let list_scalar = Scalar::list(
581            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
582            vec![1.into(), 3.into(), 6.into()],
583            Nullability::NonNullable,
584        );
585
586        let contains = list_contains(
587            ConstantArray::new(list_scalar, 7).as_ref(),
588            (0..7).collect::<PrimitiveArray>().as_ref(),
589        )
590        .unwrap();
591
592        assert_eq!(contains.len(), 7);
593        assert_eq!(
594            contains.to_bool().opt_bool_vec().unwrap(),
595            vec![
596                Some(false),
597                Some(true),
598                Some(false),
599                Some(true),
600                Some(false),
601                Some(false),
602                Some(true)
603            ]
604        );
605    }
606}