1use std::sync::LazyLock;
10
11use arcref::ArcRef;
12use arrow_buffer::bit_iterator::BitIndexIterator;
13use num_traits::Zero;
14use vortex_buffer::BitBuffer;
15use vortex_dtype::{DType, IntegerPType, Nullability, match_each_integer_ptype};
16use vortex_error::{VortexExpect, VortexResult, vortex_bail};
17use vortex_scalar::{ListScalar, Scalar};
18
19use crate::arrays::{BoolArray, ConstantArray, ListViewArray, PrimitiveArray};
20use crate::compute::{
21 self, BinaryArgs, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Output,
22};
23use crate::validity::Validity;
24use crate::vtable::{VTable, ValidityHelper};
25use crate::{Array, ArrayRef, IntoArray, ToCanonical};
26
27static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
28 let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
29 for kernel in inventory::iter::<ListContainsKernelRef> {
30 compute.register_kernel(kernel.0.clone());
31 }
32 compute
33});
34
35pub(crate) fn warm_up_vtable() -> usize {
36 LIST_CONTAINS_FN.kernels().len()
37}
38
39pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
87 LIST_CONTAINS_FN
88 .invoke(&InvocationArgs {
89 inputs: &[array.into(), value.into()],
90 options: &(),
91 })?
92 .unwrap_array()
93}
94
95pub struct ListContains;
96
97impl ComputeFnVTable for ListContains {
98 fn invoke(
99 &self,
100 args: &InvocationArgs,
101 kernels: &[ArcRef<dyn Kernel>],
102 ) -> VortexResult<Output> {
103 let BinaryArgs {
104 lhs: array,
105 rhs: value,
106 ..
107 } = BinaryArgs::<()>::try_from(args)?;
108
109 let DType::List(elem_dtype, _) = array.dtype() else {
110 vortex_bail!("Array must be of List type");
111 };
112 if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
113 vortex_bail!(
114 "Element type {} of `ListViewArray` does not match search value {}",
115 elem_dtype,
116 value.dtype(),
117 );
118 };
119
120 if value.all_invalid() || array.all_invalid() {
121 return Ok(Output::Array(
122 ConstantArray::new(
123 Scalar::null(DType::Bool(Nullability::Nullable)),
124 array.len(),
125 )
126 .to_array(),
127 ));
128 }
129
130 for kernel in kernels {
131 if let Some(output) = kernel.invoke(args)? {
132 return Ok(output);
133 }
134 }
135 if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
136 return Ok(output);
137 }
138
139 let nullability = array.dtype().nullability() | value.dtype().nullability();
140
141 let result = if let Some(value_scalar) = value.as_constant() {
142 list_contains_scalar(array, &value_scalar, nullability)
143 } else if let Some(list_scalar) = array.as_constant() {
144 constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
145 } else {
146 todo!("unsupported list contains with list and element as arrays")
147 };
148
149 result.map(Output::Array)
150 }
151
152 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
153 let input = BinaryArgs::<()>::try_from(args)?;
154 Ok(DType::Bool(
155 input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
156 ))
157 }
158
159 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
160 Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
161 }
162
163 fn is_elementwise(&self) -> bool {
164 true
165 }
166}
167
168pub trait ListContainsKernel: VTable {
169 fn list_contains(
170 &self,
171 list: &dyn Array,
172 element: &Self::Array,
173 ) -> VortexResult<Option<ArrayRef>>;
174}
175
176pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
177inventory::collect!(ListContainsKernelRef);
178
179#[derive(Debug)]
180pub struct ListContainsKernelAdapter<V: VTable>(pub V);
181
182impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
183 pub const fn lift(&'static self) -> ListContainsKernelRef {
184 ListContainsKernelRef(ArcRef::new_ref(self))
185 }
186}
187
188impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
189 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
190 let BinaryArgs {
191 lhs: array,
192 rhs: value,
193 ..
194 } = BinaryArgs::<()>::try_from(args)?;
195 let Some(value) = value.as_opt::<V>() else {
196 return Ok(None);
197 };
198 self.0
199 .list_contains(array, value)
200 .map(|c| c.map(Output::Array))
201 }
202}
203
204fn constant_list_scalar_contains(
206 list_scalar: &ListScalar,
207 values: &dyn Array,
208 nullability: Nullability,
209) -> VortexResult<ArrayRef> {
210 let elements = list_scalar.elements().vortex_expect("non null");
211
212 let len = values.len();
213 let mut result: Option<ArrayRef> = None;
214 let false_scalar = Scalar::bool(false, nullability);
215 for element in elements {
216 let res = compute::fill_null(
217 &compute::compare(
218 ConstantArray::new(element, len).as_ref(),
219 values,
220 Operator::Eq,
221 )?,
222 &false_scalar,
223 )?;
224 if let Some(acc) = result {
225 result = Some(compute::or(&acc, &res)?)
226 } else {
227 result = Some(res);
228 }
229 }
230 Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
231}
232
233fn list_contains_scalar(
235 array: &dyn Array,
236 value: &Scalar,
237 nullability: Nullability,
238) -> VortexResult<ArrayRef> {
239 if array.len() > 1 && array.is_constant() {
241 let contains = list_contains_scalar(&array.slice(0..1), value, nullability)?;
242 return Ok(ConstantArray::new(contains.scalar_at(0), array.len()).into_array());
243 }
244
245 let list_array = array.to_listview();
246
247 let elems = list_array.elements();
248 if elems.is_empty() {
249 return list_false_or_null(&list_array, nullability);
251 }
252
253 let rhs = ConstantArray::new(value.clone(), elems.len());
254 let matching_elements = compute::compare(elems, rhs.as_ref(), Operator::Eq)?;
255 let matches = matching_elements.to_bool();
256
257 if let Some(pred) = matches.as_constant() {
259 return match pred.as_bool().value() {
260 None => {
263 assert!(
264 !rhs.scalar().is_null(),
265 "Search value must not be null here"
266 );
267 list_false_or_null(&list_array, nullability)
269 }
270 Some(false) => {
272 Ok(
274 ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
275 .into_array(),
276 )
277 }
278 Some(true) => {
280 list_is_not_empty(&list_array, nullability)
282 }
283 };
284 }
285
286 let offsets = list_array.offsets().to_primitive();
288 let sizes = list_array.sizes().to_primitive();
289
290 let list_matches = match_each_integer_ptype!(offsets.ptype(), |O| {
292 match_each_integer_ptype!(sizes.ptype(), |S| {
293 process_matches::<O, S>(matches, list_array.len(), offsets, sizes)
294 })
295 });
296
297 Ok(BoolArray::from_bit_buffer(
298 list_matches,
299 list_array.validity().clone().union_nullability(nullability),
300 )
301 .into_array())
302}
303
304fn process_matches<O, S>(
307 matches: BoolArray,
308 list_array_len: usize,
309 offsets: PrimitiveArray,
310 sizes: PrimitiveArray,
311) -> BitBuffer
312where
313 O: IntegerPType,
314 S: IntegerPType,
315{
316 let offsets_slice = offsets.as_slice::<O>();
317 let sizes_slice = sizes.as_slice::<S>();
318
319 (0..list_array_len)
320 .map(|i| {
321 let offset = offsets_slice[i].as_();
322 let size = sizes_slice[i].as_();
323
324 let mut set_bits =
327 BitIndexIterator::new(matches.bit_buffer().inner().as_ref(), offset, size);
328 set_bits.next().is_some()
329 })
330 .collect::<BitBuffer>()
331}
332
333fn list_false_or_null(
336 list_array: &ListViewArray,
337 nullability: Nullability,
338) -> VortexResult<ArrayRef> {
339 match list_array.validity() {
340 Validity::NonNullable => {
341 Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
343 }
344 Validity::AllValid => {
345 Ok(
347 ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
348 .into_array(),
349 )
350 }
351 Validity::AllInvalid => {
352 Ok(ConstantArray::new(
354 Scalar::null(DType::Bool(Nullability::Nullable)),
355 list_array.len(),
356 )
357 .into_array())
358 }
359 Validity::Array(validity_array) => {
360 let buffer = BitBuffer::new_unset(list_array.len());
362 Ok(
363 BoolArray::from_bit_buffer(buffer, Validity::Array(validity_array.clone()))
364 .into_array(),
365 )
366 }
367 }
368}
369
370fn list_is_not_empty(
373 list_array: &ListViewArray,
374 nullability: Nullability,
375) -> VortexResult<ArrayRef> {
376 if matches!(list_array.validity(), Validity::AllInvalid) {
378 return Ok(ConstantArray::new(
379 Scalar::null(DType::Bool(Nullability::Nullable)),
380 list_array.len(),
381 )
382 .into_array());
383 }
384
385 let sizes = list_array.sizes().to_primitive();
386 let buffer = match_each_integer_ptype!(sizes.ptype(), |S| {
387 BitBuffer::from_iter(sizes.as_slice::<S>().iter().map(|&size| size != S::zero()))
388 });
389
390 Ok(BoolArray::from_bit_buffer(
392 buffer,
393 list_array.validity().clone().union_nullability(nullability),
394 )
395 .into_array())
396}
397
398#[cfg(test)]
399mod tests {
400 use std::sync::Arc;
401
402 use itertools::Itertools;
403 use rstest::rstest;
404 use vortex_buffer::{Buffer, bitbuffer};
405 use vortex_dtype::{DType, Nullability, PType};
406 use vortex_scalar::Scalar;
407
408 use crate::arrays::{
409 BoolArray, ConstantArray, ConstantVTable, ListArray, ListVTable, ListViewArray,
410 PrimitiveArray, VarBinArray, list_view_from_list,
411 };
412 use crate::canonical::ToCanonical;
413 use crate::compute::list_contains;
414 use crate::validity::Validity;
415 use crate::vtable::ValidityHelper;
416 use crate::{Array, ArrayRef, IntoArray};
417
418 fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
419 list_view_from_list(
420 ListArray::from_iter_slow::<u64, _>(
421 values,
422 Arc::new(DType::Utf8(Nullability::NonNullable)),
423 )
424 .unwrap()
425 .as_::<ListVTable>()
426 .clone(),
427 )
428 .into_array()
429 }
430
431 fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
432 let elements = values.iter().flatten().cloned().collect_vec();
433
434 let mut offsets = values
435 .iter()
436 .scan(0u64, |st, v| {
437 *st += v.len() as u64;
438 Some(*st)
439 })
440 .collect_vec();
441 offsets.insert(0, 0u64);
442 let offsets = Buffer::from_iter(offsets).into_array();
443
444 let elements =
445 VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
446
447 list_view_from_list(ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap())
448 .into_array()
449 }
450
451 fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
452 BoolArray::from_bit_buffer(values.into_iter().collect(), validity)
453 }
454
455 #[rstest]
456 #[case(
457 nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
458 Some("a"),
459 bool_array(vec![false, true, true], Validity::NonNullable)
460 )]
461 #[case(
463 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
464 Some("a"),
465 bool_array(vec![false, true, true], Validity::AllValid)
466 )]
467 #[case(
469 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
470 Some("a"),
471 bool_array(vec![false, true, false], Validity::AllValid)
472 )]
473 #[case(
475 nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
476 Some("a"),
477 bool_array(vec![false, true, true], Validity::NonNullable)
478 )]
479 #[case(
481 nonnull_strings(vec![vec![], vec![], vec![]]),
482 Some("a"),
483 bool_array(vec![false, false, false], Validity::NonNullable)
484 )]
485 #[case(
487 nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
488 Some("a"),
489 bool_array(vec![false, false, false], Validity::NonNullable)
490 )]
491 #[case(
493 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
494 None,
495 bool_array(vec![false, true, true], Validity::AllInvalid)
496 )]
497 #[case(
499 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
500 Some("a"),
501 bool_array(vec![false, false, false], Validity::AllValid)
502 )]
503 fn test_contains_nullable(
504 #[case] list_array: ArrayRef,
505 #[case] value: Option<&str>,
506 #[case] expected: BoolArray,
507 ) {
508 let element_nullability = list_array
509 .dtype()
510 .as_list_element_opt()
511 .unwrap()
512 .nullability();
513 let scalar = match value {
514 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
515 Some(v) => Scalar::utf8(v, element_nullability),
516 };
517 let elem = ConstantArray::new(scalar, list_array.len());
518 let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
519 let bool_result = result.to_bool();
520 assert_eq!(bool_result.opt_bool_vec(), expected.opt_bool_vec());
521 assert_eq!(bool_result.validity(), expected.validity());
522 }
523
524 #[test]
525 fn test_constant_list() {
526 let list_array = ConstantArray::new(
527 Scalar::list(
528 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
529 vec![1i32.into(), 2i32.into(), 3i32.into()],
530 Nullability::NonNullable,
531 ),
532 2,
533 )
534 .into_array();
535
536 let contains = list_contains(
537 &list_array,
538 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
539 )
540 .unwrap();
541 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
542 assert_eq!(contains.to_bool().bit_buffer(), &bitbuffer![true, true],);
543 }
544
545 #[test]
546 fn test_all_nulls() {
547 let list_array = ConstantArray::new(
548 Scalar::null(DType::List(
549 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
550 Nullability::Nullable,
551 )),
552 5,
553 )
554 .into_array();
555
556 let contains = list_contains(
557 &list_array,
558 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
559 )
560 .unwrap();
561 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
562
563 assert_eq!(contains.len(), 5);
564 assert_eq!(contains.to_bool().validity(), &Validity::AllInvalid);
565 }
566
567 #[test]
568 fn test_list_array_element() {
569 let list_scalar = Scalar::list(
570 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
571 vec![1.into(), 3.into(), 6.into()],
572 Nullability::NonNullable,
573 );
574
575 let contains = list_contains(
576 ConstantArray::new(list_scalar, 7).as_ref(),
577 (0..7).collect::<PrimitiveArray>().as_ref(),
578 )
579 .unwrap();
580
581 assert_eq!(contains.len(), 7);
582 assert_eq!(
583 contains.to_bool().opt_bool_vec(),
584 vec![
585 Some(false),
586 Some(true),
587 Some(false),
588 Some(true),
589 Some(false),
590 Some(false),
591 Some(true)
592 ]
593 );
594 }
595
596 #[test]
597 fn test_list_contains_empty_listview() {
598 let empty_elements = PrimitiveArray::empty::<i32>(Nullability::NonNullable);
600 let offsets = Buffer::from_iter([0u32, 0, 0, 0]).into_array();
601 let sizes = Buffer::from_iter([0u32, 0, 0, 0]).into_array();
602
603 let list_array = unsafe {
604 ListViewArray::new_unchecked(
605 empty_elements.into_array(),
606 offsets,
607 sizes,
608 Validity::NonNullable,
609 )
610 .with_zero_copy_to_list(true)
611 };
612
613 let search = ConstantArray::new(Scalar::from(42i32), list_array.len());
615 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
616
617 assert_eq!(result.len(), 4);
619 assert_eq!(
620 result.to_bool().bool_vec(),
621 vec![false, false, false, false]
622 );
623 }
624
625 #[test]
626 fn test_list_contains_all_null_elements() {
627 let elements = PrimitiveArray::from_option_iter::<i32, _>([None, None, None, None, None]);
629 let offsets = Buffer::from_iter([0u32, 2, 4]).into_array();
630 let sizes = Buffer::from_iter([2u32, 2, 1]).into_array();
631
632 let list_array = unsafe {
633 ListViewArray::new_unchecked(
634 elements.into_array(),
635 offsets,
636 sizes,
637 Validity::NonNullable,
638 )
639 .with_zero_copy_to_list(true)
640 };
641
642 let null_search = ConstantArray::new(
644 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
645 list_array.len(),
646 );
647 let result = list_contains(list_array.as_ref(), null_search.as_ref()).unwrap();
648
649 assert_eq!(result.len(), 3);
651 assert_eq!(result.to_bool().validity(), &Validity::AllInvalid);
652
653 let non_null_search = ConstantArray::new(Scalar::from(42i32), list_array.len());
655 let result2 = list_contains(list_array.as_ref(), non_null_search.as_ref()).unwrap();
656
657 assert_eq!(result2.len(), 3);
659 assert_eq!(result2.to_bool().bool_vec(), vec![false, false, false]);
660 }
661
662 #[test]
663 fn test_list_contains_large_offsets() {
664 let elements = Buffer::from_iter([1i32, 2, 3, 4, 5]).into_array();
667
668 let offsets = Buffer::from_iter([0u32, 1, 4, 0]).into_array();
674 let sizes = Buffer::from_iter([1u32, 2, 1, 0]).into_array();
675
676 let list_array =
677 ListViewArray::new(elements.into_array(), offsets, sizes, Validity::NonNullable);
678
679 let search = ConstantArray::new(Scalar::from(2i32), list_array.len());
681 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
682
683 assert_eq!(result.len(), 4);
684 assert_eq!(
685 result.to_bool().bool_vec(),
686 vec![false, true, false, false] );
688
689 let search5 = ConstantArray::new(Scalar::from(5i32), list_array.len());
691 let result5 = list_contains(list_array.as_ref(), search5.as_ref()).unwrap();
692
693 assert_eq!(
694 result5.to_bool().bool_vec(),
695 vec![false, false, true, false] );
697 }
698
699 #[test]
700 fn test_list_contains_offset_size_boundary() {
701 let elements = Buffer::from_iter(0..256).into_array();
706 let offsets = Buffer::from_iter([0u8, 100, 200, 254]).into_array();
707 let sizes = Buffer::from_iter([50u8, 50, 54, 2]).into_array(); let list_array =
710 ListViewArray::new(elements.into_array(), offsets, sizes, Validity::NonNullable);
711
712 let search = ConstantArray::new(Scalar::from(255i32), list_array.len());
714 let result = list_contains(list_array.as_ref(), search.as_ref()).unwrap();
715
716 assert_eq!(result.len(), 4);
717 assert_eq!(result.to_bool().bool_vec(), vec![false, false, false, true]);
718
719 let search_zero = ConstantArray::new(Scalar::from(0i32), list_array.len());
721 let result_zero = list_contains(list_array.as_ref(), search_zero.as_ref()).unwrap();
722
723 assert_eq!(
724 result_zero.to_bool().bool_vec(),
725 vec![true, false, false, false]
726 );
727 }
728}