vortex_array/compute/
list_contains.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! List-related compute operations.
5
6// TODO(connor)[ListView]: Should this compute function be moved up the `arrays/listview`?
7// TODO(connor)[ListView]: Clean up this file.
8
9use std::sync::LazyLock;
10
11use arcref::ArcRef;
12use arrow_buffer::BooleanBuffer;
13use arrow_buffer::bit_iterator::BitIndexIterator;
14use num_traits::Zero;
15use vortex_dtype::{DType, IntegerPType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect, VortexResult, vortex_bail};
17use vortex_scalar::{ListScalar, Scalar};
18
19use crate::arrays::{BoolArray, ConstantArray, ListViewArray, PrimitiveArray};
20use crate::compute::{
21    self, BinaryArgs, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Output,
22};
23use crate::validity::Validity;
24use crate::vtable::{VTable, ValidityHelper};
25use crate::{Array, ArrayRef, IntoArray, ToCanonical};
26
27static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
28    let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
29    for kernel in inventory::iter::<ListContainsKernelRef> {
30        compute.register_kernel(kernel.0.clone());
31    }
32    compute
33});
34
35pub(crate) fn warm_up_vtable() -> usize {
36    LIST_CONTAINS_FN.kernels().len()
37}
38
39/// Compute a `Bool`-typed array the same length as `array` where elements is `true` if the list
40/// item contains the `value`, `false` otherwise.
41///
42/// ## Null scalar handling
43///
44/// If the `value` or `array` is `null` at any index the result at that index is `null`.
45///
46/// ## Format semantics
47/// ```txt
48/// list_contains(list, elem)
49///   ==> (!is_null(list) or NULL) and (!is_null(elem) or NULL) and any({elem = elem_i | elem_i in list}),
50/// ```
51///
52/// ## Example
53///
54/// ```rust
55/// # use vortex_array::{Array, IntoArray, ToCanonical};
56/// # use vortex_array::arrays::{ConstantArray, ListViewArray, VarBinArray};
57/// # use vortex_array::compute;
58/// # use vortex_array::validity::Validity;
59/// # use vortex_buffer::buffer;
60/// # use vortex_dtype::DType;
61/// # use vortex_scalar::Scalar;
62/// #
63/// let elements = VarBinArray::from_vec(
64///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
65/// let offsets = buffer![0u32, 1, 3].into_array();
66/// let sizes = buffer![1u32, 2, 2].into_array();
67/// let list_array =
68///     ListViewArray::try_new(elements, offsets, sizes, Validity::NonNullable).unwrap();
69///
70/// let matches = compute::list_contains(
71///     list_array.as_ref(),
72///     ConstantArray::new(Scalar::from("b"),
73///     list_array.len()).as_ref()
74/// ).unwrap();
75///
76/// let to_vec: Vec<bool> = matches.to_bool().boolean_buffer().iter().collect();
77/// assert_eq!(to_vec, vec![false, true, false]);
78/// ```
79pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
80    LIST_CONTAINS_FN
81        .invoke(&InvocationArgs {
82            inputs: &[array.into(), value.into()],
83            options: &(),
84        })?
85        .unwrap_array()
86}
87
88pub struct ListContains;
89
90impl ComputeFnVTable for ListContains {
91    fn invoke(
92        &self,
93        args: &InvocationArgs,
94        kernels: &[ArcRef<dyn Kernel>],
95    ) -> VortexResult<Output> {
96        let BinaryArgs {
97            lhs: array,
98            rhs: value,
99            ..
100        } = BinaryArgs::<()>::try_from(args)?;
101
102        let DType::List(elem_dtype, _) = array.dtype() else {
103            vortex_bail!("Array must be of List type");
104        };
105        if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
106            vortex_bail!(
107                "Element type {} of `ListViewArray` does not match search value {}",
108                elem_dtype,
109                value.dtype(),
110            );
111        };
112
113        if value.all_invalid() || array.all_invalid() {
114            return Ok(Output::Array(
115                ConstantArray::new(
116                    Scalar::null(DType::Bool(Nullability::Nullable)),
117                    array.len(),
118                )
119                .to_array(),
120            ));
121        }
122
123        for kernel in kernels {
124            if let Some(output) = kernel.invoke(args)? {
125                return Ok(output);
126            }
127        }
128        if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
129            return Ok(output);
130        }
131
132        let nullability = array.dtype().nullability() | value.dtype().nullability();
133
134        let result = if let Some(value_scalar) = value.as_constant() {
135            list_contains_scalar(array, &value_scalar, nullability)
136        } else if let Some(list_scalar) = array.as_constant() {
137            constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
138        } else {
139            todo!("unsupported list contains with list and element as arrays")
140        };
141
142        result.map(Output::Array)
143    }
144
145    fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
146        let input = BinaryArgs::<()>::try_from(args)?;
147        Ok(DType::Bool(
148            input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
149        ))
150    }
151
152    fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
153        Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
154    }
155
156    fn is_elementwise(&self) -> bool {
157        true
158    }
159}
160
161pub trait ListContainsKernel: VTable {
162    fn list_contains(
163        &self,
164        list: &dyn Array,
165        element: &Self::Array,
166    ) -> VortexResult<Option<ArrayRef>>;
167}
168
169pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
170inventory::collect!(ListContainsKernelRef);
171
172#[derive(Debug)]
173pub struct ListContainsKernelAdapter<V: VTable>(pub V);
174
175impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
176    pub const fn lift(&'static self) -> ListContainsKernelRef {
177        ListContainsKernelRef(ArcRef::new_ref(self))
178    }
179}
180
181impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
182    fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
183        let BinaryArgs {
184            lhs: array,
185            rhs: value,
186            ..
187        } = BinaryArgs::<()>::try_from(args)?;
188        let Some(value) = value.as_opt::<V>() else {
189            return Ok(None);
190        };
191        self.0
192            .list_contains(array, value)
193            .map(|c| c.map(Output::Array))
194    }
195}
196
197// Then there is a constant list scalar (haystack) being compared to an array of needles.
198fn constant_list_scalar_contains(
199    list_scalar: &ListScalar,
200    values: &dyn Array,
201    nullability: Nullability,
202) -> VortexResult<ArrayRef> {
203    let elements = list_scalar.elements().vortex_expect("non null");
204
205    let len = values.len();
206    let mut result: Option<ArrayRef> = None;
207    let false_scalar = Scalar::bool(false, nullability);
208    for element in elements {
209        let res = compute::fill_null(
210            &compute::compare(
211                ConstantArray::new(element, len).as_ref(),
212                values,
213                Operator::Eq,
214            )?,
215            &false_scalar,
216        )?;
217        if let Some(acc) = result {
218            result = Some(compute::or(&acc, &res)?)
219        } else {
220            result = Some(res);
221        }
222    }
223    Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
224}
225
226/// Returns a [`BoolArray`] where each bit represents if a list contains the scalar.
227fn list_contains_scalar(
228    array: &dyn Array,
229    value: &Scalar,
230    nullability: Nullability,
231) -> VortexResult<ArrayRef> {
232    // If the list array is constant, we perform a single comparison.
233    if array.len() > 1 && array.is_constant() {
234        let contains = list_contains_scalar(&array.slice(0..1), value, nullability)?;
235        return Ok(ConstantArray::new(contains.scalar_at(0), array.len()).into_array());
236    }
237
238    let list_array = array.to_listview();
239
240    let elems = list_array.elements();
241    if elems.is_empty() {
242        // Must return false when a list is empty (but valid), or null when the list itself is null.
243        return list_false_or_null(&list_array, nullability);
244    }
245
246    let rhs = ConstantArray::new(value.clone(), elems.len());
247    let matching_elements = compute::compare(elems, rhs.as_ref(), Operator::Eq)?;
248    let matches = matching_elements.to_bool();
249
250    // Fast path: no elements match.
251    if let Some(pred) = matches.as_constant() {
252        return match pred.as_bool().value() {
253            // All comparisons are invalid (result in `null`), and search is not null because
254            // we already checked for null above.
255            None => {
256                assert!(
257                    !rhs.scalar().is_null(),
258                    "Search value must not be null here"
259                );
260                // False, unless the list itself is null in which case we return null.
261                list_false_or_null(&list_array, nullability)
262            }
263            // No elements match, and all comparisons are valid (result in `false`).
264            Some(false) => {
265                // False, but match the nullability to the input list array.
266                Ok(
267                    ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
268                        .into_array(),
269                )
270            }
271            // All elements match, and all comparisons are valid (result in `true`).
272            Some(true) => {
273                // True, unless the list itself is empty or NULL.
274                list_is_not_empty(&list_array, nullability)
275            }
276        };
277    }
278
279    // Get the offsets and sizes as primitive arrays.
280    let offsets = list_array.offsets().to_primitive();
281    let sizes = list_array.sizes().to_primitive();
282
283    // Process based on the offset and size types.
284    let list_matches = match_each_integer_ptype!(offsets.ptype(), |O| {
285        match_each_integer_ptype!(sizes.ptype(), |S| {
286            process_matches::<O, S>(matches, list_array.len(), offsets, sizes)
287        })
288    });
289
290    Ok(BoolArray::from_bool_buffer(
291        list_matches,
292        list_array.validity().clone().union_nullability(nullability),
293    )
294    .into_array())
295}
296
297/// Returns a [`BooleanBuffer`] where each bit represents if a list contains the scalar, derived
298/// from a [`BoolArray`] of matches on the child elements array.
299fn process_matches<O, S>(
300    matches: BoolArray,
301    list_array_len: usize,
302    offsets: PrimitiveArray,
303    sizes: PrimitiveArray,
304) -> BooleanBuffer
305where
306    O: IntegerPType,
307    S: IntegerPType,
308{
309    let offsets_slice = offsets.as_slice::<O>();
310    let sizes_slice = sizes.as_slice::<S>();
311
312    (0..list_array_len)
313        .map(|i| {
314            let offset = offsets_slice[i].as_();
315            let size = sizes_slice[i].as_();
316
317            // BitIndexIterator yields indices of true bits only. If `.next()` returns
318            // `Some(_)`, at least one element in this list's range matches.
319            let mut set_bits =
320                BitIndexIterator::new(matches.boolean_buffer().values(), offset, size);
321            set_bits.next().is_some()
322        })
323        .collect::<BooleanBuffer>()
324}
325
326/// Returns a `Bool` array with `false` for lists that are valid,
327/// or `NULL` if the list itself is null.
328fn list_false_or_null(
329    list_array: &ListViewArray,
330    nullability: Nullability,
331) -> VortexResult<ArrayRef> {
332    match list_array.validity() {
333        Validity::NonNullable => {
334            // All false.
335            Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
336        }
337        Validity::AllValid => {
338            // All false, but nullable.
339            Ok(
340                ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
341                    .into_array(),
342            )
343        }
344        Validity::AllInvalid => {
345            // All nulls, must be nullable result.
346            Ok(ConstantArray::new(
347                Scalar::null(DType::Bool(Nullability::Nullable)),
348                list_array.len(),
349            )
350            .into_array())
351        }
352        Validity::Array(validity_array) => {
353            // Create a new bool array with false, and the provided nulls
354            let buffer = BooleanBuffer::new_unset(list_array.len());
355            Ok(
356                BoolArray::from_bool_buffer(buffer, Validity::Array(validity_array.clone()))
357                    .into_array(),
358            )
359        }
360    }
361}
362
363/// Returns a `Bool` array with `true` for lists which are NOT empty, or `false` if they are empty,
364/// or `NULL` if the list itself is null.
365fn list_is_not_empty(
366    list_array: &ListViewArray,
367    nullability: Nullability,
368) -> VortexResult<ArrayRef> {
369    // Short-circuit for all invalid.
370    if matches!(list_array.validity(), Validity::AllInvalid) {
371        return Ok(ConstantArray::new(
372            Scalar::null(DType::Bool(Nullability::Nullable)),
373            list_array.len(),
374        )
375        .into_array());
376    }
377
378    let sizes = list_array.sizes().to_primitive();
379    let buffer = match_each_integer_ptype!(sizes.ptype(), |S| {
380        BooleanBuffer::from_iter(sizes.as_slice::<S>().iter().map(|&size| size != S::zero()))
381    });
382
383    // Copy over the validity mask from the input.
384    Ok(BoolArray::from_bool_buffer(
385        buffer,
386        list_array.validity().clone().union_nullability(nullability),
387    )
388    .into_array())
389}
390
391#[cfg(test)]
392mod tests {
393    use std::sync::Arc;
394
395    use itertools::Itertools;
396    use rstest::rstest;
397    use vortex_buffer::Buffer;
398    use vortex_dtype::{DType, Nullability, PType};
399    use vortex_scalar::Scalar;
400
401    use crate::arrays::{
402        BoolArray, ConstantArray, ConstantVTable, ListArray, ListVTable, ListViewArray,
403        PrimitiveArray, VarBinArray, list_view_from_list,
404    };
405    use crate::canonical::ToCanonical;
406    use crate::compute::list_contains;
407    use crate::validity::Validity;
408    use crate::vtable::ValidityHelper;
409    use crate::{Array, ArrayRef, IntoArray};
410
411    fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
412        list_view_from_list(
413            ListArray::from_iter_slow::<u64, _>(
414                values,
415                Arc::new(DType::Utf8(Nullability::NonNullable)),
416            )
417            .unwrap()
418            .as_::<ListVTable>()
419            .clone(),
420        )
421        .into_array()
422    }
423
424    fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
425        let elements = values.iter().flatten().cloned().collect_vec();
426
427        let mut offsets = values
428            .iter()
429            .scan(0u64, |st, v| {
430                *st += v.len() as u64;
431                Some(*st)
432            })
433            .collect_vec();
434        offsets.insert(0, 0u64);
435        let offsets = Buffer::from_iter(offsets).into_array();
436
437        let elements =
438            VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
439
440        list_view_from_list(ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap())
441            .into_array()
442    }
443
444    fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
445        BoolArray::from_bool_buffer(values.into_iter().collect(), validity)
446    }
447
448    #[rstest]
449    #[case(
450        nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
451        Some("a"),
452        bool_array(vec![false, true, true], Validity::NonNullable)
453    )]
454    // Cast 2: valid scalar search over nullable list, with all nulls matched
455    #[case(
456        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
457        Some("a"),
458        bool_array(vec![false, true, true], Validity::AllValid)
459    )]
460    // Cast 3: valid scalar search over nullable list, with some nulls not matched (return no nulls)
461    #[case(
462        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
463        Some("a"),
464        bool_array(vec![false, true, false], Validity::AllValid)
465    )]
466    // Case 4: list(utf8) with all elements matching, but some empty lists
467    #[case(
468        nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
469        Some("a"),
470        bool_array(vec![false, true, true], Validity::NonNullable)
471    )]
472    // Case 5: list(utf8) all lists empty.
473    #[case(
474        nonnull_strings(vec![vec![], vec![], vec![]]),
475        Some("a"),
476        bool_array(vec![false, false, false], Validity::NonNullable)
477    )]
478    // Case 6: list(utf8) no elements matching.
479    #[case(
480        nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
481        Some("a"),
482        bool_array(vec![false, false, false], Validity::NonNullable)
483    )]
484    // Case 7: list(utf8?) with empty + NULL elements and NULL search
485    #[case(
486        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
487        None,
488        bool_array(vec![false, true, true], Validity::AllInvalid)
489    )]
490    // Case 8: list(utf8?) with empty + NULL elements and search scalar
491    #[case(
492        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
493        Some("a"),
494        bool_array(vec![false, false, false], Validity::AllValid)
495    )]
496    fn test_contains_nullable(
497        #[case] list_array: ArrayRef,
498        #[case] value: Option<&str>,
499        #[case] expected: BoolArray,
500    ) {
501        let element_nullability = list_array
502            .dtype()
503            .as_list_element_opt()
504            .unwrap()
505            .nullability();
506        let scalar = match value {
507            None => Scalar::null(DType::Utf8(Nullability::Nullable)),
508            Some(v) => Scalar::utf8(v, element_nullability),
509        };
510        let elem = ConstantArray::new(scalar, list_array.len());
511        let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
512        let bool_result = result.to_bool();
513        assert_eq!(
514            bool_result.opt_bool_vec().unwrap(),
515            expected.opt_bool_vec().unwrap()
516        );
517        assert_eq!(bool_result.validity(), expected.validity());
518    }
519
520    #[test]
521    fn test_constant_list() {
522        let list_array = ConstantArray::new(
523            Scalar::list(
524                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
525                vec![1i32.into(), 2i32.into(), 3i32.into()],
526                Nullability::NonNullable,
527            ),
528            2,
529        )
530        .into_array();
531
532        let contains = list_contains(
533            &list_array,
534            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
535        )
536        .unwrap();
537        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
538        assert_eq!(
539            contains.to_bool().boolean_buffer().iter().collect_vec(),
540            vec![true, true]
541        );
542    }
543
544    #[test]
545    fn test_all_nulls() {
546        let list_array = ConstantArray::new(
547            Scalar::null(DType::List(
548                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
549                Nullability::Nullable,
550            )),
551            5,
552        )
553        .into_array();
554
555        let contains = list_contains(
556            &list_array,
557            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
558        )
559        .unwrap();
560        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
561
562        assert_eq!(contains.len(), 5);
563        assert_eq!(contains.to_bool().validity(), &Validity::AllInvalid);
564    }
565
566    #[test]
567    fn test_list_array_element() {
568        let list_scalar = Scalar::list(
569            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
570            vec![1.into(), 3.into(), 6.into()],
571            Nullability::NonNullable,
572        );
573
574        let contains = list_contains(
575            ConstantArray::new(list_scalar, 7).as_ref(),
576            (0..7).collect::<PrimitiveArray>().as_ref(),
577        )
578        .unwrap();
579
580        assert_eq!(contains.len(), 7);
581        assert_eq!(
582            contains.to_bool().opt_bool_vec().unwrap(),
583            vec![
584                Some(false),
585                Some(true),
586                Some(false),
587                Some(true),
588                Some(false),
589                Some(false),
590                Some(true)
591            ]
592        );
593    }
594
595    #[test]
596    fn test_list_contains_empty_listview() {
597        // Create a completely empty ListView with no elements
598        let empty_elements = PrimitiveArray::empty::<i32>(Nullability::NonNullable);
599        let offsets = Buffer::from_iter([0u32, 0, 0, 0]).into_array();
600        let sizes = Buffer::from_iter([0u32, 0, 0, 0]).into_array();
601
602        let list_array = ListViewArray::try_new(
603            empty_elements.into_array(),
604            offsets,
605            sizes,
606            Validity::NonNullable,
607        )
608        .unwrap();
609
610        // Test with a non-null search value
611        let search = ConstantArray::new(Scalar::from(42i32), list_array.len());
612        let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
613
614        // All lists are empty, so all should return false
615        assert_eq!(result.len(), 4);
616        assert_eq!(
617            result.to_bool().bool_vec().unwrap(),
618            vec![false, false, false, false]
619        );
620    }
621
622    #[test]
623    fn test_list_contains_all_null_elements() {
624        // Create lists containing only null elements
625        let elements = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None, None]);
626        let offsets = Buffer::from_iter([0u32, 2, 4]).into_array();
627        let sizes = Buffer::from_iter([2u32, 2, 1]).into_array();
628
629        let list_array =
630            ListViewArray::try_new(elements.into_array(), offsets, sizes, Validity::NonNullable)
631                .unwrap();
632
633        // Test searching for a null value
634        let null_search = ConstantArray::new(
635            Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
636            list_array.len(),
637        );
638        let result = list_contains(list_array.as_ref(), null_search.as_ref()).unwrap();
639
640        // Searching for null in lists with null elements should return null
641        assert_eq!(result.len(), 3);
642        assert_eq!(result.to_bool().validity(), &Validity::AllInvalid);
643
644        // Test searching for a non-null value
645        let non_null_search = ConstantArray::new(Scalar::from(42i32), list_array.len());
646        let result2 = list_contains(list_array.as_ref(), non_null_search.as_ref()).unwrap();
647
648        // All comparisons result in null, but search is not null, so should return false
649        assert_eq!(result2.len(), 3);
650        assert_eq!(
651            result2.to_bool().bool_vec().unwrap(),
652            vec![false, false, false]
653        );
654    }
655
656    #[test]
657    fn test_list_contains_large_offsets() {
658        // Test with large offset values that are still valid
659        // ListView allows non-contiguous views into the elements array
660        let elements = Buffer::from_iter([1i32, 2, 3, 4, 5]).into_array();
661
662        // Create lists with various offsets, testing the flexibility of ListView
663        // List 0: element at offset 0 (value 1)
664        // List 1: elements at offset 1-2 (values 2, 3)
665        // List 2: element at offset 4 (value 5)
666        // List 3: empty list
667        let offsets = Buffer::from_iter([0u32, 1, 4, 0]).into_array();
668        let sizes = Buffer::from_iter([1u32, 2, 1, 0]).into_array();
669
670        let list_array =
671            ListViewArray::try_new(elements.into_array(), offsets, sizes, Validity::NonNullable)
672                .unwrap();
673
674        // Test searching for value 2, which appears only in list 1
675        let search = ConstantArray::new(Scalar::from(2i32), list_array.len());
676        let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
677
678        assert_eq!(result.len(), 4);
679        assert_eq!(
680            result.to_bool().bool_vec().unwrap(),
681            vec![false, true, false, false] // Value 2 is only in list 1
682        );
683
684        // Test searching for value 5, which appears only in list 2
685        let search5 = ConstantArray::new(Scalar::from(5i32), list_array.len());
686        let result5 = list_contains(list_array.as_ref(), search5.as_ref()).unwrap();
687
688        assert_eq!(
689            result5.to_bool().bool_vec().unwrap(),
690            vec![false, false, true, false] // Value 5 is only in list 2
691        );
692    }
693
694    #[test]
695    fn test_list_contains_offset_size_boundary() {
696        // Test edge case where offset + size approaches type boundaries
697        // We create lists where the last valid index (offset + size - 1) is at various boundaries
698
699        // For u8 boundary
700        let elements = Buffer::from_iter(0..256).into_array();
701        let offsets = Buffer::from_iter([0u8, 100, 200, 254]).into_array();
702        let sizes = Buffer::from_iter([50u8, 50, 54, 2]).into_array(); // Last list goes to index 255
703
704        let list_array =
705            ListViewArray::try_new(elements.into_array(), offsets, sizes, Validity::NonNullable)
706                .unwrap();
707
708        // Search for value 255 which should only be in the last list
709        let search = ConstantArray::new(Scalar::from(255i32), list_array.len());
710        let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
711
712        assert_eq!(result.len(), 4);
713        assert_eq!(
714            result.to_bool().bool_vec().unwrap(),
715            vec![false, false, false, true]
716        );
717
718        // Search for value 0 which should only be in the first list
719        let search_zero = ConstantArray::new(Scalar::from(0i32), list_array.len());
720        let result_zero = list_contains(list_array.as_ref(), search_zero.as_ref()).unwrap();
721
722        assert_eq!(
723            result_zero.to_bool().bool_vec().unwrap(),
724            vec![true, false, false, false]
725        );
726    }
727}