vortex_array/compute/
list_contains.rs

1//! List-related compute operations.
2
3use std::sync::LazyLock;
4
5use arcref::ArcRef;
6use arrow_buffer::BooleanBuffer;
7use arrow_buffer::bit_iterator::BitIndexIterator;
8use num_traits::AsPrimitive;
9use vortex_buffer::Buffer;
10use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
11use vortex_error::{VortexExpect, VortexResult, vortex_bail};
12use vortex_scalar::{ListScalar, Scalar};
13
14use crate::arrays::{BoolArray, ConstantArray, ListArray};
15use crate::compute::{
16    BinaryArgs, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Output, compare,
17    fill_null, or,
18};
19use crate::validity::Validity;
20use crate::vtable::{VTable, ValidityHelper};
21use crate::{Array, ArrayRef, IntoArray, ToCanonical};
22
23/// Compute a `Bool`-typed array the same length as `array` where elements is `true` if the list
24/// item contains the `value`, `false` otherwise.
25///
26/// ## Null scalar handling
27///
28/// If the `value` or `array` is `null` at any index the result at that index is `null`.
29///
30/// ## Format semantics
31/// ```txt
32/// list_contains(list, elem)
33///   ==> (!is_null(list) or NULL) and (!is_null(elem) or NULL) and any({elem = elem_i | elem_i in list}),
34/// ```
35///
36/// ## Example
37///
38/// ```rust
39/// use vortex_array::{Array, IntoArray, ToCanonical};
40/// use vortex_array::arrays::{ConstantArray, ListArray, VarBinArray};
41/// use vortex_array::compute::list_contains;
42/// use vortex_array::validity::Validity;
43/// use vortex_buffer::buffer;
44/// use vortex_dtype::DType;
45/// use vortex_scalar::Scalar;
46/// let elements = VarBinArray::from_vec(
47///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
48/// let offsets = buffer![0u32, 1, 3, 5].into_array();
49/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
50///
51/// let matches = list_contains(list_array.as_ref(), ConstantArray::new(Scalar::from("b"), list_array.len()).as_ref()).unwrap();
52/// let to_vec: Vec<bool> = matches.to_bool().unwrap().boolean_buffer().iter().collect();
53/// assert_eq!(to_vec, vec![false, true, false]);
54/// ```
55pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
56    LIST_CONTAINS_FN
57        .invoke(&InvocationArgs {
58            inputs: &[array.into(), value.into()],
59            options: &(),
60        })?
61        .unwrap_array()
62}
63
64pub struct ListContains;
65
66impl ComputeFnVTable for ListContains {
67    fn invoke(
68        &self,
69        args: &InvocationArgs,
70        kernels: &[ArcRef<dyn Kernel>],
71    ) -> VortexResult<Output> {
72        let BinaryArgs {
73            lhs: array,
74            rhs: value,
75            ..
76        } = BinaryArgs::<()>::try_from(args)?;
77
78        let DType::List(elem_dtype, _) = array.dtype() else {
79            vortex_bail!("Array must be of List type");
80        };
81        if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
82            vortex_bail!(
83                "Element type {} of ListArray does not match search value {}",
84                elem_dtype,
85                value.dtype(),
86            );
87        };
88
89        if value.all_invalid()? || array.all_invalid()? {
90            return Ok(Output::Array(
91                ConstantArray::new(
92                    Scalar::null(DType::Bool(Nullability::Nullable)),
93                    array.len(),
94                )
95                .to_array(),
96            ));
97        }
98
99        for kernel in kernels {
100            if let Some(output) = kernel.invoke(args)? {
101                return Ok(output);
102            }
103        }
104        if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
105            return Ok(output);
106        }
107
108        let nullability = array.dtype().nullability() | value.dtype().nullability();
109
110        let result = if let Some(value_scalar) = value.as_constant() {
111            list_contains_scalar(array, &value_scalar, nullability)
112        } else if let Some(list_scalar) = array.as_constant() {
113            constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
114        } else {
115            todo!("unsupported list contains with list and element as arrays")
116        };
117
118        result.map(Output::Array)
119    }
120
121    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
122        let input = BinaryArgs::<()>::try_from(args)?;
123        Ok(DType::Bool(
124            input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
125        ))
126    }
127
128    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
129        Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
130    }
131
132    fn is_elementwise(&self) -> bool {
133        true
134    }
135}
136
137pub trait ListContainsKernel: VTable {
138    fn list_contains(
139        &self,
140        list: &dyn Array,
141        element: &Self::Array,
142    ) -> VortexResult<Option<ArrayRef>>;
143}
144
145pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
146inventory::collect!(ListContainsKernelRef);
147
148#[derive(Debug)]
149pub struct ListContainsKernelAdapter<V: VTable>(pub V);
150
151impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
152    pub const fn lift(&'static self) -> ListContainsKernelRef {
153        ListContainsKernelRef(ArcRef::new_ref(self))
154    }
155}
156
157impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
158    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
159        let BinaryArgs {
160            lhs: array,
161            rhs: value,
162            ..
163        } = BinaryArgs::<()>::try_from(args)?;
164        let Some(value) = value.as_opt::<V>() else {
165            return Ok(None);
166        };
167        self.0
168            .list_contains(array, value)
169            .map(|c| c.map(Output::Array))
170    }
171}
172
173pub static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
174    let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
175    for kernel in inventory::iter::<ListContainsKernelRef> {
176        compute.register_kernel(kernel.0.clone());
177    }
178    compute
179});
180
181// Then there is a constant list scalar (haystack) being compared to an array of needles.
182fn constant_list_scalar_contains(
183    list_scalar: &ListScalar,
184    values: &dyn Array,
185    nullability: Nullability,
186) -> VortexResult<ArrayRef> {
187    let elements = list_scalar.elements().vortex_expect("non null");
188
189    let len = values.len();
190    let mut result: Option<ArrayRef> = None;
191    let false_scalar = Scalar::bool(false, nullability);
192    for element in elements {
193        let res = fill_null(
194            &compare(
195                ConstantArray::new(element, len).as_ref(),
196                values,
197                Operator::Eq,
198            )?,
199            &false_scalar,
200        )?;
201        if let Some(acc) = result {
202            result = Some(or(&acc, &res)?)
203        } else {
204            result = Some(res);
205        }
206    }
207    Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
208}
209
210fn list_contains_scalar(
211    array: &dyn Array,
212    value: &Scalar,
213    nullability: Nullability,
214) -> VortexResult<ArrayRef> {
215    // If the list array is constant, we perform a single comparison.
216    if array.len() > 1 && array.is_constant() {
217        let contains = list_contains_scalar(&array.slice(0, 1)?, value, nullability)?;
218        return Ok(ConstantArray::new(contains.scalar_at(0)?, array.len()).into_array());
219    }
220
221    // Canonicalize to a list array.
222    // NOTE(ngates): we may wish to add elements and offsets accessors to the ListArrayTrait.
223    let list_array = array.to_list()?;
224
225    let elems = list_array.elements();
226    if elems.is_empty() {
227        // Must return false when a list is empty (but valid), or null when the list itself is null.
228        return list_false_or_null(&list_array, nullability);
229    }
230
231    let rhs = ConstantArray::new(value.clone(), elems.len());
232    let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
233    let matches = matching_elements.to_bool()?;
234
235    // Fast path: no elements match.
236    if let Some(pred) = matches.as_constant() {
237        return match pred.as_bool().value() {
238            // All comparisons are invalid (result in `null`), and search is not null because
239            // we already checked for null above.
240            None => {
241                assert!(
242                    !rhs.scalar().is_null(),
243                    "Search value must not be null here"
244                );
245                // False, unless the list itself is null in which case we return null.
246                list_false_or_null(&list_array, nullability)
247            }
248            // No elements match, and all comparisons are valid (result in `false`).
249            Some(false) => {
250                // False, but match the nullability to the input list array.
251                Ok(
252                    ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
253                        .into_array(),
254                )
255            }
256            // All elements match, and all comparisons are valid (result in `true`).
257            Some(true) => {
258                // True, unless the list itself is empty or NULL.
259                list_is_not_empty(&list_array, nullability)
260            }
261        };
262    }
263
264    let ends = list_array.offsets().to_primitive()?;
265    match_each_integer_ptype!(ends.ptype(), |T| {
266        Ok(reduce_with_ends(
267            ends.as_slice::<T>(),
268            matches.boolean_buffer(),
269            list_array.validity().clone().union_nullability(nullability),
270        ))
271    })
272}
273
274/// Returns a `Bool` array with `false` for lists that are valid,
275/// or `NULL` if the list itself is null.
276fn list_false_or_null(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
277    match list_array.validity() {
278        Validity::NonNullable => {
279            // All false.
280            Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
281        }
282        Validity::AllValid => {
283            // All false, but nullable.
284            Ok(
285                ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
286                    .into_array(),
287            )
288        }
289        Validity::AllInvalid => {
290            // All nulls, must be nullable result.
291            Ok(ConstantArray::new(
292                Scalar::null(DType::Bool(Nullability::Nullable)),
293                list_array.len(),
294            )
295            .into_array())
296        }
297        Validity::Array(validity_array) => {
298            // Create a new bool array with false, and the provided nulls
299            let buffer = BooleanBuffer::new_unset(list_array.len());
300            Ok(BoolArray::new(buffer, Validity::Array(validity_array.clone())).into_array())
301        }
302    }
303}
304
305/// Returns a `Bool` array with `true` for lists which are NOT empty, or `false` if they are empty,
306/// or `NULL` if the list itself is null.
307fn list_is_not_empty(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
308    // Short-circuit for all invalid.
309    if matches!(list_array.validity(), Validity::AllInvalid) {
310        return Ok(ConstantArray::new(
311            Scalar::null(DType::Bool(Nullability::Nullable)),
312            list_array.len(),
313        )
314        .into_array());
315    }
316
317    let offsets = list_array.offsets().to_primitive()?;
318    let buffer = match_each_integer_ptype!(offsets.ptype(), |T| {
319        element_is_not_empty(offsets.as_slice::<T>())
320    });
321
322    // Copy over the validity mask from the input.
323    Ok(BoolArray::new(
324        buffer,
325        list_array.validity().clone().union_nullability(nullability),
326    )
327    .into_array())
328}
329
330/// Reduces each boolean values into a Mask that indicates which elements in the
331/// ListArray contain the matching value.
332fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
333    ends: &[T],
334    matches: &BooleanBuffer,
335    validity: Validity,
336) -> ArrayRef {
337    let mask: BooleanBuffer = ends
338        .windows(2)
339        .map(|window| {
340            let len = window[1].as_() - window[0].as_();
341            let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
342            set_bits.next().is_some()
343        })
344        .collect();
345
346    BoolArray::new(mask, validity).into_array()
347}
348
349/// Returns a new array of `u64` representing the length of each list element.
350///
351/// ## Example
352///
353/// ```rust
354/// use vortex_array::arrays::{ListArray, VarBinArray};
355/// use vortex_array::{Array, IntoArray};
356/// use vortex_array::compute::{list_elem_len};
357/// use vortex_array::validity::Validity;
358/// use vortex_buffer::buffer;
359/// use vortex_dtype::DType;
360///
361/// let elements = VarBinArray::from_vec(
362///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
363/// let offsets = buffer![0u32, 1, 3, 5].into_array();
364/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
365///
366/// let lens = list_elem_len(list_array.as_ref()).unwrap();
367/// assert_eq!(lens.scalar_at(0).unwrap(), 1u32.into());
368/// assert_eq!(lens.scalar_at(1).unwrap(), 2u32.into());
369/// assert_eq!(lens.scalar_at(2).unwrap(), 2u32.into());
370/// ```
371pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
372    if !matches!(array.dtype(), DType::List(..)) {
373        vortex_bail!("Array must be of list type");
374    }
375
376    // Short-circuit for constant list arrays.
377    if array.is_constant() && array.len() > 1 {
378        let elem_lens = list_elem_len(&array.slice(0, 1)?)?;
379        return Ok(ConstantArray::new(elem_lens.scalar_at(0)?, array.len()).into_array());
380    }
381
382    let list_array = array.to_list()?;
383    let offsets = list_array.offsets().to_primitive()?;
384    let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
385        element_lens(offsets.as_slice::<T>()).into_array()
386    });
387
388    Ok(lens_array)
389}
390
391fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
392    values
393        .windows(2)
394        .map(|window| window[1] - window[0])
395        .collect()
396}
397
398fn element_is_not_empty<T: NativePType>(values: &[T]) -> BooleanBuffer {
399    BooleanBuffer::from_iter(values.windows(2).map(|window| window[1] != window[0]))
400}
401
402#[cfg(test)]
403mod tests {
404    use std::sync::Arc;
405
406    use itertools::Itertools;
407    use rstest::rstest;
408    use vortex_buffer::Buffer;
409    use vortex_dtype::{DType, Nullability, PType};
410    use vortex_scalar::Scalar;
411
412    use crate::arrays::{
413        BoolArray, ConstantArray, ConstantVTable, ListArray, PrimitiveArray, VarBinArray,
414    };
415    use crate::canonical::ToCanonical;
416    use crate::compute::list_contains;
417    use crate::validity::Validity;
418    use crate::vtable::ValidityHelper;
419    use crate::{Array, ArrayRef, IntoArray};
420
421    fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
422        ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
423            .unwrap()
424    }
425
426    fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
427        let elements = values.iter().flatten().cloned().collect_vec();
428        let mut offsets = values
429            .iter()
430            .scan(0u64, |st, v| {
431                *st += v.len() as u64;
432                Some(*st)
433            })
434            .collect_vec();
435        offsets.insert(0, 0u64);
436        let offsets = Buffer::from_iter(offsets).into_array();
437
438        let elements =
439            VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
440
441        ListArray::try_new(elements, offsets, Validity::NonNullable)
442            .unwrap()
443            .into_array()
444    }
445
446    fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
447        BoolArray::new(values.into_iter().collect(), validity)
448    }
449
450    #[rstest]
451    #[case(
452        nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
453        Some("a"),
454        bool_array(vec![false, true, true], Validity::NonNullable)
455    )]
456    // Cast 2: valid scalar search over nullable list, with all nulls matched
457    #[case(
458        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
459        Some("a"),
460        bool_array(vec![false, true, true], Validity::AllValid)
461    )]
462    // Cast 3: valid scalar search over nullable list, with some nulls not matched (return no nulls)
463    #[case(
464        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
465        Some("a"),
466        bool_array(vec![false, true, false], Validity::AllValid)
467    )]
468    // Case 4: list(utf8) with all elements matching, but some empty lists
469    #[case(
470        nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
471        Some("a"),
472        bool_array(vec![false, true, true], Validity::NonNullable)
473    )]
474    // Case 5: list(utf8) all lists empty.
475    #[case(
476        nonnull_strings(vec![vec![], vec![], vec![]]),
477        Some("a"),
478        bool_array(vec![false, false, false], Validity::NonNullable)
479    )]
480    // Case 6: list(utf8) no elements matching.
481    #[case(
482        nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
483        Some("a"),
484        bool_array(vec![false, false, false], Validity::NonNullable)
485    )]
486    // Case 7: list(utf8?) with empty + NULL elements and NULL search
487    #[case(
488        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
489        None,
490        bool_array(vec![false, true, true], Validity::AllInvalid)
491    )]
492    // Case 8: list(utf8?) with empty + NULL elements and search scalar
493    #[case(
494        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
495        Some("a"),
496        bool_array(vec![false, false, false], Validity::AllValid)
497    )]
498    fn test_contains_nullable(
499        #[case] list_array: ArrayRef,
500        #[case] value: Option<&str>,
501        #[case] expected: BoolArray,
502    ) {
503        let element_nullability = list_array.dtype().as_list_element().unwrap().nullability();
504        let scalar = match value {
505            None => Scalar::null(DType::Utf8(Nullability::Nullable)),
506            Some(v) => Scalar::utf8(v, element_nullability),
507        };
508        let elem = ConstantArray::new(scalar, list_array.len());
509        let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
510        let bool_result = result.to_bool().expect("to_bool failed");
511        assert_eq!(
512            bool_result.opt_bool_vec().unwrap(),
513            expected.opt_bool_vec().unwrap()
514        );
515        assert_eq!(bool_result.validity(), expected.validity());
516    }
517
518    #[test]
519    fn test_constant_list() {
520        let list_array = ConstantArray::new(
521            Scalar::list(
522                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
523                vec![1i32.into(), 2i32.into(), 3i32.into()],
524                Nullability::NonNullable,
525            ),
526            2,
527        )
528        .into_array();
529
530        let contains = list_contains(
531            &list_array,
532            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
533        )
534        .unwrap();
535        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
536        assert_eq!(
537            contains
538                .to_bool()
539                .unwrap()
540                .boolean_buffer()
541                .iter()
542                .collect_vec(),
543            vec![true, true],
544        );
545    }
546
547    #[test]
548    fn test_all_nulls() {
549        let list_array = ConstantArray::new(
550            Scalar::null(DType::List(
551                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
552                Nullability::Nullable,
553            )),
554            5,
555        )
556        .into_array();
557
558        let contains = list_contains(
559            &list_array,
560            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
561        )
562        .unwrap();
563        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
564
565        assert_eq!(contains.len(), 5);
566        assert_eq!(
567            contains.to_bool().unwrap().validity(),
568            &Validity::AllInvalid
569        );
570    }
571
572    #[test]
573    fn test_list_array_element() {
574        let list_scalar = Scalar::list(
575            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
576            vec![1.into(), 3.into(), 6.into()],
577            Nullability::NonNullable,
578        );
579
580        let contains = list_contains(
581            ConstantArray::new(list_scalar, 7).as_ref(),
582            (0..7).collect::<PrimitiveArray>().as_ref(),
583        )
584        .unwrap();
585
586        assert_eq!(contains.len(), 7);
587        assert_eq!(
588            contains.to_bool().unwrap().opt_bool_vec().unwrap(),
589            vec![
590                Some(false),
591                Some(true),
592                Some(false),
593                Some(true),
594                Some(false),
595                Some(false),
596                Some(true)
597            ]
598        );
599    }
600}