1use std::sync::LazyLock;
7
8use arcref::ArcRef;
9use arrow_buffer::BooleanBuffer;
10use arrow_buffer::bit_iterator::BitIndexIterator;
11use num_traits::AsPrimitive;
12use vortex_buffer::Buffer;
13use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
14use vortex_error::{VortexExpect, VortexResult, vortex_bail};
15use vortex_scalar::{ListScalar, Scalar};
16
17use crate::arrays::{BoolArray, ConstantArray, ListArray};
18use crate::compute::{
19 BinaryArgs, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Output, compare,
20 fill_null, or,
21};
22use crate::validity::Validity;
23use crate::vtable::{VTable, ValidityHelper};
24use crate::{Array, ArrayRef, IntoArray, ToCanonical};
25
26static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
27 let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
28 for kernel in inventory::iter::<ListContainsKernelRef> {
29 compute.register_kernel(kernel.0.clone());
30 }
31 compute
32});
33
34pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
67 LIST_CONTAINS_FN
68 .invoke(&InvocationArgs {
69 inputs: &[array.into(), value.into()],
70 options: &(),
71 })?
72 .unwrap_array()
73}
74
75pub struct ListContains;
76
77impl ComputeFnVTable for ListContains {
78 fn invoke(
79 &self,
80 args: &InvocationArgs,
81 kernels: &[ArcRef<dyn Kernel>],
82 ) -> VortexResult<Output> {
83 let BinaryArgs {
84 lhs: array,
85 rhs: value,
86 ..
87 } = BinaryArgs::<()>::try_from(args)?;
88
89 let DType::List(elem_dtype, _) = array.dtype() else {
90 vortex_bail!("Array must be of List type");
91 };
92 if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
93 vortex_bail!(
94 "Element type {} of ListArray does not match search value {}",
95 elem_dtype,
96 value.dtype(),
97 );
98 };
99
100 if value.all_invalid() || array.all_invalid() {
101 return Ok(Output::Array(
102 ConstantArray::new(
103 Scalar::null(DType::Bool(Nullability::Nullable)),
104 array.len(),
105 )
106 .to_array(),
107 ));
108 }
109
110 for kernel in kernels {
111 if let Some(output) = kernel.invoke(args)? {
112 return Ok(output);
113 }
114 }
115 if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
116 return Ok(output);
117 }
118
119 let nullability = array.dtype().nullability() | value.dtype().nullability();
120
121 let result = if let Some(value_scalar) = value.as_constant() {
122 list_contains_scalar(array, &value_scalar, nullability)
123 } else if let Some(list_scalar) = array.as_constant() {
124 constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
125 } else {
126 todo!("unsupported list contains with list and element as arrays")
127 };
128
129 result.map(Output::Array)
130 }
131
132 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
133 let input = BinaryArgs::<()>::try_from(args)?;
134 Ok(DType::Bool(
135 input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
136 ))
137 }
138
139 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
140 Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
141 }
142
143 fn is_elementwise(&self) -> bool {
144 true
145 }
146}
147
148pub trait ListContainsKernel: VTable {
149 fn list_contains(
150 &self,
151 list: &dyn Array,
152 element: &Self::Array,
153 ) -> VortexResult<Option<ArrayRef>>;
154}
155
156pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
157inventory::collect!(ListContainsKernelRef);
158
159#[derive(Debug)]
160pub struct ListContainsKernelAdapter<V: VTable>(pub V);
161
162impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
163 pub const fn lift(&'static self) -> ListContainsKernelRef {
164 ListContainsKernelRef(ArcRef::new_ref(self))
165 }
166}
167
168impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
169 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
170 let BinaryArgs {
171 lhs: array,
172 rhs: value,
173 ..
174 } = BinaryArgs::<()>::try_from(args)?;
175 let Some(value) = value.as_opt::<V>() else {
176 return Ok(None);
177 };
178 self.0
179 .list_contains(array, value)
180 .map(|c| c.map(Output::Array))
181 }
182}
183
184fn constant_list_scalar_contains(
186 list_scalar: &ListScalar,
187 values: &dyn Array,
188 nullability: Nullability,
189) -> VortexResult<ArrayRef> {
190 let elements = list_scalar.elements().vortex_expect("non null");
191
192 let len = values.len();
193 let mut result: Option<ArrayRef> = None;
194 let false_scalar = Scalar::bool(false, nullability);
195 for element in elements {
196 let res = fill_null(
197 &compare(
198 ConstantArray::new(element, len).as_ref(),
199 values,
200 Operator::Eq,
201 )?,
202 &false_scalar,
203 )?;
204 if let Some(acc) = result {
205 result = Some(or(&acc, &res)?)
206 } else {
207 result = Some(res);
208 }
209 }
210 Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
211}
212
213fn list_contains_scalar(
214 array: &dyn Array,
215 value: &Scalar,
216 nullability: Nullability,
217) -> VortexResult<ArrayRef> {
218 if array.len() > 1 && array.is_constant() {
220 let contains = list_contains_scalar(&array.slice(0..1), value, nullability)?;
221 return Ok(ConstantArray::new(contains.scalar_at(0), array.len()).into_array());
222 }
223
224 let list_array = array.to_list()?;
227
228 let elems = list_array.elements();
229 if elems.is_empty() {
230 return list_false_or_null(&list_array, nullability);
232 }
233
234 let rhs = ConstantArray::new(value.clone(), elems.len());
235 let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
236 let matches = matching_elements.to_bool()?;
237
238 if let Some(pred) = matches.as_constant() {
240 return match pred.as_bool().value() {
241 None => {
244 assert!(
245 !rhs.scalar().is_null(),
246 "Search value must not be null here"
247 );
248 list_false_or_null(&list_array, nullability)
250 }
251 Some(false) => {
253 Ok(
255 ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
256 .into_array(),
257 )
258 }
259 Some(true) => {
261 list_is_not_empty(&list_array, nullability)
263 }
264 };
265 }
266
267 let ends = list_array.offsets().to_primitive()?;
268 match_each_integer_ptype!(ends.ptype(), |T| {
269 Ok(reduce_with_ends(
270 ends.as_slice::<T>(),
271 matches.boolean_buffer(),
272 list_array.validity().clone().union_nullability(nullability),
273 ))
274 })
275}
276
277fn list_false_or_null(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
280 match list_array.validity() {
281 Validity::NonNullable => {
282 Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
284 }
285 Validity::AllValid => {
286 Ok(
288 ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
289 .into_array(),
290 )
291 }
292 Validity::AllInvalid => {
293 Ok(ConstantArray::new(
295 Scalar::null(DType::Bool(Nullability::Nullable)),
296 list_array.len(),
297 )
298 .into_array())
299 }
300 Validity::Array(validity_array) => {
301 let buffer = BooleanBuffer::new_unset(list_array.len());
303 Ok(BoolArray::new(buffer, Validity::Array(validity_array.clone())).into_array())
304 }
305 }
306}
307
308fn list_is_not_empty(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
311 if matches!(list_array.validity(), Validity::AllInvalid) {
313 return Ok(ConstantArray::new(
314 Scalar::null(DType::Bool(Nullability::Nullable)),
315 list_array.len(),
316 )
317 .into_array());
318 }
319
320 let offsets = list_array.offsets().to_primitive()?;
321 let buffer = match_each_integer_ptype!(offsets.ptype(), |T| {
322 element_is_not_empty(offsets.as_slice::<T>())
323 });
324
325 Ok(BoolArray::new(
327 buffer,
328 list_array.validity().clone().union_nullability(nullability),
329 )
330 .into_array())
331}
332
333fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
336 ends: &[T],
337 matches: &BooleanBuffer,
338 validity: Validity,
339) -> ArrayRef {
340 let mask: BooleanBuffer = ends
341 .windows(2)
342 .map(|window| {
343 let len = window[1].as_() - window[0].as_();
344 let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
345 set_bits.next().is_some()
346 })
347 .collect();
348
349 BoolArray::new(mask, validity).into_array()
350}
351
352pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
375 if !matches!(array.dtype(), DType::List(..)) {
376 vortex_bail!("Array must be of list type");
377 }
378
379 if array.is_constant() && array.len() > 1 {
381 let elem_lens = list_elem_len(&array.slice(0..1))?;
382 return Ok(ConstantArray::new(elem_lens.scalar_at(0), array.len()).into_array());
383 }
384
385 let list_array = array.to_list()?;
386 let offsets = list_array.offsets().to_primitive()?;
387 let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
388 element_lens(offsets.as_slice::<T>()).into_array()
389 });
390
391 Ok(lens_array)
392}
393
394fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
395 values
396 .windows(2)
397 .map(|window| window[1] - window[0])
398 .collect()
399}
400
401fn element_is_not_empty<T: NativePType>(values: &[T]) -> BooleanBuffer {
402 BooleanBuffer::from_iter(values.windows(2).map(|window| window[1] != window[0]))
403}
404
405#[cfg(test)]
406mod tests {
407 use std::sync::Arc;
408
409 use itertools::Itertools;
410 use rstest::rstest;
411 use vortex_buffer::Buffer;
412 use vortex_dtype::{DType, Nullability, PType};
413 use vortex_scalar::Scalar;
414
415 use crate::arrays::{
416 BoolArray, ConstantArray, ConstantVTable, ListArray, PrimitiveArray, VarBinArray,
417 };
418 use crate::canonical::ToCanonical;
419 use crate::compute::list_contains;
420 use crate::validity::Validity;
421 use crate::vtable::ValidityHelper;
422 use crate::{Array, ArrayRef, IntoArray};
423
424 fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
425 ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
426 .unwrap()
427 }
428
429 fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
430 let elements = values.iter().flatten().cloned().collect_vec();
431 let mut offsets = values
432 .iter()
433 .scan(0u64, |st, v| {
434 *st += v.len() as u64;
435 Some(*st)
436 })
437 .collect_vec();
438 offsets.insert(0, 0u64);
439 let offsets = Buffer::from_iter(offsets).into_array();
440
441 let elements =
442 VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
443
444 ListArray::try_new(elements, offsets, Validity::NonNullable)
445 .unwrap()
446 .into_array()
447 }
448
449 fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
450 BoolArray::new(values.into_iter().collect(), validity)
451 }
452
453 #[rstest]
454 #[case(
455 nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
456 Some("a"),
457 bool_array(vec![false, true, true], Validity::NonNullable)
458 )]
459 #[case(
461 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
462 Some("a"),
463 bool_array(vec![false, true, true], Validity::AllValid)
464 )]
465 #[case(
467 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
468 Some("a"),
469 bool_array(vec![false, true, false], Validity::AllValid)
470 )]
471 #[case(
473 nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
474 Some("a"),
475 bool_array(vec![false, true, true], Validity::NonNullable)
476 )]
477 #[case(
479 nonnull_strings(vec![vec![], vec![], vec![]]),
480 Some("a"),
481 bool_array(vec![false, false, false], Validity::NonNullable)
482 )]
483 #[case(
485 nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
486 Some("a"),
487 bool_array(vec![false, false, false], Validity::NonNullable)
488 )]
489 #[case(
491 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
492 None,
493 bool_array(vec![false, true, true], Validity::AllInvalid)
494 )]
495 #[case(
497 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
498 Some("a"),
499 bool_array(vec![false, false, false], Validity::AllValid)
500 )]
501 fn test_contains_nullable(
502 #[case] list_array: ArrayRef,
503 #[case] value: Option<&str>,
504 #[case] expected: BoolArray,
505 ) {
506 let element_nullability = list_array
507 .dtype()
508 .as_list_element_opt()
509 .unwrap()
510 .nullability();
511 let scalar = match value {
512 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
513 Some(v) => Scalar::utf8(v, element_nullability),
514 };
515 let elem = ConstantArray::new(scalar, list_array.len());
516 let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
517 let bool_result = result.to_bool().expect("to_bool failed");
518 assert_eq!(
519 bool_result.opt_bool_vec().unwrap(),
520 expected.opt_bool_vec().unwrap()
521 );
522 assert_eq!(bool_result.validity(), expected.validity());
523 }
524
525 #[test]
526 fn test_constant_list() {
527 let list_array = ConstantArray::new(
528 Scalar::list(
529 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
530 vec![1i32.into(), 2i32.into(), 3i32.into()],
531 Nullability::NonNullable,
532 ),
533 2,
534 )
535 .into_array();
536
537 let contains = list_contains(
538 &list_array,
539 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
540 )
541 .unwrap();
542 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
543 assert_eq!(
544 contains
545 .to_bool()
546 .unwrap()
547 .boolean_buffer()
548 .iter()
549 .collect_vec(),
550 vec![true, true],
551 );
552 }
553
554 #[test]
555 fn test_all_nulls() {
556 let list_array = ConstantArray::new(
557 Scalar::null(DType::List(
558 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
559 Nullability::Nullable,
560 )),
561 5,
562 )
563 .into_array();
564
565 let contains = list_contains(
566 &list_array,
567 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
568 )
569 .unwrap();
570 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
571
572 assert_eq!(contains.len(), 5);
573 assert_eq!(
574 contains.to_bool().unwrap().validity(),
575 &Validity::AllInvalid
576 );
577 }
578
579 #[test]
580 fn test_list_array_element() {
581 let list_scalar = Scalar::list(
582 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
583 vec![1.into(), 3.into(), 6.into()],
584 Nullability::NonNullable,
585 );
586
587 let contains = list_contains(
588 ConstantArray::new(list_scalar, 7).as_ref(),
589 (0..7).collect::<PrimitiveArray>().as_ref(),
590 )
591 .unwrap();
592
593 assert_eq!(contains.len(), 7);
594 assert_eq!(
595 contains.to_bool().unwrap().opt_bool_vec().unwrap(),
596 vec![
597 Some(false),
598 Some(true),
599 Some(false),
600 Some(true),
601 Some(false),
602 Some(false),
603 Some(true)
604 ]
605 );
606 }
607}