1use std::sync::LazyLock;
4
5use arcref::ArcRef;
6use arrow_buffer::BooleanBuffer;
7use arrow_buffer::bit_iterator::BitIndexIterator;
8use num_traits::AsPrimitive;
9use vortex_buffer::Buffer;
10use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
11use vortex_error::{VortexExpect, VortexResult, vortex_bail};
12use vortex_scalar::{ListScalar, Scalar};
13
14use crate::arrays::{BoolArray, ConstantArray, ListArray};
15use crate::compute::{
16 BinaryArgs, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Output, compare,
17 fill_null, or,
18};
19use crate::validity::Validity;
20use crate::vtable::{VTable, ValidityHelper};
21use crate::{Array, ArrayRef, IntoArray, ToCanonical};
22
23pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
56 LIST_CONTAINS_FN
57 .invoke(&InvocationArgs {
58 inputs: &[array.into(), value.into()],
59 options: &(),
60 })?
61 .unwrap_array()
62}
63
64pub struct ListContains;
65
66impl ComputeFnVTable for ListContains {
67 fn invoke(
68 &self,
69 args: &InvocationArgs,
70 kernels: &[ArcRef<dyn Kernel>],
71 ) -> VortexResult<Output> {
72 let BinaryArgs {
73 lhs: array,
74 rhs: value,
75 ..
76 } = BinaryArgs::<()>::try_from(args)?;
77
78 let DType::List(elem_dtype, _) = array.dtype() else {
79 vortex_bail!("Array must be of List type");
80 };
81 if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
82 vortex_bail!(
83 "Element type {} of ListArray does not match search value {}",
84 elem_dtype,
85 value.dtype(),
86 );
87 };
88
89 if value.all_invalid()? || array.all_invalid()? {
90 return Ok(Output::Array(
91 ConstantArray::new(
92 Scalar::null(DType::Bool(Nullability::Nullable)),
93 array.len(),
94 )
95 .to_array(),
96 ));
97 }
98
99 for kernel in kernels {
100 if let Some(output) = kernel.invoke(args)? {
101 return Ok(output);
102 }
103 }
104 if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
105 return Ok(output);
106 }
107
108 let nullability = array.dtype().nullability() | value.dtype().nullability();
109
110 let result = if let Some(value_scalar) = value.as_constant() {
111 list_contains_scalar(array, &value_scalar, nullability)
112 } else if let Some(list_scalar) = array.as_constant() {
113 constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
114 } else {
115 todo!("unsupported list contains with list and element as arrays")
116 };
117
118 result.map(Output::Array)
119 }
120
121 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
122 let input = BinaryArgs::<()>::try_from(args)?;
123 Ok(DType::Bool(
124 input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
125 ))
126 }
127
128 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
129 Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
130 }
131
132 fn is_elementwise(&self) -> bool {
133 true
134 }
135}
136
137pub trait ListContainsKernel: VTable {
138 fn list_contains(
139 &self,
140 list: &dyn Array,
141 element: &Self::Array,
142 ) -> VortexResult<Option<ArrayRef>>;
143}
144
145pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
146inventory::collect!(ListContainsKernelRef);
147
148#[derive(Debug)]
149pub struct ListContainsKernelAdapter<V: VTable>(pub V);
150
151impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
152 pub const fn lift(&'static self) -> ListContainsKernelRef {
153 ListContainsKernelRef(ArcRef::new_ref(self))
154 }
155}
156
157impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
158 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
159 let BinaryArgs {
160 lhs: array,
161 rhs: value,
162 ..
163 } = BinaryArgs::<()>::try_from(args)?;
164 let Some(value) = value.as_opt::<V>() else {
165 return Ok(None);
166 };
167 self.0
168 .list_contains(array, value)
169 .map(|c| c.map(Output::Array))
170 }
171}
172
173pub static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
174 let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
175 for kernel in inventory::iter::<ListContainsKernelRef> {
176 compute.register_kernel(kernel.0.clone());
177 }
178 compute
179});
180
181fn constant_list_scalar_contains(
183 list_scalar: &ListScalar,
184 values: &dyn Array,
185 nullability: Nullability,
186) -> VortexResult<ArrayRef> {
187 let elements = list_scalar.elements().vortex_expect("non null");
188
189 let len = values.len();
190 let mut result: Option<ArrayRef> = None;
191 let false_scalar = Scalar::bool(false, nullability);
192 for element in elements {
193 let res = fill_null(
194 &compare(
195 ConstantArray::new(element, len).as_ref(),
196 values,
197 Operator::Eq,
198 )?,
199 &false_scalar,
200 )?;
201 if let Some(acc) = result {
202 result = Some(or(&acc, &res)?)
203 } else {
204 result = Some(res);
205 }
206 }
207 Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
208}
209
210fn list_contains_scalar(
211 array: &dyn Array,
212 value: &Scalar,
213 nullability: Nullability,
214) -> VortexResult<ArrayRef> {
215 if array.len() > 1 && array.is_constant() {
217 let contains = list_contains_scalar(&array.slice(0, 1)?, value, nullability)?;
218 return Ok(ConstantArray::new(contains.scalar_at(0)?, array.len()).into_array());
219 }
220
221 let list_array = array.to_list()?;
224
225 let elems = list_array.elements();
226 if elems.is_empty() {
227 return list_false_or_null(&list_array, nullability);
229 }
230
231 let rhs = ConstantArray::new(value.clone(), elems.len());
232 let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
233 let matches = matching_elements.to_bool()?;
234
235 if let Some(pred) = matches.as_constant() {
237 return match pred.as_bool().value() {
238 None => {
241 assert!(
242 !rhs.scalar().is_null(),
243 "Search value must not be null here"
244 );
245 list_false_or_null(&list_array, nullability)
247 }
248 Some(false) => {
250 Ok(
252 ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
253 .into_array(),
254 )
255 }
256 Some(true) => {
258 list_is_not_empty(&list_array, nullability)
260 }
261 };
262 }
263
264 let ends = list_array.offsets().to_primitive()?;
265 match_each_integer_ptype!(ends.ptype(), |T| {
266 Ok(reduce_with_ends(
267 ends.as_slice::<T>(),
268 matches.boolean_buffer(),
269 list_array.validity().clone().union_nullability(nullability),
270 ))
271 })
272}
273
274fn list_false_or_null(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
277 match list_array.validity() {
278 Validity::NonNullable => {
279 Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
281 }
282 Validity::AllValid => {
283 Ok(
285 ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
286 .into_array(),
287 )
288 }
289 Validity::AllInvalid => {
290 Ok(ConstantArray::new(
292 Scalar::null(DType::Bool(Nullability::Nullable)),
293 list_array.len(),
294 )
295 .into_array())
296 }
297 Validity::Array(validity_array) => {
298 let buffer = BooleanBuffer::new_unset(list_array.len());
300 Ok(BoolArray::new(buffer, Validity::Array(validity_array.clone())).into_array())
301 }
302 }
303}
304
305fn list_is_not_empty(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
308 if matches!(list_array.validity(), Validity::AllInvalid) {
310 return Ok(ConstantArray::new(
311 Scalar::null(DType::Bool(Nullability::Nullable)),
312 list_array.len(),
313 )
314 .into_array());
315 }
316
317 let offsets = list_array.offsets().to_primitive()?;
318 let buffer = match_each_integer_ptype!(offsets.ptype(), |T| {
319 element_is_not_empty(offsets.as_slice::<T>())
320 });
321
322 Ok(BoolArray::new(
324 buffer,
325 list_array.validity().clone().union_nullability(nullability),
326 )
327 .into_array())
328}
329
330fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
333 ends: &[T],
334 matches: &BooleanBuffer,
335 validity: Validity,
336) -> ArrayRef {
337 let mask: BooleanBuffer = ends
338 .windows(2)
339 .map(|window| {
340 let len = window[1].as_() - window[0].as_();
341 let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
342 set_bits.next().is_some()
343 })
344 .collect();
345
346 BoolArray::new(mask, validity).into_array()
347}
348
349pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
372 if !matches!(array.dtype(), DType::List(..)) {
373 vortex_bail!("Array must be of list type");
374 }
375
376 if array.is_constant() && array.len() > 1 {
378 let elem_lens = list_elem_len(&array.slice(0, 1)?)?;
379 return Ok(ConstantArray::new(elem_lens.scalar_at(0)?, array.len()).into_array());
380 }
381
382 let list_array = array.to_list()?;
383 let offsets = list_array.offsets().to_primitive()?;
384 let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
385 element_lens(offsets.as_slice::<T>()).into_array()
386 });
387
388 Ok(lens_array)
389}
390
391fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
392 values
393 .windows(2)
394 .map(|window| window[1] - window[0])
395 .collect()
396}
397
398fn element_is_not_empty<T: NativePType>(values: &[T]) -> BooleanBuffer {
399 BooleanBuffer::from_iter(values.windows(2).map(|window| window[1] != window[0]))
400}
401
402#[cfg(test)]
403mod tests {
404 use std::sync::Arc;
405
406 use itertools::Itertools;
407 use rstest::rstest;
408 use vortex_buffer::Buffer;
409 use vortex_dtype::{DType, Nullability, PType};
410 use vortex_scalar::Scalar;
411
412 use crate::arrays::{
413 BoolArray, ConstantArray, ConstantVTable, ListArray, PrimitiveArray, VarBinArray,
414 };
415 use crate::canonical::ToCanonical;
416 use crate::compute::list_contains;
417 use crate::validity::Validity;
418 use crate::vtable::ValidityHelper;
419 use crate::{Array, ArrayRef, IntoArray};
420
421 fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
422 ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
423 .unwrap()
424 }
425
426 fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
427 let elements = values.iter().flatten().cloned().collect_vec();
428 let mut offsets = values
429 .iter()
430 .scan(0u64, |st, v| {
431 *st += v.len() as u64;
432 Some(*st)
433 })
434 .collect_vec();
435 offsets.insert(0, 0u64);
436 let offsets = Buffer::from_iter(offsets).into_array();
437
438 let elements =
439 VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
440
441 ListArray::try_new(elements, offsets, Validity::NonNullable)
442 .unwrap()
443 .into_array()
444 }
445
446 fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
447 BoolArray::new(values.into_iter().collect(), validity)
448 }
449
450 #[rstest]
451 #[case(
452 nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
453 Some("a"),
454 bool_array(vec![false, true, true], Validity::NonNullable)
455 )]
456 #[case(
458 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
459 Some("a"),
460 bool_array(vec![false, true, true], Validity::AllValid)
461 )]
462 #[case(
464 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
465 Some("a"),
466 bool_array(vec![false, true, false], Validity::AllValid)
467 )]
468 #[case(
470 nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
471 Some("a"),
472 bool_array(vec![false, true, true], Validity::NonNullable)
473 )]
474 #[case(
476 nonnull_strings(vec![vec![], vec![], vec![]]),
477 Some("a"),
478 bool_array(vec![false, false, false], Validity::NonNullable)
479 )]
480 #[case(
482 nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
483 Some("a"),
484 bool_array(vec![false, false, false], Validity::NonNullable)
485 )]
486 #[case(
488 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
489 None,
490 bool_array(vec![false, true, true], Validity::AllInvalid)
491 )]
492 #[case(
494 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
495 Some("a"),
496 bool_array(vec![false, false, false], Validity::AllValid)
497 )]
498 fn test_contains_nullable(
499 #[case] list_array: ArrayRef,
500 #[case] value: Option<&str>,
501 #[case] expected: BoolArray,
502 ) {
503 let element_nullability = list_array.dtype().as_list_element().unwrap().nullability();
504 let scalar = match value {
505 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
506 Some(v) => Scalar::utf8(v, element_nullability),
507 };
508 let elem = ConstantArray::new(scalar, list_array.len());
509 let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
510 let bool_result = result.to_bool().expect("to_bool failed");
511 assert_eq!(
512 bool_result.opt_bool_vec().unwrap(),
513 expected.opt_bool_vec().unwrap()
514 );
515 assert_eq!(bool_result.validity(), expected.validity());
516 }
517
518 #[test]
519 fn test_constant_list() {
520 let list_array = ConstantArray::new(
521 Scalar::list(
522 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
523 vec![1i32.into(), 2i32.into(), 3i32.into()],
524 Nullability::NonNullable,
525 ),
526 2,
527 )
528 .into_array();
529
530 let contains = list_contains(
531 &list_array,
532 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
533 )
534 .unwrap();
535 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
536 assert_eq!(
537 contains
538 .to_bool()
539 .unwrap()
540 .boolean_buffer()
541 .iter()
542 .collect_vec(),
543 vec![true, true],
544 );
545 }
546
547 #[test]
548 fn test_all_nulls() {
549 let list_array = ConstantArray::new(
550 Scalar::null(DType::List(
551 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
552 Nullability::Nullable,
553 )),
554 5,
555 )
556 .into_array();
557
558 let contains = list_contains(
559 &list_array,
560 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
561 )
562 .unwrap();
563 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
564
565 assert_eq!(contains.len(), 5);
566 assert_eq!(
567 contains.to_bool().unwrap().validity(),
568 &Validity::AllInvalid
569 );
570 }
571
572 #[test]
573 fn test_list_array_element() {
574 let list_scalar = Scalar::list(
575 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
576 vec![1.into(), 3.into(), 6.into()],
577 Nullability::NonNullable,
578 );
579
580 let contains = list_contains(
581 ConstantArray::new(list_scalar, 7).as_ref(),
582 (0..7).collect::<PrimitiveArray>().as_ref(),
583 )
584 .unwrap();
585
586 assert_eq!(contains.len(), 7);
587 assert_eq!(
588 contains.to_bool().unwrap().opt_bool_vec().unwrap(),
589 vec![
590 Some(false),
591 Some(true),
592 Some(false),
593 Some(true),
594 Some(false),
595 Some(false),
596 Some(true)
597 ]
598 );
599 }
600}