1use std::sync::LazyLock;
7
8use arcref::ArcRef;
9use arrow_buffer::BooleanBuffer;
10use arrow_buffer::bit_iterator::BitIndexIterator;
11use num_traits::AsPrimitive;
12use vortex_buffer::Buffer;
13use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
14use vortex_error::{VortexExpect, VortexResult, vortex_bail};
15use vortex_scalar::{ListScalar, Scalar};
16
17use crate::arrays::{BoolArray, ConstantArray, ListArray};
18use crate::compute::{
19 BinaryArgs, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Output, compare,
20 fill_null, or,
21};
22use crate::validity::Validity;
23use crate::vtable::{VTable, ValidityHelper};
24use crate::{Array, ArrayRef, IntoArray, ToCanonical};
25
26static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
27 let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
28 for kernel in inventory::iter::<ListContainsKernelRef> {
29 compute.register_kernel(kernel.0.clone());
30 }
31 compute
32});
33
34pub(crate) fn warm_up_vtable() -> usize {
35 LIST_CONTAINS_FN.kernels().len()
36}
37
38pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
71 LIST_CONTAINS_FN
72 .invoke(&InvocationArgs {
73 inputs: &[array.into(), value.into()],
74 options: &(),
75 })?
76 .unwrap_array()
77}
78
79pub struct ListContains;
80
81impl ComputeFnVTable for ListContains {
82 fn invoke(
83 &self,
84 args: &InvocationArgs,
85 kernels: &[ArcRef<dyn Kernel>],
86 ) -> VortexResult<Output> {
87 let BinaryArgs {
88 lhs: array,
89 rhs: value,
90 ..
91 } = BinaryArgs::<()>::try_from(args)?;
92
93 let DType::List(elem_dtype, _) = array.dtype() else {
94 vortex_bail!("Array must be of List type");
95 };
96 if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
97 vortex_bail!(
98 "Element type {} of ListArray does not match search value {}",
99 elem_dtype,
100 value.dtype(),
101 );
102 };
103
104 if value.all_invalid() || array.all_invalid() {
105 return Ok(Output::Array(
106 ConstantArray::new(
107 Scalar::null(DType::Bool(Nullability::Nullable)),
108 array.len(),
109 )
110 .to_array(),
111 ));
112 }
113
114 for kernel in kernels {
115 if let Some(output) = kernel.invoke(args)? {
116 return Ok(output);
117 }
118 }
119 if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
120 return Ok(output);
121 }
122
123 let nullability = array.dtype().nullability() | value.dtype().nullability();
124
125 let result = if let Some(value_scalar) = value.as_constant() {
126 list_contains_scalar(array, &value_scalar, nullability)
127 } else if let Some(list_scalar) = array.as_constant() {
128 constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
129 } else {
130 todo!("unsupported list contains with list and element as arrays")
131 };
132
133 result.map(Output::Array)
134 }
135
136 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
137 let input = BinaryArgs::<()>::try_from(args)?;
138 Ok(DType::Bool(
139 input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
140 ))
141 }
142
143 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
144 Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
145 }
146
147 fn is_elementwise(&self) -> bool {
148 true
149 }
150}
151
152pub trait ListContainsKernel: VTable {
153 fn list_contains(
154 &self,
155 list: &dyn Array,
156 element: &Self::Array,
157 ) -> VortexResult<Option<ArrayRef>>;
158}
159
160pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
161inventory::collect!(ListContainsKernelRef);
162
163#[derive(Debug)]
164pub struct ListContainsKernelAdapter<V: VTable>(pub V);
165
166impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
167 pub const fn lift(&'static self) -> ListContainsKernelRef {
168 ListContainsKernelRef(ArcRef::new_ref(self))
169 }
170}
171
172impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
173 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
174 let BinaryArgs {
175 lhs: array,
176 rhs: value,
177 ..
178 } = BinaryArgs::<()>::try_from(args)?;
179 let Some(value) = value.as_opt::<V>() else {
180 return Ok(None);
181 };
182 self.0
183 .list_contains(array, value)
184 .map(|c| c.map(Output::Array))
185 }
186}
187
188fn constant_list_scalar_contains(
190 list_scalar: &ListScalar,
191 values: &dyn Array,
192 nullability: Nullability,
193) -> VortexResult<ArrayRef> {
194 let elements = list_scalar.elements().vortex_expect("non null");
195
196 let len = values.len();
197 let mut result: Option<ArrayRef> = None;
198 let false_scalar = Scalar::bool(false, nullability);
199 for element in elements {
200 let res = fill_null(
201 &compare(
202 ConstantArray::new(element, len).as_ref(),
203 values,
204 Operator::Eq,
205 )?,
206 &false_scalar,
207 )?;
208 if let Some(acc) = result {
209 result = Some(or(&acc, &res)?)
210 } else {
211 result = Some(res);
212 }
213 }
214 Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
215}
216
217fn list_contains_scalar(
218 array: &dyn Array,
219 value: &Scalar,
220 nullability: Nullability,
221) -> VortexResult<ArrayRef> {
222 if array.len() > 1 && array.is_constant() {
224 let contains = list_contains_scalar(&array.slice(0..1), value, nullability)?;
225 return Ok(ConstantArray::new(contains.scalar_at(0), array.len()).into_array());
226 }
227
228 let list_array = array.to_list();
231
232 let elems = list_array.elements();
233 if elems.is_empty() {
234 return list_false_or_null(&list_array, nullability);
236 }
237
238 let rhs = ConstantArray::new(value.clone(), elems.len());
239 let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
240 let matches = matching_elements.to_bool();
241
242 if let Some(pred) = matches.as_constant() {
244 return match pred.as_bool().value() {
245 None => {
248 assert!(
249 !rhs.scalar().is_null(),
250 "Search value must not be null here"
251 );
252 list_false_or_null(&list_array, nullability)
254 }
255 Some(false) => {
257 Ok(
259 ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
260 .into_array(),
261 )
262 }
263 Some(true) => {
265 list_is_not_empty(&list_array, nullability)
267 }
268 };
269 }
270
271 let ends = list_array.offsets().to_primitive();
272 match_each_integer_ptype!(ends.ptype(), |T| {
273 Ok(reduce_with_ends(
274 ends.as_slice::<T>(),
275 matches.boolean_buffer(),
276 list_array.validity().clone().union_nullability(nullability),
277 ))
278 })
279}
280
281fn list_false_or_null(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
284 match list_array.validity() {
285 Validity::NonNullable => {
286 Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
288 }
289 Validity::AllValid => {
290 Ok(
292 ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
293 .into_array(),
294 )
295 }
296 Validity::AllInvalid => {
297 Ok(ConstantArray::new(
299 Scalar::null(DType::Bool(Nullability::Nullable)),
300 list_array.len(),
301 )
302 .into_array())
303 }
304 Validity::Array(validity_array) => {
305 let buffer = BooleanBuffer::new_unset(list_array.len());
307 Ok(
308 BoolArray::from_bool_buffer(buffer, Validity::Array(validity_array.clone()))
309 .into_array(),
310 )
311 }
312 }
313}
314
315fn list_is_not_empty(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
318 if matches!(list_array.validity(), Validity::AllInvalid) {
320 return Ok(ConstantArray::new(
321 Scalar::null(DType::Bool(Nullability::Nullable)),
322 list_array.len(),
323 )
324 .into_array());
325 }
326
327 let offsets = list_array.offsets().to_primitive();
328 let buffer = match_each_integer_ptype!(offsets.ptype(), |T| {
329 element_is_not_empty(offsets.as_slice::<T>())
330 });
331
332 Ok(BoolArray::from_bool_buffer(
334 buffer,
335 list_array.validity().clone().union_nullability(nullability),
336 )
337 .into_array())
338}
339
340fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
343 ends: &[T],
344 matches: &BooleanBuffer,
345 validity: Validity,
346) -> ArrayRef {
347 let mask: BooleanBuffer = ends
348 .windows(2)
349 .map(|window| {
350 let len = window[1].as_() - window[0].as_();
351 let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
352 set_bits.next().is_some()
353 })
354 .collect();
355
356 BoolArray::from_bool_buffer(mask, validity).into_array()
357}
358
359pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
382 if !matches!(array.dtype(), DType::List(..)) {
383 vortex_bail!("Array must be of list type");
384 }
385
386 if array.is_constant() && array.len() > 1 {
388 let elem_lens = list_elem_len(&array.slice(0..1))?;
389 return Ok(ConstantArray::new(elem_lens.scalar_at(0), array.len()).into_array());
390 }
391
392 let list_array = array.to_list();
393 let offsets = list_array.offsets().to_primitive();
394 let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
395 element_lens(offsets.as_slice::<T>()).into_array()
396 });
397
398 Ok(lens_array)
399}
400
401fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
402 values
403 .windows(2)
404 .map(|window| window[1] - window[0])
405 .collect()
406}
407
408fn element_is_not_empty<T: NativePType>(values: &[T]) -> BooleanBuffer {
409 BooleanBuffer::from_iter(values.windows(2).map(|window| window[1] != window[0]))
410}
411
412#[cfg(test)]
413mod tests {
414 use std::sync::Arc;
415
416 use itertools::Itertools;
417 use rstest::rstest;
418 use vortex_buffer::Buffer;
419 use vortex_dtype::{DType, Nullability, PType};
420 use vortex_scalar::Scalar;
421
422 use crate::arrays::{
423 BoolArray, ConstantArray, ConstantVTable, ListArray, PrimitiveArray, VarBinArray,
424 };
425 use crate::canonical::ToCanonical;
426 use crate::compute::list_contains;
427 use crate::validity::Validity;
428 use crate::vtable::ValidityHelper;
429 use crate::{Array, ArrayRef, IntoArray};
430
431 fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
432 ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
433 .unwrap()
434 }
435
436 fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
437 let elements = values.iter().flatten().cloned().collect_vec();
438 let mut offsets = values
439 .iter()
440 .scan(0u64, |st, v| {
441 *st += v.len() as u64;
442 Some(*st)
443 })
444 .collect_vec();
445 offsets.insert(0, 0u64);
446 let offsets = Buffer::from_iter(offsets).into_array();
447
448 let elements =
449 VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
450
451 ListArray::try_new(elements, offsets, Validity::NonNullable)
452 .unwrap()
453 .into_array()
454 }
455
456 fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
457 BoolArray::from_bool_buffer(values.into_iter().collect(), validity)
458 }
459
460 #[rstest]
461 #[case(
462 nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
463 Some("a"),
464 bool_array(vec![false, true, true], Validity::NonNullable)
465 )]
466 #[case(
468 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
469 Some("a"),
470 bool_array(vec![false, true, true], Validity::AllValid)
471 )]
472 #[case(
474 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
475 Some("a"),
476 bool_array(vec![false, true, false], Validity::AllValid)
477 )]
478 #[case(
480 nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
481 Some("a"),
482 bool_array(vec![false, true, true], Validity::NonNullable)
483 )]
484 #[case(
486 nonnull_strings(vec![vec![], vec![], vec![]]),
487 Some("a"),
488 bool_array(vec![false, false, false], Validity::NonNullable)
489 )]
490 #[case(
492 nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
493 Some("a"),
494 bool_array(vec![false, false, false], Validity::NonNullable)
495 )]
496 #[case(
498 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
499 None,
500 bool_array(vec![false, true, true], Validity::AllInvalid)
501 )]
502 #[case(
504 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
505 Some("a"),
506 bool_array(vec![false, false, false], Validity::AllValid)
507 )]
508 fn test_contains_nullable(
509 #[case] list_array: ArrayRef,
510 #[case] value: Option<&str>,
511 #[case] expected: BoolArray,
512 ) {
513 let element_nullability = list_array
514 .dtype()
515 .as_list_element_opt()
516 .unwrap()
517 .nullability();
518 let scalar = match value {
519 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
520 Some(v) => Scalar::utf8(v, element_nullability),
521 };
522 let elem = ConstantArray::new(scalar, list_array.len());
523 let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
524 let bool_result = result.to_bool();
525 assert_eq!(
526 bool_result.opt_bool_vec().unwrap(),
527 expected.opt_bool_vec().unwrap()
528 );
529 assert_eq!(bool_result.validity(), expected.validity());
530 }
531
532 #[test]
533 fn test_constant_list() {
534 let list_array = ConstantArray::new(
535 Scalar::list(
536 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
537 vec![1i32.into(), 2i32.into(), 3i32.into()],
538 Nullability::NonNullable,
539 ),
540 2,
541 )
542 .into_array();
543
544 let contains = list_contains(
545 &list_array,
546 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
547 )
548 .unwrap();
549 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
550 assert_eq!(
551 contains.to_bool().boolean_buffer().iter().collect_vec(),
552 vec![true, true],
553 );
554 }
555
556 #[test]
557 fn test_all_nulls() {
558 let list_array = ConstantArray::new(
559 Scalar::null(DType::List(
560 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
561 Nullability::Nullable,
562 )),
563 5,
564 )
565 .into_array();
566
567 let contains = list_contains(
568 &list_array,
569 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
570 )
571 .unwrap();
572 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
573
574 assert_eq!(contains.len(), 5);
575 assert_eq!(contains.to_bool().validity(), &Validity::AllInvalid);
576 }
577
578 #[test]
579 fn test_list_array_element() {
580 let list_scalar = Scalar::list(
581 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
582 vec![1.into(), 3.into(), 6.into()],
583 Nullability::NonNullable,
584 );
585
586 let contains = list_contains(
587 ConstantArray::new(list_scalar, 7).as_ref(),
588 (0..7).collect::<PrimitiveArray>().as_ref(),
589 )
590 .unwrap();
591
592 assert_eq!(contains.len(), 7);
593 assert_eq!(
594 contains.to_bool().opt_bool_vec().unwrap(),
595 vec![
596 Some(false),
597 Some(true),
598 Some(false),
599 Some(true),
600 Some(false),
601 Some(false),
602 Some(true)
603 ]
604 );
605 }
606}