1use std::sync::LazyLock;
10
11use arcref::ArcRef;
12use arrow_buffer::bit_iterator::BitIndexIterator;
13use num_traits::Zero;
14use vortex_buffer::BitBuffer;
15use vortex_dtype::DType;
16use vortex_dtype::IntegerPType;
17use vortex_dtype::Nullability;
18use vortex_dtype::match_each_integer_ptype;
19use vortex_error::VortexExpect;
20use vortex_error::VortexResult;
21use vortex_error::vortex_bail;
22use vortex_scalar::ListScalar;
23use vortex_scalar::Scalar;
24
25use crate::Array;
26use crate::ArrayRef;
27use crate::IntoArray;
28use crate::ToCanonical;
29use crate::arrays::BoolArray;
30use crate::arrays::ConstantArray;
31use crate::arrays::ListViewArray;
32use crate::arrays::PrimitiveArray;
33use crate::compute::BinaryArgs;
34use crate::compute::ComputeFn;
35use crate::compute::ComputeFnVTable;
36use crate::compute::InvocationArgs;
37use crate::compute::Kernel;
38use crate::compute::Operator;
39use crate::compute::Output;
40use crate::compute::{self};
41use crate::validity::Validity;
42use crate::vtable::VTable;
43use crate::vtable::ValidityHelper;
44
45static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
46 let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
47 for kernel in inventory::iter::<ListContainsKernelRef> {
48 compute.register_kernel(kernel.0.clone());
49 }
50 compute
51});
52
53pub(crate) fn warm_up_vtable() -> usize {
54 LIST_CONTAINS_FN.kernels().len()
55}
56
57pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
105 LIST_CONTAINS_FN
106 .invoke(&InvocationArgs {
107 inputs: &[array.into(), value.into()],
108 options: &(),
109 })?
110 .unwrap_array()
111}
112
113pub struct ListContains;
114
115impl ComputeFnVTable for ListContains {
116 fn invoke(
117 &self,
118 args: &InvocationArgs,
119 kernels: &[ArcRef<dyn Kernel>],
120 ) -> VortexResult<Output> {
121 let BinaryArgs {
122 lhs: array,
123 rhs: value,
124 ..
125 } = BinaryArgs::<()>::try_from(args)?;
126
127 let DType::List(elem_dtype, _) = array.dtype() else {
128 vortex_bail!("Array must be of List type");
129 };
130 if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
131 vortex_bail!(
132 "Element type {} of `ListViewArray` does not match search value {}",
133 elem_dtype,
134 value.dtype(),
135 );
136 };
137
138 if value.all_invalid() || array.all_invalid() {
139 return Ok(Output::Array(
140 ConstantArray::new(
141 Scalar::null(DType::Bool(Nullability::Nullable)),
142 array.len(),
143 )
144 .to_array(),
145 ));
146 }
147
148 for kernel in kernels {
149 if let Some(output) = kernel.invoke(args)? {
150 return Ok(output);
151 }
152 }
153 if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
154 return Ok(output);
155 }
156
157 let nullability = array.dtype().nullability() | value.dtype().nullability();
158
159 let result = if let Some(value_scalar) = value.as_constant() {
160 list_contains_scalar(array, &value_scalar, nullability)
161 } else if let Some(list_scalar) = array.as_constant() {
162 constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
163 } else {
164 todo!("unsupported list contains with list and element as arrays")
165 };
166
167 result.map(Output::Array)
168 }
169
170 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
171 let input = BinaryArgs::<()>::try_from(args)?;
172 Ok(DType::Bool(
173 input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
174 ))
175 }
176
177 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
178 Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
179 }
180
181 fn is_elementwise(&self) -> bool {
182 true
183 }
184}
185
186pub trait ListContainsKernel: VTable {
187 fn list_contains(
188 &self,
189 list: &dyn Array,
190 element: &Self::Array,
191 ) -> VortexResult<Option<ArrayRef>>;
192}
193
194pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
195inventory::collect!(ListContainsKernelRef);
196
197#[derive(Debug)]
198pub struct ListContainsKernelAdapter<V: VTable>(pub V);
199
200impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
201 pub const fn lift(&'static self) -> ListContainsKernelRef {
202 ListContainsKernelRef(ArcRef::new_ref(self))
203 }
204}
205
206impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
207 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
208 let BinaryArgs {
209 lhs: array,
210 rhs: value,
211 ..
212 } = BinaryArgs::<()>::try_from(args)?;
213 let Some(value) = value.as_opt::<V>() else {
214 return Ok(None);
215 };
216 self.0
217 .list_contains(array, value)
218 .map(|c| c.map(Output::Array))
219 }
220}
221
222fn constant_list_scalar_contains(
224 list_scalar: &ListScalar,
225 values: &dyn Array,
226 nullability: Nullability,
227) -> VortexResult<ArrayRef> {
228 let elements = list_scalar.elements().vortex_expect("non null");
229
230 let len = values.len();
231 let mut result: Option<ArrayRef> = None;
232 let false_scalar = Scalar::bool(false, nullability);
233 for element in elements {
234 let res = compute::fill_null(
235 &compute::compare(
236 ConstantArray::new(element, len).as_ref(),
237 values,
238 Operator::Eq,
239 )?,
240 &false_scalar,
241 )?;
242 if let Some(acc) = result {
243 result = Some(compute::or(&acc, &res)?)
244 } else {
245 result = Some(res);
246 }
247 }
248 Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
249}
250
251fn list_contains_scalar(
253 array: &dyn Array,
254 value: &Scalar,
255 nullability: Nullability,
256) -> VortexResult<ArrayRef> {
257 if array.len() > 1 && array.is_constant() {
259 let contains = list_contains_scalar(&array.slice(0..1), value, nullability)?;
260 return Ok(ConstantArray::new(contains.scalar_at(0), array.len()).into_array());
261 }
262
263 let list_array = array.to_listview();
264
265 let elems = list_array.elements();
266 if elems.is_empty() {
267 return list_false_or_null(&list_array, nullability);
269 }
270
271 let rhs = ConstantArray::new(value.clone(), elems.len());
272 let matching_elements = compute::compare(elems, rhs.as_ref(), Operator::Eq)?;
273 let matches = matching_elements.to_bool();
274
275 if let Some(pred) = matches.as_constant() {
277 return match pred.as_bool().value() {
278 None => {
281 assert!(
282 !rhs.scalar().is_null(),
283 "Search value must not be null here"
284 );
285 list_false_or_null(&list_array, nullability)
287 }
288 Some(false) => {
290 Ok(
292 ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
293 .into_array(),
294 )
295 }
296 Some(true) => {
298 list_is_not_empty(&list_array, nullability)
300 }
301 };
302 }
303
304 let offsets = list_array.offsets().to_primitive();
306 let sizes = list_array.sizes().to_primitive();
307
308 let list_matches = match_each_integer_ptype!(offsets.ptype(), |O| {
310 match_each_integer_ptype!(sizes.ptype(), |S| {
311 process_matches::<O, S>(matches, list_array.len(), offsets, sizes)
312 })
313 });
314
315 Ok(BoolArray::from_bit_buffer(
316 list_matches,
317 list_array.validity().clone().union_nullability(nullability),
318 )
319 .into_array())
320}
321
322fn process_matches<O, S>(
325 matches: BoolArray,
326 list_array_len: usize,
327 offsets: PrimitiveArray,
328 sizes: PrimitiveArray,
329) -> BitBuffer
330where
331 O: IntegerPType,
332 S: IntegerPType,
333{
334 let offsets_slice = offsets.as_slice::<O>();
335 let sizes_slice = sizes.as_slice::<S>();
336
337 (0..list_array_len)
338 .map(|i| {
339 let offset = offsets_slice[i].as_();
340 let size = sizes_slice[i].as_();
341
342 let mut set_bits =
345 BitIndexIterator::new(matches.bit_buffer().inner().as_ref(), offset, size);
346 set_bits.next().is_some()
347 })
348 .collect::<BitBuffer>()
349}
350
351fn list_false_or_null(
354 list_array: &ListViewArray,
355 nullability: Nullability,
356) -> VortexResult<ArrayRef> {
357 match list_array.validity() {
358 Validity::NonNullable => {
359 Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
361 }
362 Validity::AllValid => {
363 Ok(
365 ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
366 .into_array(),
367 )
368 }
369 Validity::AllInvalid => {
370 Ok(ConstantArray::new(
372 Scalar::null(DType::Bool(Nullability::Nullable)),
373 list_array.len(),
374 )
375 .into_array())
376 }
377 Validity::Array(validity_array) => {
378 let buffer = BitBuffer::new_unset(list_array.len());
380 Ok(
381 BoolArray::from_bit_buffer(buffer, Validity::Array(validity_array.clone()))
382 .into_array(),
383 )
384 }
385 }
386}
387
388fn list_is_not_empty(
391 list_array: &ListViewArray,
392 nullability: Nullability,
393) -> VortexResult<ArrayRef> {
394 if matches!(list_array.validity(), Validity::AllInvalid) {
396 return Ok(ConstantArray::new(
397 Scalar::null(DType::Bool(Nullability::Nullable)),
398 list_array.len(),
399 )
400 .into_array());
401 }
402
403 let sizes = list_array.sizes().to_primitive();
404 let buffer = match_each_integer_ptype!(sizes.ptype(), |S| {
405 BitBuffer::from_iter(sizes.as_slice::<S>().iter().map(|&size| size != S::zero()))
406 });
407
408 Ok(BoolArray::from_bit_buffer(
410 buffer,
411 list_array.validity().clone().union_nullability(nullability),
412 )
413 .into_array())
414}
415
416#[cfg(test)]
417mod tests {
418 use std::sync::Arc;
419
420 use itertools::Itertools;
421 use rstest::rstest;
422 use vortex_buffer::Buffer;
423 use vortex_buffer::bitbuffer;
424 use vortex_dtype::DType;
425 use vortex_dtype::Nullability;
426 use vortex_dtype::PType;
427 use vortex_scalar::Scalar;
428
429 use crate::Array;
430 use crate::ArrayRef;
431 use crate::IntoArray;
432 use crate::arrays::BoolArray;
433 use crate::arrays::ConstantArray;
434 use crate::arrays::ConstantVTable;
435 use crate::arrays::ListArray;
436 use crate::arrays::ListVTable;
437 use crate::arrays::ListViewArray;
438 use crate::arrays::PrimitiveArray;
439 use crate::arrays::VarBinArray;
440 use crate::arrays::list_view_from_list;
441 use crate::canonical::ToCanonical;
442 use crate::compute::list_contains;
443 use crate::validity::Validity;
444 use crate::vtable::ValidityHelper;
445
446 fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
447 list_view_from_list(
448 ListArray::from_iter_slow::<u64, _>(
449 values,
450 Arc::new(DType::Utf8(Nullability::NonNullable)),
451 )
452 .unwrap()
453 .as_::<ListVTable>()
454 .clone(),
455 )
456 .into_array()
457 }
458
459 fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
460 let elements = values.iter().flatten().cloned().collect_vec();
461
462 let mut offsets = values
463 .iter()
464 .scan(0u64, |st, v| {
465 *st += v.len() as u64;
466 Some(*st)
467 })
468 .collect_vec();
469 offsets.insert(0, 0u64);
470 let offsets = Buffer::from_iter(offsets).into_array();
471
472 let elements =
473 VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
474
475 list_view_from_list(ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap())
476 .into_array()
477 }
478
479 fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
480 BoolArray::from_bit_buffer(values.into_iter().collect(), validity)
481 }
482
483 #[rstest]
484 #[case(
485 nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
486 Some("a"),
487 bool_array(vec![false, true, true], Validity::NonNullable)
488 )]
489 #[case(
491 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
492 Some("a"),
493 bool_array(vec![false, true, true], Validity::AllValid)
494 )]
495 #[case(
497 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
498 Some("a"),
499 bool_array(vec![false, true, false], Validity::AllValid)
500 )]
501 #[case(
503 nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
504 Some("a"),
505 bool_array(vec![false, true, true], Validity::NonNullable)
506 )]
507 #[case(
509 nonnull_strings(vec![vec![], vec![], vec![]]),
510 Some("a"),
511 bool_array(vec![false, false, false], Validity::NonNullable)
512 )]
513 #[case(
515 nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
516 Some("a"),
517 bool_array(vec![false, false, false], Validity::NonNullable)
518 )]
519 #[case(
521 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
522 None,
523 bool_array(vec![false, true, true], Validity::AllInvalid)
524 )]
525 #[case(
527 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
528 Some("a"),
529 bool_array(vec![false, false, false], Validity::AllValid)
530 )]
531 fn test_contains_nullable(
532 #[case] list_array: ArrayRef,
533 #[case] value: Option<&str>,
534 #[case] expected: BoolArray,
535 ) {
536 let element_nullability = list_array
537 .dtype()
538 .as_list_element_opt()
539 .unwrap()
540 .nullability();
541 let scalar = match value {
542 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
543 Some(v) => Scalar::utf8(v, element_nullability),
544 };
545 let elem = ConstantArray::new(scalar, list_array.len());
546 let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
547 let bool_result = result.to_bool();
548 assert_eq!(bool_result.opt_bool_vec(), expected.opt_bool_vec());
549 assert_eq!(bool_result.validity(), expected.validity());
550 }
551
552 #[test]
553 fn test_constant_list() {
554 let list_array = ConstantArray::new(
555 Scalar::list(
556 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
557 vec![1i32.into(), 2i32.into(), 3i32.into()],
558 Nullability::NonNullable,
559 ),
560 2,
561 )
562 .into_array();
563
564 let contains = list_contains(
565 &list_array,
566 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
567 )
568 .unwrap();
569 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
570 assert_eq!(contains.to_bool().bit_buffer(), &bitbuffer![true, true],);
571 }
572
573 #[test]
574 fn test_all_nulls() {
575 let list_array = ConstantArray::new(
576 Scalar::null(DType::List(
577 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
578 Nullability::Nullable,
579 )),
580 5,
581 )
582 .into_array();
583
584 let contains = list_contains(
585 &list_array,
586 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
587 )
588 .unwrap();
589 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
590
591 assert_eq!(contains.len(), 5);
592 assert_eq!(contains.to_bool().validity(), &Validity::AllInvalid);
593 }
594
595 #[test]
596 fn test_list_array_element() {
597 let list_scalar = Scalar::list(
598 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
599 vec![1.into(), 3.into(), 6.into()],
600 Nullability::NonNullable,
601 );
602
603 let contains = list_contains(
604 ConstantArray::new(list_scalar, 7).as_ref(),
605 (0..7).collect::<PrimitiveArray>().as_ref(),
606 )
607 .unwrap();
608
609 assert_eq!(contains.len(), 7);
610 assert_eq!(
611 contains.to_bool().opt_bool_vec(),
612 vec![
613 Some(false),
614 Some(true),
615 Some(false),
616 Some(true),
617 Some(false),
618 Some(false),
619 Some(true)
620 ]
621 );
622 }
623
624 #[test]
625 fn test_list_contains_empty_listview() {
626 let empty_elements = PrimitiveArray::empty::<i32>(Nullability::NonNullable);
628 let offsets = Buffer::from_iter([0u32, 0, 0, 0]).into_array();
629 let sizes = Buffer::from_iter([0u32, 0, 0, 0]).into_array();
630
631 let list_array = unsafe {
632 ListViewArray::new_unchecked(
633 empty_elements.into_array(),
634 offsets,
635 sizes,
636 Validity::NonNullable,
637 )
638 .with_zero_copy_to_list(true)
639 };
640
641 let search = ConstantArray::new(Scalar::from(42i32), list_array.len());
643 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
644
645 assert_eq!(result.len(), 4);
647 assert_eq!(
648 result.to_bool().bool_vec(),
649 vec![false, false, false, false]
650 );
651 }
652
653 #[test]
654 fn test_list_contains_all_null_elements() {
655 let elements = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None, None]);
657 let offsets = Buffer::from_iter([0u32, 2, 4]).into_array();
658 let sizes = Buffer::from_iter([2u32, 2, 1]).into_array();
659
660 let list_array = unsafe {
661 ListViewArray::new_unchecked(
662 elements.into_array(),
663 offsets,
664 sizes,
665 Validity::NonNullable,
666 )
667 .with_zero_copy_to_list(true)
668 };
669
670 let null_search = ConstantArray::new(
672 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
673 list_array.len(),
674 );
675 let result = list_contains(list_array.as_ref(), null_search.as_ref()).unwrap();
676
677 assert_eq!(result.len(), 3);
679 assert_eq!(result.to_bool().validity(), &Validity::AllInvalid);
680
681 let non_null_search = ConstantArray::new(Scalar::from(42i32), list_array.len());
683 let result2 = list_contains(list_array.as_ref(), non_null_search.as_ref()).unwrap();
684
685 assert_eq!(result2.len(), 3);
687 assert_eq!(result2.to_bool().bool_vec(), vec![false, false, false]);
688 }
689
690 #[test]
691 fn test_list_contains_large_offsets() {
692 let elements = Buffer::from_iter([1i32, 2, 3, 4, 5]).into_array();
695
696 let offsets = Buffer::from_iter([0u32, 1, 4, 0]).into_array();
702 let sizes = Buffer::from_iter([1u32, 2, 1, 0]).into_array();
703
704 let list_array =
705 ListViewArray::new(elements.into_array(), offsets, sizes, Validity::NonNullable);
706
707 let search = ConstantArray::new(Scalar::from(2i32), list_array.len());
709 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
710
711 assert_eq!(result.len(), 4);
712 assert_eq!(
713 result.to_bool().bool_vec(),
714 vec![false, true, false, false] );
716
717 let search5 = ConstantArray::new(Scalar::from(5i32), list_array.len());
719 let result5 = list_contains(list_array.as_ref(), search5.as_ref()).unwrap();
720
721 assert_eq!(
722 result5.to_bool().bool_vec(),
723 vec![false, false, true, false] );
725 }
726
727 #[test]
728 fn test_list_contains_offset_size_boundary() {
729 let elements = Buffer::from_iter(0..256).into_array();
734 let offsets = Buffer::from_iter([0u8, 100, 200, 254]).into_array();
735 let sizes = Buffer::from_iter([50u8, 50, 54, 2]).into_array(); let list_array =
738 ListViewArray::new(elements.into_array(), offsets, sizes, Validity::NonNullable);
739
740 let search = ConstantArray::new(Scalar::from(255i32), list_array.len());
742 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
743
744 assert_eq!(result.len(), 4);
745 assert_eq!(result.to_bool().bool_vec(), vec![false, false, false, true]);
746
747 let search_zero = ConstantArray::new(Scalar::from(0i32), list_array.len());
749 let result_zero = list_contains(list_array.as_ref(), search_zero.as_ref()).unwrap();
750
751 assert_eq!(
752 result_zero.to_bool().bool_vec(),
753 vec![true, false, false, false]
754 );
755 }
756}