1use std::sync::LazyLock;
10
11use arcref::ArcRef;
12use arrow_buffer::bit_iterator::BitIndexIterator;
13use num_traits::Zero;
14use vortex_buffer::BitBuffer;
15use vortex_dtype::DType;
16use vortex_dtype::IntegerPType;
17use vortex_dtype::Nullability;
18use vortex_dtype::match_each_integer_ptype;
19use vortex_error::VortexExpect;
20use vortex_error::VortexResult;
21use vortex_error::vortex_bail;
22
23use crate::Array;
24use crate::ArrayRef;
25use crate::IntoArray;
26use crate::ToCanonical;
27use crate::arrays::BoolArray;
28use crate::arrays::ConstantArray;
29use crate::arrays::ConstantVTable;
30use crate::arrays::ListViewArray;
31use crate::arrays::PrimitiveArray;
32use crate::builtins::ArrayBuiltins;
33use crate::compute::BinaryArgs;
34use crate::compute::ComputeFn;
35use crate::compute::ComputeFnVTable;
36use crate::compute::InvocationArgs;
37use crate::compute::Kernel;
38use crate::compute::Operator;
39use crate::compute::Output;
40use crate::compute::{self};
41use crate::scalar::ListScalar;
42use crate::scalar::Scalar;
43use crate::validity::Validity;
44use crate::vtable::VTable;
45use crate::vtable::ValidityHelper;
46
47static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
48 let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
49 for kernel in inventory::iter::<ListContainsKernelRef> {
50 compute.register_kernel(kernel.0.clone());
51 }
52 compute
53});
54
55pub(crate) fn warm_up_vtable() -> usize {
56 LIST_CONTAINS_FN.kernels().len()
57}
58
59pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
109 LIST_CONTAINS_FN
110 .invoke(&InvocationArgs {
111 inputs: &[array.into(), value.into()],
112 options: &(),
113 })?
114 .unwrap_array()
115}
116
117pub struct ListContains;
118
119impl ComputeFnVTable for ListContains {
120 fn invoke(
121 &self,
122 args: &InvocationArgs,
123 kernels: &[ArcRef<dyn Kernel>],
124 ) -> VortexResult<Output> {
125 let BinaryArgs {
126 lhs: array,
127 rhs: value,
128 ..
129 } = BinaryArgs::<()>::try_from(args)?;
130
131 let DType::List(elem_dtype, _) = array.dtype() else {
132 vortex_bail!("Array must be of List type");
133 };
134 if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
135 vortex_bail!(
136 "Element type {} of `ListViewArray` does not match search value {}",
137 elem_dtype,
138 value.dtype(),
139 );
140 };
141
142 if value.all_invalid()? || array.all_invalid()? {
143 return Ok(Output::Array(
144 ConstantArray::new(
145 Scalar::null(DType::Bool(Nullability::Nullable)),
146 array.len(),
147 )
148 .to_array(),
149 ));
150 }
151
152 for kernel in kernels {
153 if let Some(output) = kernel.invoke(args)? {
154 return Ok(output);
155 }
156 }
157
158 let nullability = array.dtype().nullability() | value.dtype().nullability();
159
160 let result = if let Some(value_scalar) = value.as_constant() {
161 list_contains_scalar(array, &value_scalar, nullability)
162 } else if let Some(list_scalar) = array.as_constant() {
163 constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
164 } else {
165 todo!("unsupported list contains with list and element as arrays")
166 };
167
168 result.map(Output::Array)
169 }
170
171 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
172 let input = BinaryArgs::<()>::try_from(args)?;
173 Ok(DType::Bool(
174 input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
175 ))
176 }
177
178 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
179 Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
180 }
181
182 fn is_elementwise(&self) -> bool {
183 true
184 }
185}
186
187pub trait ListContainsKernel: VTable {
188 fn list_contains(
189 &self,
190 list: &dyn Array,
191 element: &Self::Array,
192 ) -> VortexResult<Option<ArrayRef>>;
193}
194
195pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
196inventory::collect!(ListContainsKernelRef);
197
198#[derive(Debug)]
199pub struct ListContainsKernelAdapter<V: VTable>(pub V);
200
201impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
202 pub const fn lift(&'static self) -> ListContainsKernelRef {
203 ListContainsKernelRef(ArcRef::new_ref(self))
204 }
205}
206
207impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
208 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
209 let BinaryArgs {
210 lhs: array,
211 rhs: value,
212 ..
213 } = BinaryArgs::<()>::try_from(args)?;
214 let Some(value) = value.as_opt::<V>() else {
215 return Ok(None);
216 };
217 self.0
218 .list_contains(array, value)
219 .map(|c| c.map(Output::Array))
220 }
221}
222
223fn constant_list_scalar_contains(
225 list_scalar: &ListScalar,
226 values: &dyn Array,
227 nullability: Nullability,
228) -> VortexResult<ArrayRef> {
229 let elements = list_scalar.elements().vortex_expect("non null");
230
231 let len = values.len();
232 let mut result: Option<ArrayRef> = None;
233 let false_scalar = Scalar::bool(false, nullability);
234
235 for element in elements {
236 let res = compute::compare(
237 ConstantArray::new(element, len).as_ref(),
238 values,
239 Operator::Eq,
240 )?
241 .fill_null(false_scalar.clone())?;
242 if let Some(acc) = result {
243 result = Some(compute::or_kleene(&acc, &res)?)
244 } else {
245 result = Some(res);
246 }
247 }
248 Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
249}
250
251fn list_contains_scalar(
253 array: &dyn Array,
254 value: &Scalar,
255 nullability: Nullability,
256) -> VortexResult<ArrayRef> {
257 if array.len() > 1 && array.is::<ConstantVTable>() {
259 let contains = list_contains_scalar(&array.slice(0..1)?, value, nullability)?;
260 return Ok(ConstantArray::new(contains.scalar_at(0)?, array.len()).into_array());
261 }
262
263 let list_array = array.to_listview();
264
265 let elems = list_array.elements();
266 if elems.is_empty() {
267 return list_false_or_null(&list_array, nullability);
269 }
270
271 let rhs = ConstantArray::new(value.clone(), elems.len());
272 let matching_elements = compute::compare(elems, rhs.as_ref(), Operator::Eq)?;
273 let matches = matching_elements.to_bool();
274
275 if let Some(pred) = matches.as_constant() {
277 return match pred.as_bool().value() {
278 None => {
281 assert!(
282 !rhs.scalar().is_null(),
283 "Search value must not be null here"
284 );
285 list_false_or_null(&list_array, nullability)
287 }
288 Some(false) => {
290 Ok(
292 ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
293 .into_array(),
294 )
295 }
296 Some(true) => {
298 list_is_not_empty(&list_array, nullability)
300 }
301 };
302 }
303
304 let offsets = list_array.offsets().to_primitive();
306 let sizes = list_array.sizes().to_primitive();
307
308 let list_matches = match_each_integer_ptype!(offsets.ptype(), |O| {
310 match_each_integer_ptype!(sizes.ptype(), |S| {
311 process_matches::<O, S>(matches, list_array.len(), offsets, sizes)
312 })
313 });
314
315 Ok(BoolArray::new(
316 list_matches,
317 list_array.validity().clone().union_nullability(nullability),
318 )
319 .into_array())
320}
321
322fn process_matches<O, S>(
325 matches: BoolArray,
326 list_array_len: usize,
327 offsets: PrimitiveArray,
328 sizes: PrimitiveArray,
329) -> BitBuffer
330where
331 O: IntegerPType,
332 S: IntegerPType,
333{
334 let offsets_slice = offsets.as_slice::<O>();
335 let sizes_slice = sizes.as_slice::<S>();
336
337 (0..list_array_len)
338 .map(|i| {
339 let offset = offsets_slice[i].as_();
340 let size = sizes_slice[i].as_();
341
342 let bits = matches.to_bit_buffer();
345 let mut set_bits = BitIndexIterator::new(bits.inner().as_ref(), offset, size);
346 set_bits.next().is_some()
347 })
348 .collect::<BitBuffer>()
349}
350
351fn list_false_or_null(
354 list_array: &ListViewArray,
355 nullability: Nullability,
356) -> VortexResult<ArrayRef> {
357 match list_array.validity() {
358 Validity::NonNullable => {
359 Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
361 }
362 Validity::AllValid => {
363 Ok(
365 ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
366 .into_array(),
367 )
368 }
369 Validity::AllInvalid => {
370 Ok(ConstantArray::new(
372 Scalar::null(DType::Bool(Nullability::Nullable)),
373 list_array.len(),
374 )
375 .into_array())
376 }
377 Validity::Array(validity_array) => {
378 let buffer = BitBuffer::new_unset(list_array.len());
380 Ok(BoolArray::new(buffer, Validity::Array(validity_array.clone())).into_array())
381 }
382 }
383}
384
385fn list_is_not_empty(
388 list_array: &ListViewArray,
389 nullability: Nullability,
390) -> VortexResult<ArrayRef> {
391 if matches!(list_array.validity(), Validity::AllInvalid) {
393 return Ok(ConstantArray::new(
394 Scalar::null(DType::Bool(Nullability::Nullable)),
395 list_array.len(),
396 )
397 .into_array());
398 }
399
400 let sizes = list_array.sizes().to_primitive();
401 let buffer = match_each_integer_ptype!(sizes.ptype(), |S| {
402 BitBuffer::from_iter(sizes.as_slice::<S>().iter().map(|&size| size != S::zero()))
403 });
404
405 Ok(BoolArray::new(
407 buffer,
408 list_array.validity().clone().union_nullability(nullability),
409 )
410 .into_array())
411}
412
413#[cfg(test)]
414mod tests {
415 use std::sync::Arc;
416
417 use itertools::Itertools;
418 use rstest::rstest;
419 use vortex_buffer::Buffer;
420 use vortex_dtype::DType;
421 use vortex_dtype::Nullability;
422 use vortex_dtype::PType;
423
424 use crate::Array;
425 use crate::ArrayRef;
426 use crate::IntoArray;
427 use crate::arrays::BoolArray;
428 use crate::arrays::ConstantArray;
429 use crate::arrays::ConstantVTable;
430 use crate::arrays::ListArray;
431 use crate::arrays::ListVTable;
432 use crate::arrays::ListViewArray;
433 use crate::arrays::PrimitiveArray;
434 use crate::arrays::VarBinArray;
435 use crate::assert_arrays_eq;
436 use crate::canonical::ToCanonical;
437 use crate::compute::list_contains;
438 use crate::scalar::Scalar;
439 use crate::validity::Validity;
440
441 fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
442 ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
443 .unwrap()
444 .as_::<ListVTable>()
445 .to_listview()
446 .into_array()
447 }
448
449 fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
450 let elements = values.iter().flatten().cloned().collect_vec();
451
452 let mut offsets = values
453 .iter()
454 .scan(0u64, |st, v| {
455 *st += v.len() as u64;
456 Some(*st)
457 })
458 .collect_vec();
459 offsets.insert(0, 0u64);
460 let offsets = Buffer::from_iter(offsets).into_array();
461
462 let elements =
463 VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
464
465 ListArray::try_new(elements, offsets, Validity::NonNullable)
466 .unwrap()
467 .to_listview()
468 .into_array()
469 }
470
471 fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
472 BoolArray::new(values.into_iter().collect(), validity)
473 }
474
475 #[rstest]
476 #[case(
477 nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
478 Some("a"),
479 bool_array(vec![false, true, true], Validity::NonNullable)
480 )]
481 #[case(
483 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
484 Some("a"),
485 bool_array(vec![false, true, true], Validity::AllValid)
486 )]
487 #[case(
489 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
490 Some("a"),
491 bool_array(vec![false, true, false], Validity::AllValid)
492 )]
493 #[case(
495 nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
496 Some("a"),
497 bool_array(vec![false, true, true], Validity::NonNullable)
498 )]
499 #[case(
501 nonnull_strings(vec![vec![], vec![], vec![]]),
502 Some("a"),
503 bool_array(vec![false, false, false], Validity::NonNullable)
504 )]
505 #[case(
507 nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
508 Some("a"),
509 bool_array(vec![false, false, false], Validity::NonNullable)
510 )]
511 #[case(
513 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
514 None,
515 bool_array(vec![false, true, true], Validity::AllInvalid)
516 )]
517 #[case(
519 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
520 Some("a"),
521 bool_array(vec![false, false, false], Validity::AllValid)
522 )]
523 fn test_contains_nullable(
524 #[case] list_array: ArrayRef,
525 #[case] value: Option<&str>,
526 #[case] expected: BoolArray,
527 ) {
528 let element_nullability = list_array
529 .dtype()
530 .as_list_element_opt()
531 .unwrap()
532 .nullability();
533 let scalar = match value {
534 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
535 Some(v) => Scalar::utf8(v, element_nullability),
536 };
537 let elem = ConstantArray::new(scalar, list_array.len());
538 let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
539 assert_arrays_eq!(result, expected);
540 }
541
542 #[test]
543 fn test_contains_nullable22() {
550 let list_array = null_strings(vec![
551 vec![],
552 vec![Some("a"), None],
553 vec![Some("b"), None, None],
554 ]);
555 let value = Some("a");
556 let expected = bool_array(vec![false, true, false], Validity::AllValid);
557 let element_nullability = list_array
558 .dtype()
559 .as_list_element_opt()
560 .unwrap()
561 .nullability();
562 let scalar = match value {
563 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
564 Some(v) => Scalar::utf8(v, element_nullability),
565 };
566 let elem = ConstantArray::new(scalar, list_array.len());
567 let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
568 assert_arrays_eq!(result, expected);
569 }
570
571 #[test]
572 fn test_constant_list() {
573 let list_array = ConstantArray::new(
574 Scalar::list(
575 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
576 vec![1i32.into(), 2i32.into(), 3i32.into()],
577 Nullability::NonNullable,
578 ),
579 2,
580 )
581 .into_array();
582
583 let contains = list_contains(
584 &list_array,
585 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
586 )
587 .unwrap();
588 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
589 let expected = BoolArray::from_iter([true, true]);
590 assert_arrays_eq!(contains, expected);
591 }
592
593 #[test]
594 fn test_all_nulls() {
595 let list_array = ConstantArray::new(
596 Scalar::null(DType::List(
597 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
598 Nullability::Nullable,
599 )),
600 5,
601 )
602 .into_array();
603
604 let contains = list_contains(
605 &list_array,
606 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
607 )
608 .unwrap();
609 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
610
611 let expected = BoolArray::new(
612 [false, false, false, false, false].into_iter().collect(),
613 Validity::AllInvalid,
614 );
615 assert_arrays_eq!(contains, expected);
616 }
617
618 #[test]
619 fn test_list_array_element() {
620 let list_scalar = Scalar::list(
621 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
622 vec![1.into(), 3.into(), 6.into()],
623 Nullability::NonNullable,
624 );
625
626 let contains = list_contains(
627 ConstantArray::new(list_scalar, 7).as_ref(),
628 (0..7).collect::<PrimitiveArray>().as_ref(),
629 )
630 .unwrap();
631
632 let expected = BoolArray::from_iter([false, true, false, true, false, false, true]);
633 assert_arrays_eq!(contains, expected);
634 }
635
636 #[test]
637 fn test_list_contains_empty_listview() {
638 let empty_elements = PrimitiveArray::empty::<i32>(Nullability::NonNullable);
640 let offsets = Buffer::from_iter([0u32, 0, 0, 0]).into_array();
641 let sizes = Buffer::from_iter([0u32, 0, 0, 0]).into_array();
642
643 let list_array = unsafe {
644 ListViewArray::new_unchecked(
645 empty_elements.into_array(),
646 offsets,
647 sizes,
648 Validity::NonNullable,
649 )
650 .with_zero_copy_to_list(true)
651 };
652
653 let search = ConstantArray::new(Scalar::from(42i32), list_array.len());
655 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
656
657 let expected = BoolArray::from_iter([false, false, false, false]);
659 assert_arrays_eq!(result, expected);
660 }
661
662 #[test]
663 fn test_list_contains_all_null_elements() {
664 let elements = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None, None]);
666 let offsets = Buffer::from_iter([0u32, 2, 4]).into_array();
667 let sizes = Buffer::from_iter([2u32, 2, 1]).into_array();
668
669 let list_array = unsafe {
670 ListViewArray::new_unchecked(
671 elements.into_array(),
672 offsets,
673 sizes,
674 Validity::NonNullable,
675 )
676 .with_zero_copy_to_list(true)
677 };
678
679 let null_search = ConstantArray::new(
681 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
682 list_array.len(),
683 );
684 let result = list_contains(list_array.as_ref(), null_search.as_ref()).unwrap();
685
686 let expected = BoolArray::new(
688 [false, false, false].into_iter().collect(),
689 Validity::AllInvalid,
690 );
691 assert_arrays_eq!(result, expected);
692
693 let non_null_search = ConstantArray::new(Scalar::from(42i32), list_array.len());
695 let result2 = list_contains(list_array.as_ref(), non_null_search.as_ref()).unwrap();
696
697 let expected2 = BoolArray::from_iter([false, false, false]);
699 assert_arrays_eq!(result2, expected2);
700 }
701
702 #[test]
703 fn test_list_contains_large_offsets() {
704 let elements = Buffer::from_iter([1i32, 2, 3, 4, 5]).into_array();
707
708 let offsets = Buffer::from_iter([0u32, 1, 4, 0]).into_array();
714 let sizes = Buffer::from_iter([1u32, 2, 1, 0]).into_array();
715
716 let list_array =
717 ListViewArray::new(elements.into_array(), offsets, sizes, Validity::NonNullable);
718
719 let search = ConstantArray::new(Scalar::from(2i32), list_array.len());
721 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
722
723 let expected = BoolArray::from_iter([false, true, false, false]);
725 assert_arrays_eq!(result, expected);
726
727 let search5 = ConstantArray::new(Scalar::from(5i32), list_array.len());
729 let result5 = list_contains(list_array.as_ref(), search5.as_ref()).unwrap();
730
731 let expected5 = BoolArray::from_iter([false, false, true, false]);
733 assert_arrays_eq!(result5, expected5);
734 }
735
736 #[test]
737 fn test_list_contains_offset_size_boundary() {
738 let elements = Buffer::from_iter(0..256).into_array();
743 let offsets = Buffer::from_iter([0u8, 100, 200, 254]).into_array();
744 let sizes = Buffer::from_iter([50u8, 50, 54, 2]).into_array(); let list_array =
747 ListViewArray::new(elements.into_array(), offsets, sizes, Validity::NonNullable);
748
749 let search = ConstantArray::new(Scalar::from(255i32), list_array.len());
751 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
752
753 let expected = BoolArray::from_iter([false, false, false, true]);
754 assert_arrays_eq!(result, expected);
755
756 let search_zero = ConstantArray::new(Scalar::from(0i32), list_array.len());
758 let result_zero = list_contains(list_array.as_ref(), search_zero.as_ref()).unwrap();
759
760 let expected_zero = BoolArray::from_iter([true, false, false, false]);
761 assert_arrays_eq!(result_zero, expected_zero);
762 }
763}