vortex_array/compute/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! List-related compute operations.
5
6use std::sync::LazyLock;
7
8use arcref::ArcRef;
9use arrow_buffer::BooleanBuffer;
10use arrow_buffer::bit_iterator::BitIndexIterator;
11use num_traits::AsPrimitive;
12use vortex_buffer::Buffer;
13use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
14use vortex_error::{VortexExpect, VortexResult, vortex_bail};
15use vortex_scalar::{ListScalar, Scalar};
16
17use crate::arrays::{BoolArray, ConstantArray, ListArray};
18use crate::compute::{
19    BinaryArgs, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Output, compare,
20    fill_null, or,
21};
22use crate::validity::Validity;
23use crate::vtable::{VTable, ValidityHelper};
24use crate::{Array, ArrayRef, IntoArray, ToCanonical};
25
26static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
27    let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
28    for kernel in inventory::iter::<ListContainsKernelRef> {
29        compute.register_kernel(kernel.0.clone());
30    }
31    compute
32});
33
34/// Compute a `Bool`-typed array the same length as `array` where elements is `true` if the list
35/// item contains the `value`, `false` otherwise.
36///
37/// ## Null scalar handling
38///
39/// If the `value` or `array` is `null` at any index the result at that index is `null`.
40///
41/// ## Format semantics
42/// ```txt
43/// list_contains(list, elem)
44///   ==> (!is_null(list) or NULL) and (!is_null(elem) or NULL) and any({elem = elem_i | elem_i in list}),
45/// ```
46///
47/// ## Example
48///
49/// ```rust
50/// use vortex_array::{Array, IntoArray, ToCanonical};
51/// use vortex_array::arrays::{ConstantArray, ListArray, VarBinArray};
52/// use vortex_array::compute::list_contains;
53/// use vortex_array::validity::Validity;
54/// use vortex_buffer::buffer;
55/// use vortex_dtype::DType;
56/// use vortex_scalar::Scalar;
57/// let elements = VarBinArray::from_vec(
58///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
59/// let offsets = buffer![0u32, 1, 3, 5].into_array();
60/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
61///
62/// let matches = list_contains(list_array.as_ref(), ConstantArray::new(Scalar::from("b"), list_array.len()).as_ref()).unwrap();
63/// let to_vec: Vec<bool> = matches.to_bool().unwrap().boolean_buffer().iter().collect();
64/// assert_eq!(to_vec, vec![false, true, false]);
65/// ```
66pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
67    LIST_CONTAINS_FN
68        .invoke(&InvocationArgs {
69            inputs: &[array.into(), value.into()],
70            options: &(),
71        })?
72        .unwrap_array()
73}
74
75pub struct ListContains;
76
77impl ComputeFnVTable for ListContains {
78    fn invoke(
79        &self,
80        args: &InvocationArgs,
81        kernels: &[ArcRef<dyn Kernel>],
82    ) -> VortexResult<Output> {
83        let BinaryArgs {
84            lhs: array,
85            rhs: value,
86            ..
87        } = BinaryArgs::<()>::try_from(args)?;
88
89        let DType::List(elem_dtype, _) = array.dtype() else {
90            vortex_bail!("Array must be of List type");
91        };
92        if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
93            vortex_bail!(
94                "Element type {} of ListArray does not match search value {}",
95                elem_dtype,
96                value.dtype(),
97            );
98        };
99
100        if value.all_invalid() || array.all_invalid() {
101            return Ok(Output::Array(
102                ConstantArray::new(
103                    Scalar::null(DType::Bool(Nullability::Nullable)),
104                    array.len(),
105                )
106                .to_array(),
107            ));
108        }
109
110        for kernel in kernels {
111            if let Some(output) = kernel.invoke(args)? {
112                return Ok(output);
113            }
114        }
115        if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
116            return Ok(output);
117        }
118
119        let nullability = array.dtype().nullability() | value.dtype().nullability();
120
121        let result = if let Some(value_scalar) = value.as_constant() {
122            list_contains_scalar(array, &value_scalar, nullability)
123        } else if let Some(list_scalar) = array.as_constant() {
124            constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
125        } else {
126            todo!("unsupported list contains with list and element as arrays")
127        };
128
129        result.map(Output::Array)
130    }
131
132    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
133        let input = BinaryArgs::<()>::try_from(args)?;
134        Ok(DType::Bool(
135            input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
136        ))
137    }
138
139    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
140        Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
141    }
142
143    fn is_elementwise(&self) -> bool {
144        true
145    }
146}
147
148pub trait ListContainsKernel: VTable {
149    fn list_contains(
150        &self,
151        list: &dyn Array,
152        element: &Self::Array,
153    ) -> VortexResult<Option<ArrayRef>>;
154}
155
156pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
157inventory::collect!(ListContainsKernelRef);
158
159#[derive(Debug)]
160pub struct ListContainsKernelAdapter<V: VTable>(pub V);
161
162impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
163    pub const fn lift(&'static self) -> ListContainsKernelRef {
164        ListContainsKernelRef(ArcRef::new_ref(self))
165    }
166}
167
168impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
169    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
170        let BinaryArgs {
171            lhs: array,
172            rhs: value,
173            ..
174        } = BinaryArgs::<()>::try_from(args)?;
175        let Some(value) = value.as_opt::<V>() else {
176            return Ok(None);
177        };
178        self.0
179            .list_contains(array, value)
180            .map(|c| c.map(Output::Array))
181    }
182}
183
184// Then there is a constant list scalar (haystack) being compared to an array of needles.
185fn constant_list_scalar_contains(
186    list_scalar: &ListScalar,
187    values: &dyn Array,
188    nullability: Nullability,
189) -> VortexResult<ArrayRef> {
190    let elements = list_scalar.elements().vortex_expect("non null");
191
192    let len = values.len();
193    let mut result: Option<ArrayRef> = None;
194    let false_scalar = Scalar::bool(false, nullability);
195    for element in elements {
196        let res = fill_null(
197            &compare(
198                ConstantArray::new(element, len).as_ref(),
199                values,
200                Operator::Eq,
201            )?,
202            &false_scalar,
203        )?;
204        if let Some(acc) = result {
205            result = Some(or(&acc, &res)?)
206        } else {
207            result = Some(res);
208        }
209    }
210    Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
211}
212
213fn list_contains_scalar(
214    array: &dyn Array,
215    value: &Scalar,
216    nullability: Nullability,
217) -> VortexResult<ArrayRef> {
218    // If the list array is constant, we perform a single comparison.
219    if array.len() > 1 && array.is_constant() {
220        let contains = list_contains_scalar(&array.slice(0..1), value, nullability)?;
221        return Ok(ConstantArray::new(contains.scalar_at(0), array.len()).into_array());
222    }
223
224    // Canonicalize to a list array.
225    // NOTE(ngates): we may wish to add elements and offsets accessors to the ListArrayTrait.
226    let list_array = array.to_list()?;
227
228    let elems = list_array.elements();
229    if elems.is_empty() {
230        // Must return false when a list is empty (but valid), or null when the list itself is null.
231        return list_false_or_null(&list_array, nullability);
232    }
233
234    let rhs = ConstantArray::new(value.clone(), elems.len());
235    let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
236    let matches = matching_elements.to_bool()?;
237
238    // Fast path: no elements match.
239    if let Some(pred) = matches.as_constant() {
240        return match pred.as_bool().value() {
241            // All comparisons are invalid (result in `null`), and search is not null because
242            // we already checked for null above.
243            None => {
244                assert!(
245                    !rhs.scalar().is_null(),
246                    "Search value must not be null here"
247                );
248                // False, unless the list itself is null in which case we return null.
249                list_false_or_null(&list_array, nullability)
250            }
251            // No elements match, and all comparisons are valid (result in `false`).
252            Some(false) => {
253                // False, but match the nullability to the input list array.
254                Ok(
255                    ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
256                        .into_array(),
257                )
258            }
259            // All elements match, and all comparisons are valid (result in `true`).
260            Some(true) => {
261                // True, unless the list itself is empty or NULL.
262                list_is_not_empty(&list_array, nullability)
263            }
264        };
265    }
266
267    let ends = list_array.offsets().to_primitive()?;
268    match_each_integer_ptype!(ends.ptype(), |T| {
269        Ok(reduce_with_ends(
270            ends.as_slice::<T>(),
271            matches.boolean_buffer(),
272            list_array.validity().clone().union_nullability(nullability),
273        ))
274    })
275}
276
277/// Returns a `Bool` array with `false` for lists that are valid,
278/// or `NULL` if the list itself is null.
279fn list_false_or_null(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
280    match list_array.validity() {
281        Validity::NonNullable => {
282            // All false.
283            Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
284        }
285        Validity::AllValid => {
286            // All false, but nullable.
287            Ok(
288                ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
289                    .into_array(),
290            )
291        }
292        Validity::AllInvalid => {
293            // All nulls, must be nullable result.
294            Ok(ConstantArray::new(
295                Scalar::null(DType::Bool(Nullability::Nullable)),
296                list_array.len(),
297            )
298            .into_array())
299        }
300        Validity::Array(validity_array) => {
301            // Create a new bool array with false, and the provided nulls
302            let buffer = BooleanBuffer::new_unset(list_array.len());
303            Ok(BoolArray::new(buffer, Validity::Array(validity_array.clone())).into_array())
304        }
305    }
306}
307
308/// Returns a `Bool` array with `true` for lists which are NOT empty, or `false` if they are empty,
309/// or `NULL` if the list itself is null.
310fn list_is_not_empty(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
311    // Short-circuit for all invalid.
312    if matches!(list_array.validity(), Validity::AllInvalid) {
313        return Ok(ConstantArray::new(
314            Scalar::null(DType::Bool(Nullability::Nullable)),
315            list_array.len(),
316        )
317        .into_array());
318    }
319
320    let offsets = list_array.offsets().to_primitive()?;
321    let buffer = match_each_integer_ptype!(offsets.ptype(), |T| {
322        element_is_not_empty(offsets.as_slice::<T>())
323    });
324
325    // Copy over the validity mask from the input.
326    Ok(BoolArray::new(
327        buffer,
328        list_array.validity().clone().union_nullability(nullability),
329    )
330    .into_array())
331}
332
333/// Reduces each boolean values into a Mask that indicates which elements in the
334/// ListArray contain the matching value.
335fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
336    ends: &[T],
337    matches: &BooleanBuffer,
338    validity: Validity,
339) -> ArrayRef {
340    let mask: BooleanBuffer = ends
341        .windows(2)
342        .map(|window| {
343            let len = window[1].as_() - window[0].as_();
344            let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
345            set_bits.next().is_some()
346        })
347        .collect();
348
349    BoolArray::new(mask, validity).into_array()
350}
351
352/// Returns a new array of `u64` representing the length of each list element.
353///
354/// ## Example
355///
356/// ```rust
357/// use vortex_array::arrays::{ListArray, VarBinArray};
358/// use vortex_array::{Array, IntoArray};
359/// use vortex_array::compute::{list_elem_len};
360/// use vortex_array::validity::Validity;
361/// use vortex_buffer::buffer;
362/// use vortex_dtype::DType;
363///
364/// let elements = VarBinArray::from_vec(
365///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
366/// let offsets = buffer![0u32, 1, 3, 5].into_array();
367/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
368///
369/// let lens = list_elem_len(list_array.as_ref()).unwrap();
370/// assert_eq!(lens.scalar_at(0), 1u32.into());
371/// assert_eq!(lens.scalar_at(1), 2u32.into());
372/// assert_eq!(lens.scalar_at(2), 2u32.into());
373/// ```
374pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
375    if !matches!(array.dtype(), DType::List(..)) {
376        vortex_bail!("Array must be of list type");
377    }
378
379    // Short-circuit for constant list arrays.
380    if array.is_constant() && array.len() > 1 {
381        let elem_lens = list_elem_len(&array.slice(0..1))?;
382        return Ok(ConstantArray::new(elem_lens.scalar_at(0), array.len()).into_array());
383    }
384
385    let list_array = array.to_list()?;
386    let offsets = list_array.offsets().to_primitive()?;
387    let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
388        element_lens(offsets.as_slice::<T>()).into_array()
389    });
390
391    Ok(lens_array)
392}
393
394fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
395    values
396        .windows(2)
397        .map(|window| window[1] - window[0])
398        .collect()
399}
400
401fn element_is_not_empty<T: NativePType>(values: &[T]) -> BooleanBuffer {
402    BooleanBuffer::from_iter(values.windows(2).map(|window| window[1] != window[0]))
403}
404
405#[cfg(test)]
406mod tests {
407    use std::sync::Arc;
408
409    use itertools::Itertools;
410    use rstest::rstest;
411    use vortex_buffer::Buffer;
412    use vortex_dtype::{DType, Nullability, PType};
413    use vortex_scalar::Scalar;
414
415    use crate::arrays::{
416        BoolArray, ConstantArray, ConstantVTable, ListArray, PrimitiveArray, VarBinArray,
417    };
418    use crate::canonical::ToCanonical;
419    use crate::compute::list_contains;
420    use crate::validity::Validity;
421    use crate::vtable::ValidityHelper;
422    use crate::{Array, ArrayRef, IntoArray};
423
424    fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
425        ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
426            .unwrap()
427    }
428
429    fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
430        let elements = values.iter().flatten().cloned().collect_vec();
431        let mut offsets = values
432            .iter()
433            .scan(0u64, |st, v| {
434                *st += v.len() as u64;
435                Some(*st)
436            })
437            .collect_vec();
438        offsets.insert(0, 0u64);
439        let offsets = Buffer::from_iter(offsets).into_array();
440
441        let elements =
442            VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
443
444        ListArray::try_new(elements, offsets, Validity::NonNullable)
445            .unwrap()
446            .into_array()
447    }
448
449    fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
450        BoolArray::new(values.into_iter().collect(), validity)
451    }
452
453    #[rstest]
454    #[case(
455        nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
456        Some("a"),
457        bool_array(vec![false, true, true], Validity::NonNullable)
458    )]
459    // Cast 2: valid scalar search over nullable list, with all nulls matched
460    #[case(
461        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
462        Some("a"),
463        bool_array(vec![false, true, true], Validity::AllValid)
464    )]
465    // Cast 3: valid scalar search over nullable list, with some nulls not matched (return no nulls)
466    #[case(
467        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
468        Some("a"),
469        bool_array(vec![false, true, false], Validity::AllValid)
470    )]
471    // Case 4: list(utf8) with all elements matching, but some empty lists
472    #[case(
473        nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
474        Some("a"),
475        bool_array(vec![false, true, true], Validity::NonNullable)
476    )]
477    // Case 5: list(utf8) all lists empty.
478    #[case(
479        nonnull_strings(vec![vec![], vec![], vec![]]),
480        Some("a"),
481        bool_array(vec![false, false, false], Validity::NonNullable)
482    )]
483    // Case 6: list(utf8) no elements matching.
484    #[case(
485        nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
486        Some("a"),
487        bool_array(vec![false, false, false], Validity::NonNullable)
488    )]
489    // Case 7: list(utf8?) with empty + NULL elements and NULL search
490    #[case(
491        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
492        None,
493        bool_array(vec![false, true, true], Validity::AllInvalid)
494    )]
495    // Case 8: list(utf8?) with empty + NULL elements and search scalar
496    #[case(
497        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
498        Some("a"),
499        bool_array(vec![false, false, false], Validity::AllValid)
500    )]
501    fn test_contains_nullable(
502        #[case] list_array: ArrayRef,
503        #[case] value: Option<&str>,
504        #[case] expected: BoolArray,
505    ) {
506        let element_nullability = list_array
507            .dtype()
508            .as_list_element_opt()
509            .unwrap()
510            .nullability();
511        let scalar = match value {
512            None => Scalar::null(DType::Utf8(Nullability::Nullable)),
513            Some(v) => Scalar::utf8(v, element_nullability),
514        };
515        let elem = ConstantArray::new(scalar, list_array.len());
516        let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
517        let bool_result = result.to_bool().expect("to_bool failed");
518        assert_eq!(
519            bool_result.opt_bool_vec().unwrap(),
520            expected.opt_bool_vec().unwrap()
521        );
522        assert_eq!(bool_result.validity(), expected.validity());
523    }
524
525    #[test]
526    fn test_constant_list() {
527        let list_array = ConstantArray::new(
528            Scalar::list(
529                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
530                vec![1i32.into(), 2i32.into(), 3i32.into()],
531                Nullability::NonNullable,
532            ),
533            2,
534        )
535        .into_array();
536
537        let contains = list_contains(
538            &list_array,
539            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
540        )
541        .unwrap();
542        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
543        assert_eq!(
544            contains
545                .to_bool()
546                .unwrap()
547                .boolean_buffer()
548                .iter()
549                .collect_vec(),
550            vec![true, true],
551        );
552    }
553
554    #[test]
555    fn test_all_nulls() {
556        let list_array = ConstantArray::new(
557            Scalar::null(DType::List(
558                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
559                Nullability::Nullable,
560            )),
561            5,
562        )
563        .into_array();
564
565        let contains = list_contains(
566            &list_array,
567            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
568        )
569        .unwrap();
570        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
571
572        assert_eq!(contains.len(), 5);
573        assert_eq!(
574            contains.to_bool().unwrap().validity(),
575            &Validity::AllInvalid
576        );
577    }
578
579    #[test]
580    fn test_list_array_element() {
581        let list_scalar = Scalar::list(
582            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
583            vec![1.into(), 3.into(), 6.into()],
584            Nullability::NonNullable,
585        );
586
587        let contains = list_contains(
588            ConstantArray::new(list_scalar, 7).as_ref(),
589            (0..7).collect::<PrimitiveArray>().as_ref(),
590        )
591        .unwrap();
592
593        assert_eq!(contains.len(), 7);
594        assert_eq!(
595            contains.to_bool().unwrap().opt_bool_vec().unwrap(),
596            vec![
597                Some(false),
598                Some(true),
599                Some(false),
600                Some(true),
601                Some(false),
602                Some(false),
603                Some(true)
604            ]
605        );
606    }
607}