1use std::sync::LazyLock;
10
11use arcref::ArcRef;
12use arrow_buffer::BooleanBuffer;
13use arrow_buffer::bit_iterator::BitIndexIterator;
14use num_traits::Zero;
15use vortex_dtype::{DType, IntegerPType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect, VortexResult, vortex_bail};
17use vortex_scalar::{ListScalar, Scalar};
18
19use crate::arrays::{BoolArray, ConstantArray, ListViewArray, PrimitiveArray};
20use crate::compute::{
21 self, BinaryArgs, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Output,
22};
23use crate::validity::Validity;
24use crate::vtable::{VTable, ValidityHelper};
25use crate::{Array, ArrayRef, IntoArray, ToCanonical};
26
27static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
28 let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
29 for kernel in inventory::iter::<ListContainsKernelRef> {
30 compute.register_kernel(kernel.0.clone());
31 }
32 compute
33});
34
35pub(crate) fn warm_up_vtable() -> usize {
36 LIST_CONTAINS_FN.kernels().len()
37}
38
39pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
80 LIST_CONTAINS_FN
81 .invoke(&InvocationArgs {
82 inputs: &[array.into(), value.into()],
83 options: &(),
84 })?
85 .unwrap_array()
86}
87
88pub struct ListContains;
89
90impl ComputeFnVTable for ListContains {
91 fn invoke(
92 &self,
93 args: &InvocationArgs,
94 kernels: &[ArcRef<dyn Kernel>],
95 ) -> VortexResult<Output> {
96 let BinaryArgs {
97 lhs: array,
98 rhs: value,
99 ..
100 } = BinaryArgs::<()>::try_from(args)?;
101
102 let DType::List(elem_dtype, _) = array.dtype() else {
103 vortex_bail!("Array must be of List type");
104 };
105 if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
106 vortex_bail!(
107 "Element type {} of `ListViewArray` does not match search value {}",
108 elem_dtype,
109 value.dtype(),
110 );
111 };
112
113 if value.all_invalid() || array.all_invalid() {
114 return Ok(Output::Array(
115 ConstantArray::new(
116 Scalar::null(DType::Bool(Nullability::Nullable)),
117 array.len(),
118 )
119 .to_array(),
120 ));
121 }
122
123 for kernel in kernels {
124 if let Some(output) = kernel.invoke(args)? {
125 return Ok(output);
126 }
127 }
128 if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
129 return Ok(output);
130 }
131
132 let nullability = array.dtype().nullability() | value.dtype().nullability();
133
134 let result = if let Some(value_scalar) = value.as_constant() {
135 list_contains_scalar(array, &value_scalar, nullability)
136 } else if let Some(list_scalar) = array.as_constant() {
137 constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
138 } else {
139 todo!("unsupported list contains with list and element as arrays")
140 };
141
142 result.map(Output::Array)
143 }
144
145 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
146 let input = BinaryArgs::<()>::try_from(args)?;
147 Ok(DType::Bool(
148 input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
149 ))
150 }
151
152 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
153 Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
154 }
155
156 fn is_elementwise(&self) -> bool {
157 true
158 }
159}
160
161pub trait ListContainsKernel: VTable {
162 fn list_contains(
163 &self,
164 list: &dyn Array,
165 element: &Self::Array,
166 ) -> VortexResult<Option<ArrayRef>>;
167}
168
169pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
170inventory::collect!(ListContainsKernelRef);
171
172#[derive(Debug)]
173pub struct ListContainsKernelAdapter<V: VTable>(pub V);
174
175impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
176 pub const fn lift(&'static self) -> ListContainsKernelRef {
177 ListContainsKernelRef(ArcRef::new_ref(self))
178 }
179}
180
181impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
182 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
183 let BinaryArgs {
184 lhs: array,
185 rhs: value,
186 ..
187 } = BinaryArgs::<()>::try_from(args)?;
188 let Some(value) = value.as_opt::<V>() else {
189 return Ok(None);
190 };
191 self.0
192 .list_contains(array, value)
193 .map(|c| c.map(Output::Array))
194 }
195}
196
197fn constant_list_scalar_contains(
199 list_scalar: &ListScalar,
200 values: &dyn Array,
201 nullability: Nullability,
202) -> VortexResult<ArrayRef> {
203 let elements = list_scalar.elements().vortex_expect("non null");
204
205 let len = values.len();
206 let mut result: Option<ArrayRef> = None;
207 let false_scalar = Scalar::bool(false, nullability);
208 for element in elements {
209 let res = compute::fill_null(
210 &compute::compare(
211 ConstantArray::new(element, len).as_ref(),
212 values,
213 Operator::Eq,
214 )?,
215 &false_scalar,
216 )?;
217 if let Some(acc) = result {
218 result = Some(compute::or(&acc, &res)?)
219 } else {
220 result = Some(res);
221 }
222 }
223 Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
224}
225
226fn list_contains_scalar(
228 array: &dyn Array,
229 value: &Scalar,
230 nullability: Nullability,
231) -> VortexResult<ArrayRef> {
232 if array.len() > 1 && array.is_constant() {
234 let contains = list_contains_scalar(&array.slice(0..1), value, nullability)?;
235 return Ok(ConstantArray::new(contains.scalar_at(0), array.len()).into_array());
236 }
237
238 let list_array = array.to_listview();
239
240 let elems = list_array.elements();
241 if elems.is_empty() {
242 return list_false_or_null(&list_array, nullability);
244 }
245
246 let rhs = ConstantArray::new(value.clone(), elems.len());
247 let matching_elements = compute::compare(elems, rhs.as_ref(), Operator::Eq)?;
248 let matches = matching_elements.to_bool();
249
250 if let Some(pred) = matches.as_constant() {
252 return match pred.as_bool().value() {
253 None => {
256 assert!(
257 !rhs.scalar().is_null(),
258 "Search value must not be null here"
259 );
260 list_false_or_null(&list_array, nullability)
262 }
263 Some(false) => {
265 Ok(
267 ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
268 .into_array(),
269 )
270 }
271 Some(true) => {
273 list_is_not_empty(&list_array, nullability)
275 }
276 };
277 }
278
279 let offsets = list_array.offsets().to_primitive();
281 let sizes = list_array.sizes().to_primitive();
282
283 let list_matches = match_each_integer_ptype!(offsets.ptype(), |O| {
285 match_each_integer_ptype!(sizes.ptype(), |S| {
286 process_matches::<O, S>(matches, list_array.len(), offsets, sizes)
287 })
288 });
289
290 Ok(BoolArray::from_bool_buffer(
291 list_matches,
292 list_array.validity().clone().union_nullability(nullability),
293 )
294 .into_array())
295}
296
297fn process_matches<O, S>(
300 matches: BoolArray,
301 list_array_len: usize,
302 offsets: PrimitiveArray,
303 sizes: PrimitiveArray,
304) -> BooleanBuffer
305where
306 O: IntegerPType,
307 S: IntegerPType,
308{
309 let offsets_slice = offsets.as_slice::<O>();
310 let sizes_slice = sizes.as_slice::<S>();
311
312 (0..list_array_len)
313 .map(|i| {
314 let offset = offsets_slice[i].as_();
315 let size = sizes_slice[i].as_();
316
317 let mut set_bits =
320 BitIndexIterator::new(matches.boolean_buffer().values(), offset, size);
321 set_bits.next().is_some()
322 })
323 .collect::<BooleanBuffer>()
324}
325
326fn list_false_or_null(
329 list_array: &ListViewArray,
330 nullability: Nullability,
331) -> VortexResult<ArrayRef> {
332 match list_array.validity() {
333 Validity::NonNullable => {
334 Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
336 }
337 Validity::AllValid => {
338 Ok(
340 ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
341 .into_array(),
342 )
343 }
344 Validity::AllInvalid => {
345 Ok(ConstantArray::new(
347 Scalar::null(DType::Bool(Nullability::Nullable)),
348 list_array.len(),
349 )
350 .into_array())
351 }
352 Validity::Array(validity_array) => {
353 let buffer = BooleanBuffer::new_unset(list_array.len());
355 Ok(
356 BoolArray::from_bool_buffer(buffer, Validity::Array(validity_array.clone()))
357 .into_array(),
358 )
359 }
360 }
361}
362
363fn list_is_not_empty(
366 list_array: &ListViewArray,
367 nullability: Nullability,
368) -> VortexResult<ArrayRef> {
369 if matches!(list_array.validity(), Validity::AllInvalid) {
371 return Ok(ConstantArray::new(
372 Scalar::null(DType::Bool(Nullability::Nullable)),
373 list_array.len(),
374 )
375 .into_array());
376 }
377
378 let sizes = list_array.sizes().to_primitive();
379 let buffer = match_each_integer_ptype!(sizes.ptype(), |S| {
380 BooleanBuffer::from_iter(sizes.as_slice::<S>().iter().map(|&size| size != S::zero()))
381 });
382
383 Ok(BoolArray::from_bool_buffer(
385 buffer,
386 list_array.validity().clone().union_nullability(nullability),
387 )
388 .into_array())
389}
390
391#[cfg(test)]
392mod tests {
393 use std::sync::Arc;
394
395 use itertools::Itertools;
396 use rstest::rstest;
397 use vortex_buffer::Buffer;
398 use vortex_dtype::{DType, Nullability, PType};
399 use vortex_scalar::Scalar;
400
401 use crate::arrays::{
402 BoolArray, ConstantArray, ConstantVTable, ListArray, ListVTable, ListViewArray,
403 PrimitiveArray, VarBinArray, list_view_from_list,
404 };
405 use crate::canonical::ToCanonical;
406 use crate::compute::list_contains;
407 use crate::validity::Validity;
408 use crate::vtable::ValidityHelper;
409 use crate::{Array, ArrayRef, IntoArray};
410
411 fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
412 list_view_from_list(
413 ListArray::from_iter_slow::<u64, _>(
414 values,
415 Arc::new(DType::Utf8(Nullability::NonNullable)),
416 )
417 .unwrap()
418 .as_::<ListVTable>()
419 .clone(),
420 )
421 .into_array()
422 }
423
424 fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
425 let elements = values.iter().flatten().cloned().collect_vec();
426
427 let mut offsets = values
428 .iter()
429 .scan(0u64, |st, v| {
430 *st += v.len() as u64;
431 Some(*st)
432 })
433 .collect_vec();
434 offsets.insert(0, 0u64);
435 let offsets = Buffer::from_iter(offsets).into_array();
436
437 let elements =
438 VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
439
440 list_view_from_list(ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap())
441 .into_array()
442 }
443
444 fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
445 BoolArray::from_bool_buffer(values.into_iter().collect(), validity)
446 }
447
448 #[rstest]
449 #[case(
450 nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
451 Some("a"),
452 bool_array(vec![false, true, true], Validity::NonNullable)
453 )]
454 #[case(
456 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
457 Some("a"),
458 bool_array(vec![false, true, true], Validity::AllValid)
459 )]
460 #[case(
462 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
463 Some("a"),
464 bool_array(vec![false, true, false], Validity::AllValid)
465 )]
466 #[case(
468 nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
469 Some("a"),
470 bool_array(vec![false, true, true], Validity::NonNullable)
471 )]
472 #[case(
474 nonnull_strings(vec![vec![], vec![], vec![]]),
475 Some("a"),
476 bool_array(vec![false, false, false], Validity::NonNullable)
477 )]
478 #[case(
480 nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
481 Some("a"),
482 bool_array(vec![false, false, false], Validity::NonNullable)
483 )]
484 #[case(
486 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
487 None,
488 bool_array(vec![false, true, true], Validity::AllInvalid)
489 )]
490 #[case(
492 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
493 Some("a"),
494 bool_array(vec![false, false, false], Validity::AllValid)
495 )]
496 fn test_contains_nullable(
497 #[case] list_array: ArrayRef,
498 #[case] value: Option<&str>,
499 #[case] expected: BoolArray,
500 ) {
501 let element_nullability = list_array
502 .dtype()
503 .as_list_element_opt()
504 .unwrap()
505 .nullability();
506 let scalar = match value {
507 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
508 Some(v) => Scalar::utf8(v, element_nullability),
509 };
510 let elem = ConstantArray::new(scalar, list_array.len());
511 let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
512 let bool_result = result.to_bool();
513 assert_eq!(
514 bool_result.opt_bool_vec().unwrap(),
515 expected.opt_bool_vec().unwrap()
516 );
517 assert_eq!(bool_result.validity(), expected.validity());
518 }
519
520 #[test]
521 fn test_constant_list() {
522 let list_array = ConstantArray::new(
523 Scalar::list(
524 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
525 vec![1i32.into(), 2i32.into(), 3i32.into()],
526 Nullability::NonNullable,
527 ),
528 2,
529 )
530 .into_array();
531
532 let contains = list_contains(
533 &list_array,
534 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
535 )
536 .unwrap();
537 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
538 assert_eq!(
539 contains.to_bool().boolean_buffer().iter().collect_vec(),
540 vec![true, true]
541 );
542 }
543
544 #[test]
545 fn test_all_nulls() {
546 let list_array = ConstantArray::new(
547 Scalar::null(DType::List(
548 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
549 Nullability::Nullable,
550 )),
551 5,
552 )
553 .into_array();
554
555 let contains = list_contains(
556 &list_array,
557 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
558 )
559 .unwrap();
560 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
561
562 assert_eq!(contains.len(), 5);
563 assert_eq!(contains.to_bool().validity(), &Validity::AllInvalid);
564 }
565
566 #[test]
567 fn test_list_array_element() {
568 let list_scalar = Scalar::list(
569 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
570 vec![1.into(), 3.into(), 6.into()],
571 Nullability::NonNullable,
572 );
573
574 let contains = list_contains(
575 ConstantArray::new(list_scalar, 7).as_ref(),
576 (0..7).collect::<PrimitiveArray>().as_ref(),
577 )
578 .unwrap();
579
580 assert_eq!(contains.len(), 7);
581 assert_eq!(
582 contains.to_bool().opt_bool_vec().unwrap(),
583 vec![
584 Some(false),
585 Some(true),
586 Some(false),
587 Some(true),
588 Some(false),
589 Some(false),
590 Some(true)
591 ]
592 );
593 }
594
595 #[test]
596 fn test_list_contains_empty_listview() {
597 let empty_elements = PrimitiveArray::empty::<i32>(Nullability::NonNullable);
599 let offsets = Buffer::from_iter([0u32, 0, 0, 0]).into_array();
600 let sizes = Buffer::from_iter([0u32, 0, 0, 0]).into_array();
601
602 let list_array = ListViewArray::try_new(
603 empty_elements.into_array(),
604 offsets,
605 sizes,
606 Validity::NonNullable,
607 )
608 .unwrap();
609
610 let search = ConstantArray::new(Scalar::from(42i32), list_array.len());
612 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
613
614 assert_eq!(result.len(), 4);
616 assert_eq!(
617 result.to_bool().bool_vec().unwrap(),
618 vec![false, false, false, false]
619 );
620 }
621
622 #[test]
623 fn test_list_contains_all_null_elements() {
624 let elements = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None, None]);
626 let offsets = Buffer::from_iter([0u32, 2, 4]).into_array();
627 let sizes = Buffer::from_iter([2u32, 2, 1]).into_array();
628
629 let list_array =
630 ListViewArray::try_new(elements.into_array(), offsets, sizes, Validity::NonNullable)
631 .unwrap();
632
633 let null_search = ConstantArray::new(
635 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
636 list_array.len(),
637 );
638 let result = list_contains(list_array.as_ref(), null_search.as_ref()).unwrap();
639
640 assert_eq!(result.len(), 3);
642 assert_eq!(result.to_bool().validity(), &Validity::AllInvalid);
643
644 let non_null_search = ConstantArray::new(Scalar::from(42i32), list_array.len());
646 let result2 = list_contains(list_array.as_ref(), non_null_search.as_ref()).unwrap();
647
648 assert_eq!(result2.len(), 3);
650 assert_eq!(
651 result2.to_bool().bool_vec().unwrap(),
652 vec![false, false, false]
653 );
654 }
655
656 #[test]
657 fn test_list_contains_large_offsets() {
658 let elements = Buffer::from_iter([1i32, 2, 3, 4, 5]).into_array();
661
662 let offsets = Buffer::from_iter([0u32, 1, 4, 0]).into_array();
668 let sizes = Buffer::from_iter([1u32, 2, 1, 0]).into_array();
669
670 let list_array =
671 ListViewArray::try_new(elements.into_array(), offsets, sizes, Validity::NonNullable)
672 .unwrap();
673
674 let search = ConstantArray::new(Scalar::from(2i32), list_array.len());
676 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
677
678 assert_eq!(result.len(), 4);
679 assert_eq!(
680 result.to_bool().bool_vec().unwrap(),
681 vec![false, true, false, false] );
683
684 let search5 = ConstantArray::new(Scalar::from(5i32), list_array.len());
686 let result5 = list_contains(list_array.as_ref(), search5.as_ref()).unwrap();
687
688 assert_eq!(
689 result5.to_bool().bool_vec().unwrap(),
690 vec![false, false, true, false] );
692 }
693
694 #[test]
695 fn test_list_contains_offset_size_boundary() {
696 let elements = Buffer::from_iter(0..256).into_array();
701 let offsets = Buffer::from_iter([0u8, 100, 200, 254]).into_array();
702 let sizes = Buffer::from_iter([50u8, 50, 54, 2]).into_array(); let list_array =
705 ListViewArray::try_new(elements.into_array(), offsets, sizes, Validity::NonNullable)
706 .unwrap();
707
708 let search = ConstantArray::new(Scalar::from(255i32), list_array.len());
710 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
711
712 assert_eq!(result.len(), 4);
713 assert_eq!(
714 result.to_bool().bool_vec().unwrap(),
715 vec![false, false, false, true]
716 );
717
718 let search_zero = ConstantArray::new(Scalar::from(0i32), list_array.len());
720 let result_zero = list_contains(list_array.as_ref(), search_zero.as_ref()).unwrap();
721
722 assert_eq!(
723 result_zero.to_bool().bool_vec().unwrap(),
724 vec![true, false, false, false]
725 );
726 }
727}