vortex_array/compute/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! List-related compute operations.
5
6use std::sync::LazyLock;
7
8use arcref::ArcRef;
9use arrow_buffer::BooleanBuffer;
10use arrow_buffer::bit_iterator::BitIndexIterator;
11use num_traits::AsPrimitive;
12use vortex_buffer::Buffer;
13use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
14use vortex_error::{VortexExpect, VortexResult, vortex_bail};
15use vortex_scalar::{ListScalar, Scalar};
16
17use crate::arrays::{BoolArray, ConstantArray, ListArray};
18use crate::compute::{
19    BinaryArgs, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Output, compare,
20    fill_null, or,
21};
22use crate::validity::Validity;
23use crate::vtable::{VTable, ValidityHelper};
24use crate::{Array, ArrayRef, IntoArray, ToCanonical};
25
26/// Compute a `Bool`-typed array the same length as `array` where elements is `true` if the list
27/// item contains the `value`, `false` otherwise.
28///
29/// ## Null scalar handling
30///
31/// If the `value` or `array` is `null` at any index the result at that index is `null`.
32///
33/// ## Format semantics
34/// ```txt
35/// list_contains(list, elem)
36///   ==> (!is_null(list) or NULL) and (!is_null(elem) or NULL) and any({elem = elem_i | elem_i in list}),
37/// ```
38///
39/// ## Example
40///
41/// ```rust
42/// use vortex_array::{Array, IntoArray, ToCanonical};
43/// use vortex_array::arrays::{ConstantArray, ListArray, VarBinArray};
44/// use vortex_array::compute::list_contains;
45/// use vortex_array::validity::Validity;
46/// use vortex_buffer::buffer;
47/// use vortex_dtype::DType;
48/// use vortex_scalar::Scalar;
49/// let elements = VarBinArray::from_vec(
50///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
51/// let offsets = buffer![0u32, 1, 3, 5].into_array();
52/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
53///
54/// let matches = list_contains(list_array.as_ref(), ConstantArray::new(Scalar::from("b"), list_array.len()).as_ref()).unwrap();
55/// let to_vec: Vec<bool> = matches.to_bool().unwrap().boolean_buffer().iter().collect();
56/// assert_eq!(to_vec, vec![false, true, false]);
57/// ```
58pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
59    LIST_CONTAINS_FN
60        .invoke(&InvocationArgs {
61            inputs: &[array.into(), value.into()],
62            options: &(),
63        })?
64        .unwrap_array()
65}
66
67pub struct ListContains;
68
69impl ComputeFnVTable for ListContains {
70    fn invoke(
71        &self,
72        args: &InvocationArgs,
73        kernels: &[ArcRef<dyn Kernel>],
74    ) -> VortexResult<Output> {
75        let BinaryArgs {
76            lhs: array,
77            rhs: value,
78            ..
79        } = BinaryArgs::<()>::try_from(args)?;
80
81        let DType::List(elem_dtype, _) = array.dtype() else {
82            vortex_bail!("Array must be of List type");
83        };
84        if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
85            vortex_bail!(
86                "Element type {} of ListArray does not match search value {}",
87                elem_dtype,
88                value.dtype(),
89            );
90        };
91
92        if value.all_invalid()? || array.all_invalid()? {
93            return Ok(Output::Array(
94                ConstantArray::new(
95                    Scalar::null(DType::Bool(Nullability::Nullable)),
96                    array.len(),
97                )
98                .to_array(),
99            ));
100        }
101
102        for kernel in kernels {
103            if let Some(output) = kernel.invoke(args)? {
104                return Ok(output);
105            }
106        }
107        if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
108            return Ok(output);
109        }
110
111        let nullability = array.dtype().nullability() | value.dtype().nullability();
112
113        let result = if let Some(value_scalar) = value.as_constant() {
114            list_contains_scalar(array, &value_scalar, nullability)
115        } else if let Some(list_scalar) = array.as_constant() {
116            constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
117        } else {
118            todo!("unsupported list contains with list and element as arrays")
119        };
120
121        result.map(Output::Array)
122    }
123
124    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
125        let input = BinaryArgs::<()>::try_from(args)?;
126        Ok(DType::Bool(
127            input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
128        ))
129    }
130
131    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
132        Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
133    }
134
135    fn is_elementwise(&self) -> bool {
136        true
137    }
138}
139
140pub trait ListContainsKernel: VTable {
141    fn list_contains(
142        &self,
143        list: &dyn Array,
144        element: &Self::Array,
145    ) -> VortexResult<Option<ArrayRef>>;
146}
147
148pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
149inventory::collect!(ListContainsKernelRef);
150
151#[derive(Debug)]
152pub struct ListContainsKernelAdapter<V: VTable>(pub V);
153
154impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
155    pub const fn lift(&'static self) -> ListContainsKernelRef {
156        ListContainsKernelRef(ArcRef::new_ref(self))
157    }
158}
159
160impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
161    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
162        let BinaryArgs {
163            lhs: array,
164            rhs: value,
165            ..
166        } = BinaryArgs::<()>::try_from(args)?;
167        let Some(value) = value.as_opt::<V>() else {
168            return Ok(None);
169        };
170        self.0
171            .list_contains(array, value)
172            .map(|c| c.map(Output::Array))
173    }
174}
175
176pub static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
177    let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
178    for kernel in inventory::iter::<ListContainsKernelRef> {
179        compute.register_kernel(kernel.0.clone());
180    }
181    compute
182});
183
184// Then there is a constant list scalar (haystack) being compared to an array of needles.
185fn constant_list_scalar_contains(
186    list_scalar: &ListScalar,
187    values: &dyn Array,
188    nullability: Nullability,
189) -> VortexResult<ArrayRef> {
190    let elements = list_scalar.elements().vortex_expect("non null");
191
192    let len = values.len();
193    let mut result: Option<ArrayRef> = None;
194    let false_scalar = Scalar::bool(false, nullability);
195    for element in elements {
196        let res = fill_null(
197            &compare(
198                ConstantArray::new(element, len).as_ref(),
199                values,
200                Operator::Eq,
201            )?,
202            &false_scalar,
203        )?;
204        if let Some(acc) = result {
205            result = Some(or(&acc, &res)?)
206        } else {
207            result = Some(res);
208        }
209    }
210    Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
211}
212
213fn list_contains_scalar(
214    array: &dyn Array,
215    value: &Scalar,
216    nullability: Nullability,
217) -> VortexResult<ArrayRef> {
218    // If the list array is constant, we perform a single comparison.
219    if array.len() > 1 && array.is_constant() {
220        let contains = list_contains_scalar(&array.slice(0, 1)?, value, nullability)?;
221        return Ok(ConstantArray::new(contains.scalar_at(0)?, array.len()).into_array());
222    }
223
224    // Canonicalize to a list array.
225    // NOTE(ngates): we may wish to add elements and offsets accessors to the ListArrayTrait.
226    let list_array = array.to_list()?;
227
228    let elems = list_array.elements();
229    if elems.is_empty() {
230        // Must return false when a list is empty (but valid), or null when the list itself is null.
231        return list_false_or_null(&list_array, nullability);
232    }
233
234    let rhs = ConstantArray::new(value.clone(), elems.len());
235    let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
236    let matches = matching_elements.to_bool()?;
237
238    // Fast path: no elements match.
239    if let Some(pred) = matches.as_constant() {
240        return match pred.as_bool().value() {
241            // All comparisons are invalid (result in `null`), and search is not null because
242            // we already checked for null above.
243            None => {
244                assert!(
245                    !rhs.scalar().is_null(),
246                    "Search value must not be null here"
247                );
248                // False, unless the list itself is null in which case we return null.
249                list_false_or_null(&list_array, nullability)
250            }
251            // No elements match, and all comparisons are valid (result in `false`).
252            Some(false) => {
253                // False, but match the nullability to the input list array.
254                Ok(
255                    ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
256                        .into_array(),
257                )
258            }
259            // All elements match, and all comparisons are valid (result in `true`).
260            Some(true) => {
261                // True, unless the list itself is empty or NULL.
262                list_is_not_empty(&list_array, nullability)
263            }
264        };
265    }
266
267    let ends = list_array.offsets().to_primitive()?;
268    match_each_integer_ptype!(ends.ptype(), |T| {
269        Ok(reduce_with_ends(
270            ends.as_slice::<T>(),
271            matches.boolean_buffer(),
272            list_array.validity().clone().union_nullability(nullability),
273        ))
274    })
275}
276
277/// Returns a `Bool` array with `false` for lists that are valid,
278/// or `NULL` if the list itself is null.
279fn list_false_or_null(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
280    match list_array.validity() {
281        Validity::NonNullable => {
282            // All false.
283            Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
284        }
285        Validity::AllValid => {
286            // All false, but nullable.
287            Ok(
288                ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
289                    .into_array(),
290            )
291        }
292        Validity::AllInvalid => {
293            // All nulls, must be nullable result.
294            Ok(ConstantArray::new(
295                Scalar::null(DType::Bool(Nullability::Nullable)),
296                list_array.len(),
297            )
298            .into_array())
299        }
300        Validity::Array(validity_array) => {
301            // Create a new bool array with false, and the provided nulls
302            let buffer = BooleanBuffer::new_unset(list_array.len());
303            Ok(BoolArray::new(buffer, Validity::Array(validity_array.clone())).into_array())
304        }
305    }
306}
307
308/// Returns a `Bool` array with `true` for lists which are NOT empty, or `false` if they are empty,
309/// or `NULL` if the list itself is null.
310fn list_is_not_empty(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
311    // Short-circuit for all invalid.
312    if matches!(list_array.validity(), Validity::AllInvalid) {
313        return Ok(ConstantArray::new(
314            Scalar::null(DType::Bool(Nullability::Nullable)),
315            list_array.len(),
316        )
317        .into_array());
318    }
319
320    let offsets = list_array.offsets().to_primitive()?;
321    let buffer = match_each_integer_ptype!(offsets.ptype(), |T| {
322        element_is_not_empty(offsets.as_slice::<T>())
323    });
324
325    // Copy over the validity mask from the input.
326    Ok(BoolArray::new(
327        buffer,
328        list_array.validity().clone().union_nullability(nullability),
329    )
330    .into_array())
331}
332
333/// Reduces each boolean values into a Mask that indicates which elements in the
334/// ListArray contain the matching value.
335fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
336    ends: &[T],
337    matches: &BooleanBuffer,
338    validity: Validity,
339) -> ArrayRef {
340    let mask: BooleanBuffer = ends
341        .windows(2)
342        .map(|window| {
343            let len = window[1].as_() - window[0].as_();
344            let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
345            set_bits.next().is_some()
346        })
347        .collect();
348
349    BoolArray::new(mask, validity).into_array()
350}
351
352/// Returns a new array of `u64` representing the length of each list element.
353///
354/// ## Example
355///
356/// ```rust
357/// use vortex_array::arrays::{ListArray, VarBinArray};
358/// use vortex_array::{Array, IntoArray};
359/// use vortex_array::compute::{list_elem_len};
360/// use vortex_array::validity::Validity;
361/// use vortex_buffer::buffer;
362/// use vortex_dtype::DType;
363///
364/// let elements = VarBinArray::from_vec(
365///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
366/// let offsets = buffer![0u32, 1, 3, 5].into_array();
367/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
368///
369/// let lens = list_elem_len(list_array.as_ref()).unwrap();
370/// assert_eq!(lens.scalar_at(0).unwrap(), 1u32.into());
371/// assert_eq!(lens.scalar_at(1).unwrap(), 2u32.into());
372/// assert_eq!(lens.scalar_at(2).unwrap(), 2u32.into());
373/// ```
374pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
375    if !matches!(array.dtype(), DType::List(..)) {
376        vortex_bail!("Array must be of list type");
377    }
378
379    // Short-circuit for constant list arrays.
380    if array.is_constant() && array.len() > 1 {
381        let elem_lens = list_elem_len(&array.slice(0, 1)?)?;
382        return Ok(ConstantArray::new(elem_lens.scalar_at(0)?, array.len()).into_array());
383    }
384
385    let list_array = array.to_list()?;
386    let offsets = list_array.offsets().to_primitive()?;
387    let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
388        element_lens(offsets.as_slice::<T>()).into_array()
389    });
390
391    Ok(lens_array)
392}
393
394fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
395    values
396        .windows(2)
397        .map(|window| window[1] - window[0])
398        .collect()
399}
400
401fn element_is_not_empty<T: NativePType>(values: &[T]) -> BooleanBuffer {
402    BooleanBuffer::from_iter(values.windows(2).map(|window| window[1] != window[0]))
403}
404
405#[cfg(test)]
406mod tests {
407    use std::sync::Arc;
408
409    use itertools::Itertools;
410    use rstest::rstest;
411    use vortex_buffer::Buffer;
412    use vortex_dtype::{DType, Nullability, PType};
413    use vortex_scalar::Scalar;
414
415    use crate::arrays::{
416        BoolArray, ConstantArray, ConstantVTable, ListArray, PrimitiveArray, VarBinArray,
417    };
418    use crate::canonical::ToCanonical;
419    use crate::compute::list_contains;
420    use crate::validity::Validity;
421    use crate::vtable::ValidityHelper;
422    use crate::{Array, ArrayRef, IntoArray};
423
424    fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
425        ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
426            .unwrap()
427    }
428
429    fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
430        let elements = values.iter().flatten().cloned().collect_vec();
431        let mut offsets = values
432            .iter()
433            .scan(0u64, |st, v| {
434                *st += v.len() as u64;
435                Some(*st)
436            })
437            .collect_vec();
438        offsets.insert(0, 0u64);
439        let offsets = Buffer::from_iter(offsets).into_array();
440
441        let elements =
442            VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
443
444        ListArray::try_new(elements, offsets, Validity::NonNullable)
445            .unwrap()
446            .into_array()
447    }
448
449    fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
450        BoolArray::new(values.into_iter().collect(), validity)
451    }
452
453    #[rstest]
454    #[case(
455        nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
456        Some("a"),
457        bool_array(vec![false, true, true], Validity::NonNullable)
458    )]
459    // Cast 2: valid scalar search over nullable list, with all nulls matched
460    #[case(
461        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
462        Some("a"),
463        bool_array(vec![false, true, true], Validity::AllValid)
464    )]
465    // Cast 3: valid scalar search over nullable list, with some nulls not matched (return no nulls)
466    #[case(
467        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
468        Some("a"),
469        bool_array(vec![false, true, false], Validity::AllValid)
470    )]
471    // Case 4: list(utf8) with all elements matching, but some empty lists
472    #[case(
473        nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
474        Some("a"),
475        bool_array(vec![false, true, true], Validity::NonNullable)
476    )]
477    // Case 5: list(utf8) all lists empty.
478    #[case(
479        nonnull_strings(vec![vec![], vec![], vec![]]),
480        Some("a"),
481        bool_array(vec![false, false, false], Validity::NonNullable)
482    )]
483    // Case 6: list(utf8) no elements matching.
484    #[case(
485        nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
486        Some("a"),
487        bool_array(vec![false, false, false], Validity::NonNullable)
488    )]
489    // Case 7: list(utf8?) with empty + NULL elements and NULL search
490    #[case(
491        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
492        None,
493        bool_array(vec![false, true, true], Validity::AllInvalid)
494    )]
495    // Case 8: list(utf8?) with empty + NULL elements and search scalar
496    #[case(
497        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
498        Some("a"),
499        bool_array(vec![false, false, false], Validity::AllValid)
500    )]
501    fn test_contains_nullable(
502        #[case] list_array: ArrayRef,
503        #[case] value: Option<&str>,
504        #[case] expected: BoolArray,
505    ) {
506        let element_nullability = list_array.dtype().as_list_element().unwrap().nullability();
507        let scalar = match value {
508            None => Scalar::null(DType::Utf8(Nullability::Nullable)),
509            Some(v) => Scalar::utf8(v, element_nullability),
510        };
511        let elem = ConstantArray::new(scalar, list_array.len());
512        let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
513        let bool_result = result.to_bool().expect("to_bool failed");
514        assert_eq!(
515            bool_result.opt_bool_vec().unwrap(),
516            expected.opt_bool_vec().unwrap()
517        );
518        assert_eq!(bool_result.validity(), expected.validity());
519    }
520
521    #[test]
522    fn test_constant_list() {
523        let list_array = ConstantArray::new(
524            Scalar::list(
525                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
526                vec![1i32.into(), 2i32.into(), 3i32.into()],
527                Nullability::NonNullable,
528            ),
529            2,
530        )
531        .into_array();
532
533        let contains = list_contains(
534            &list_array,
535            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
536        )
537        .unwrap();
538        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
539        assert_eq!(
540            contains
541                .to_bool()
542                .unwrap()
543                .boolean_buffer()
544                .iter()
545                .collect_vec(),
546            vec![true, true],
547        );
548    }
549
550    #[test]
551    fn test_all_nulls() {
552        let list_array = ConstantArray::new(
553            Scalar::null(DType::List(
554                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
555                Nullability::Nullable,
556            )),
557            5,
558        )
559        .into_array();
560
561        let contains = list_contains(
562            &list_array,
563            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
564        )
565        .unwrap();
566        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
567
568        assert_eq!(contains.len(), 5);
569        assert_eq!(
570            contains.to_bool().unwrap().validity(),
571            &Validity::AllInvalid
572        );
573    }
574
575    #[test]
576    fn test_list_array_element() {
577        let list_scalar = Scalar::list(
578            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
579            vec![1.into(), 3.into(), 6.into()],
580            Nullability::NonNullable,
581        );
582
583        let contains = list_contains(
584            ConstantArray::new(list_scalar, 7).as_ref(),
585            (0..7).collect::<PrimitiveArray>().as_ref(),
586        )
587        .unwrap();
588
589        assert_eq!(contains.len(), 7);
590        assert_eq!(
591            contains.to_bool().unwrap().opt_bool_vec().unwrap(),
592            vec![
593                Some(false),
594                Some(true),
595                Some(false),
596                Some(true),
597                Some(false),
598                Some(false),
599                Some(true)
600            ]
601        );
602    }
603}