1use std::sync::LazyLock;
7
8use arcref::ArcRef;
9use arrow_buffer::BooleanBuffer;
10use arrow_buffer::bit_iterator::BitIndexIterator;
11use num_traits::AsPrimitive;
12use vortex_buffer::Buffer;
13use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
14use vortex_error::{VortexExpect, VortexResult, vortex_bail};
15use vortex_scalar::{ListScalar, Scalar};
16
17use crate::arrays::{BoolArray, ConstantArray, ListArray};
18use crate::compute::{
19 BinaryArgs, ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Operator, Output, compare,
20 fill_null, or,
21};
22use crate::validity::Validity;
23use crate::vtable::{VTable, ValidityHelper};
24use crate::{Array, ArrayRef, IntoArray, ToCanonical};
25
26pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
59 LIST_CONTAINS_FN
60 .invoke(&InvocationArgs {
61 inputs: &[array.into(), value.into()],
62 options: &(),
63 })?
64 .unwrap_array()
65}
66
67pub struct ListContains;
68
69impl ComputeFnVTable for ListContains {
70 fn invoke(
71 &self,
72 args: &InvocationArgs,
73 kernels: &[ArcRef<dyn Kernel>],
74 ) -> VortexResult<Output> {
75 let BinaryArgs {
76 lhs: array,
77 rhs: value,
78 ..
79 } = BinaryArgs::<()>::try_from(args)?;
80
81 let DType::List(elem_dtype, _) = array.dtype() else {
82 vortex_bail!("Array must be of List type");
83 };
84 if !elem_dtype.as_ref().eq_ignore_nullability(value.dtype()) {
85 vortex_bail!(
86 "Element type {} of ListArray does not match search value {}",
87 elem_dtype,
88 value.dtype(),
89 );
90 };
91
92 if value.all_invalid()? || array.all_invalid()? {
93 return Ok(Output::Array(
94 ConstantArray::new(
95 Scalar::null(DType::Bool(Nullability::Nullable)),
96 array.len(),
97 )
98 .to_array(),
99 ));
100 }
101
102 for kernel in kernels {
103 if let Some(output) = kernel.invoke(args)? {
104 return Ok(output);
105 }
106 }
107 if let Some(output) = array.invoke(&LIST_CONTAINS_FN, args)? {
108 return Ok(output);
109 }
110
111 let nullability = array.dtype().nullability() | value.dtype().nullability();
112
113 let result = if let Some(value_scalar) = value.as_constant() {
114 list_contains_scalar(array, &value_scalar, nullability)
115 } else if let Some(list_scalar) = array.as_constant() {
116 constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
117 } else {
118 todo!("unsupported list contains with list and element as arrays")
119 };
120
121 result.map(Output::Array)
122 }
123
124 fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
125 let input = BinaryArgs::<()>::try_from(args)?;
126 Ok(DType::Bool(
127 input.lhs.dtype().nullability() | input.rhs.dtype().nullability(),
128 ))
129 }
130
131 fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
132 Ok(BinaryArgs::<()>::try_from(args)?.lhs.len())
133 }
134
135 fn is_elementwise(&self) -> bool {
136 true
137 }
138}
139
140pub trait ListContainsKernel: VTable {
141 fn list_contains(
142 &self,
143 list: &dyn Array,
144 element: &Self::Array,
145 ) -> VortexResult<Option<ArrayRef>>;
146}
147
148pub struct ListContainsKernelRef(ArcRef<dyn Kernel>);
149inventory::collect!(ListContainsKernelRef);
150
151#[derive(Debug)]
152pub struct ListContainsKernelAdapter<V: VTable>(pub V);
153
154impl<V: VTable + ListContainsKernel> ListContainsKernelAdapter<V> {
155 pub const fn lift(&'static self) -> ListContainsKernelRef {
156 ListContainsKernelRef(ArcRef::new_ref(self))
157 }
158}
159
160impl<V: VTable + ListContainsKernel> Kernel for ListContainsKernelAdapter<V> {
161 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
162 let BinaryArgs {
163 lhs: array,
164 rhs: value,
165 ..
166 } = BinaryArgs::<()>::try_from(args)?;
167 let Some(value) = value.as_opt::<V>() else {
168 return Ok(None);
169 };
170 self.0
171 .list_contains(array, value)
172 .map(|c| c.map(Output::Array))
173 }
174}
175
176pub static LIST_CONTAINS_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
177 let compute = ComputeFn::new("list_contains".into(), ArcRef::new_ref(&ListContains));
178 for kernel in inventory::iter::<ListContainsKernelRef> {
179 compute.register_kernel(kernel.0.clone());
180 }
181 compute
182});
183
184fn constant_list_scalar_contains(
186 list_scalar: &ListScalar,
187 values: &dyn Array,
188 nullability: Nullability,
189) -> VortexResult<ArrayRef> {
190 let elements = list_scalar.elements().vortex_expect("non null");
191
192 let len = values.len();
193 let mut result: Option<ArrayRef> = None;
194 let false_scalar = Scalar::bool(false, nullability);
195 for element in elements {
196 let res = fill_null(
197 &compare(
198 ConstantArray::new(element, len).as_ref(),
199 values,
200 Operator::Eq,
201 )?,
202 &false_scalar,
203 )?;
204 if let Some(acc) = result {
205 result = Some(or(&acc, &res)?)
206 } else {
207 result = Some(res);
208 }
209 }
210 Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
211}
212
213fn list_contains_scalar(
214 array: &dyn Array,
215 value: &Scalar,
216 nullability: Nullability,
217) -> VortexResult<ArrayRef> {
218 if array.len() > 1 && array.is_constant() {
220 let contains = list_contains_scalar(&array.slice(0, 1)?, value, nullability)?;
221 return Ok(ConstantArray::new(contains.scalar_at(0)?, array.len()).into_array());
222 }
223
224 let list_array = array.to_list()?;
227
228 let elems = list_array.elements();
229 if elems.is_empty() {
230 return list_false_or_null(&list_array, nullability);
232 }
233
234 let rhs = ConstantArray::new(value.clone(), elems.len());
235 let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
236 let matches = matching_elements.to_bool()?;
237
238 if let Some(pred) = matches.as_constant() {
240 return match pred.as_bool().value() {
241 None => {
244 assert!(
245 !rhs.scalar().is_null(),
246 "Search value must not be null here"
247 );
248 list_false_or_null(&list_array, nullability)
250 }
251 Some(false) => {
253 Ok(
255 ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
256 .into_array(),
257 )
258 }
259 Some(true) => {
261 list_is_not_empty(&list_array, nullability)
263 }
264 };
265 }
266
267 let ends = list_array.offsets().to_primitive()?;
268 match_each_integer_ptype!(ends.ptype(), |T| {
269 Ok(reduce_with_ends(
270 ends.as_slice::<T>(),
271 matches.boolean_buffer(),
272 list_array.validity().clone().union_nullability(nullability),
273 ))
274 })
275}
276
277fn list_false_or_null(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
280 match list_array.validity() {
281 Validity::NonNullable => {
282 Ok(ConstantArray::new(Scalar::bool(false, nullability), list_array.len()).into_array())
284 }
285 Validity::AllValid => {
286 Ok(
288 ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
289 .into_array(),
290 )
291 }
292 Validity::AllInvalid => {
293 Ok(ConstantArray::new(
295 Scalar::null(DType::Bool(Nullability::Nullable)),
296 list_array.len(),
297 )
298 .into_array())
299 }
300 Validity::Array(validity_array) => {
301 let buffer = BooleanBuffer::new_unset(list_array.len());
303 Ok(BoolArray::new(buffer, Validity::Array(validity_array.clone())).into_array())
304 }
305 }
306}
307
308fn list_is_not_empty(list_array: &ListArray, nullability: Nullability) -> VortexResult<ArrayRef> {
311 if matches!(list_array.validity(), Validity::AllInvalid) {
313 return Ok(ConstantArray::new(
314 Scalar::null(DType::Bool(Nullability::Nullable)),
315 list_array.len(),
316 )
317 .into_array());
318 }
319
320 let offsets = list_array.offsets().to_primitive()?;
321 let buffer = match_each_integer_ptype!(offsets.ptype(), |T| {
322 element_is_not_empty(offsets.as_slice::<T>())
323 });
324
325 Ok(BoolArray::new(
327 buffer,
328 list_array.validity().clone().union_nullability(nullability),
329 )
330 .into_array())
331}
332
333fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
336 ends: &[T],
337 matches: &BooleanBuffer,
338 validity: Validity,
339) -> ArrayRef {
340 let mask: BooleanBuffer = ends
341 .windows(2)
342 .map(|window| {
343 let len = window[1].as_() - window[0].as_();
344 let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
345 set_bits.next().is_some()
346 })
347 .collect();
348
349 BoolArray::new(mask, validity).into_array()
350}
351
352pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
375 if !matches!(array.dtype(), DType::List(..)) {
376 vortex_bail!("Array must be of list type");
377 }
378
379 if array.is_constant() && array.len() > 1 {
381 let elem_lens = list_elem_len(&array.slice(0, 1)?)?;
382 return Ok(ConstantArray::new(elem_lens.scalar_at(0)?, array.len()).into_array());
383 }
384
385 let list_array = array.to_list()?;
386 let offsets = list_array.offsets().to_primitive()?;
387 let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
388 element_lens(offsets.as_slice::<T>()).into_array()
389 });
390
391 Ok(lens_array)
392}
393
394fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
395 values
396 .windows(2)
397 .map(|window| window[1] - window[0])
398 .collect()
399}
400
401fn element_is_not_empty<T: NativePType>(values: &[T]) -> BooleanBuffer {
402 BooleanBuffer::from_iter(values.windows(2).map(|window| window[1] != window[0]))
403}
404
405#[cfg(test)]
406mod tests {
407 use std::sync::Arc;
408
409 use itertools::Itertools;
410 use rstest::rstest;
411 use vortex_buffer::Buffer;
412 use vortex_dtype::{DType, Nullability, PType};
413 use vortex_scalar::Scalar;
414
415 use crate::arrays::{
416 BoolArray, ConstantArray, ConstantVTable, ListArray, PrimitiveArray, VarBinArray,
417 };
418 use crate::canonical::ToCanonical;
419 use crate::compute::list_contains;
420 use crate::validity::Validity;
421 use crate::vtable::ValidityHelper;
422 use crate::{Array, ArrayRef, IntoArray};
423
424 fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
425 ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
426 .unwrap()
427 }
428
429 fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
430 let elements = values.iter().flatten().cloned().collect_vec();
431 let mut offsets = values
432 .iter()
433 .scan(0u64, |st, v| {
434 *st += v.len() as u64;
435 Some(*st)
436 })
437 .collect_vec();
438 offsets.insert(0, 0u64);
439 let offsets = Buffer::from_iter(offsets).into_array();
440
441 let elements =
442 VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
443
444 ListArray::try_new(elements, offsets, Validity::NonNullable)
445 .unwrap()
446 .into_array()
447 }
448
449 fn bool_array(values: Vec<bool>, validity: Validity) -> BoolArray {
450 BoolArray::new(values.into_iter().collect(), validity)
451 }
452
453 #[rstest]
454 #[case(
455 nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
456 Some("a"),
457 bool_array(vec![false, true, true], Validity::NonNullable)
458 )]
459 #[case(
461 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
462 Some("a"),
463 bool_array(vec![false, true, true], Validity::AllValid)
464 )]
465 #[case(
467 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
468 Some("a"),
469 bool_array(vec![false, true, false], Validity::AllValid)
470 )]
471 #[case(
473 nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
474 Some("a"),
475 bool_array(vec![false, true, true], Validity::NonNullable)
476 )]
477 #[case(
479 nonnull_strings(vec![vec![], vec![], vec![]]),
480 Some("a"),
481 bool_array(vec![false, false, false], Validity::NonNullable)
482 )]
483 #[case(
485 nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
486 Some("a"),
487 bool_array(vec![false, false, false], Validity::NonNullable)
488 )]
489 #[case(
491 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
492 None,
493 bool_array(vec![false, true, true], Validity::AllInvalid)
494 )]
495 #[case(
497 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
498 Some("a"),
499 bool_array(vec![false, false, false], Validity::AllValid)
500 )]
501 fn test_contains_nullable(
502 #[case] list_array: ArrayRef,
503 #[case] value: Option<&str>,
504 #[case] expected: BoolArray,
505 ) {
506 let element_nullability = list_array.dtype().as_list_element().unwrap().nullability();
507 let scalar = match value {
508 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
509 Some(v) => Scalar::utf8(v, element_nullability),
510 };
511 let elem = ConstantArray::new(scalar, list_array.len());
512 let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
513 let bool_result = result.to_bool().expect("to_bool failed");
514 assert_eq!(
515 bool_result.opt_bool_vec().unwrap(),
516 expected.opt_bool_vec().unwrap()
517 );
518 assert_eq!(bool_result.validity(), expected.validity());
519 }
520
521 #[test]
522 fn test_constant_list() {
523 let list_array = ConstantArray::new(
524 Scalar::list(
525 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
526 vec![1i32.into(), 2i32.into(), 3i32.into()],
527 Nullability::NonNullable,
528 ),
529 2,
530 )
531 .into_array();
532
533 let contains = list_contains(
534 &list_array,
535 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
536 )
537 .unwrap();
538 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
539 assert_eq!(
540 contains
541 .to_bool()
542 .unwrap()
543 .boolean_buffer()
544 .iter()
545 .collect_vec(),
546 vec![true, true],
547 );
548 }
549
550 #[test]
551 fn test_all_nulls() {
552 let list_array = ConstantArray::new(
553 Scalar::null(DType::List(
554 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
555 Nullability::Nullable,
556 )),
557 5,
558 )
559 .into_array();
560
561 let contains = list_contains(
562 &list_array,
563 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
564 )
565 .unwrap();
566 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
567
568 assert_eq!(contains.len(), 5);
569 assert_eq!(
570 contains.to_bool().unwrap().validity(),
571 &Validity::AllInvalid
572 );
573 }
574
575 #[test]
576 fn test_list_array_element() {
577 let list_scalar = Scalar::list(
578 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
579 vec![1.into(), 3.into(), 6.into()],
580 Nullability::NonNullable,
581 );
582
583 let contains = list_contains(
584 ConstantArray::new(list_scalar, 7).as_ref(),
585 (0..7).collect::<PrimitiveArray>().as_ref(),
586 )
587 .unwrap();
588
589 assert_eq!(contains.len(), 7);
590 assert_eq!(
591 contains.to_bool().unwrap().opt_bool_vec().unwrap(),
592 vec![
593 Some(false),
594 Some(true),
595 Some(false),
596 Some(true),
597 Some(false),
598 Some(false),
599 Some(true)
600 ]
601 );
602 }
603}