1use arrow_buffer::BooleanBuffer;
4use arrow_buffer::bit_iterator::BitIndexIterator;
5use num_traits::AsPrimitive;
6use vortex_buffer::Buffer;
7use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9use vortex_scalar::{ListScalar, Scalar};
10
11use crate::arrays::{BoolArray, ConstantArray, ListArray};
12use crate::compute::{Operator, compare, fill_null, or};
13use crate::validity::Validity;
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, ToCanonical};
16
17pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
51 let DType::List(elem_dtype, _) = array.dtype() else {
52 vortex_bail!("Array must be of List type");
53 };
54 if !elem_dtype.eq_ignore_nullability(value.dtype()) {
55 vortex_bail!("Element type of ListArray does not match search value");
56 }
57
58 if value.all_invalid()? || array.all_invalid()? {
59 return Ok(ConstantArray::new(
60 Scalar::null(DType::Bool(Nullability::Nullable)),
61 array.len(),
62 )
63 .to_array());
64 }
65
66 let nullability = array.dtype().nullability() | value.dtype().nullability();
67
68 if let Some(value_scalar) = value.as_constant() {
69 list_contains_scalar(array, &value_scalar, nullability)
70 } else if let Some(list_scalar) = array.as_constant() {
71 constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
72 } else {
73 todo!("unsupported list contains with list and element as arrays")
74 }
75}
76
77fn constant_list_scalar_contains(
79 list_scalar: &ListScalar,
80 values: &dyn Array,
81 nullability: Nullability,
82) -> VortexResult<ArrayRef> {
83 let elements = list_scalar.elements().vortex_expect("non null");
84
85 let len = values.len();
86 let mut result: Option<ArrayRef> = None;
87 let false_scalar = Scalar::bool(false, nullability);
88 for element in elements {
89 let res = fill_null(
90 &compare(
91 ConstantArray::new(element, len).as_ref(),
92 values,
93 Operator::Eq,
94 )?,
95 &false_scalar,
96 )?;
97 if let Some(acc) = result {
98 result = Some(or(&acc, &res)?)
99 } else {
100 result = Some(res);
101 }
102 }
103 Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
104}
105
106fn list_contains_scalar(
107 array: &dyn Array,
108 value: &Scalar,
109 nullability: Nullability,
110) -> VortexResult<ArrayRef> {
111 if array.len() > 1 && array.is_constant() {
113 let contains = list_contains_scalar(&array.slice(0, 1)?, value, nullability)?;
114 return Ok(ConstantArray::new(contains.scalar_at(0)?, array.len()).into_array());
115 }
116
117 let list_array = array.to_list()?;
120
121 let elems = list_array.elements();
122 if elems.is_empty() {
123 return list_false_or_null(&list_array);
125 }
126
127 let rhs = ConstantArray::new(value.clone(), elems.len());
128 let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
129 let matches = matching_elements.to_bool()?;
130
131 if let Some(pred) = matches.as_constant() {
133 return match pred.as_bool().value() {
134 None => {
137 assert!(
138 !rhs.scalar().is_null(),
139 "Search value must not be null here"
140 );
141 list_false_or_null(&list_array)
143 }
144 Some(false) => {
146 Ok(
148 ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
149 .into_array(),
150 )
151 }
152 Some(true) => {
154 list_is_not_empty(&list_array)
156 }
157 };
158 }
159
160 let ends = list_array.offsets().to_primitive()?;
161 match_each_integer_ptype!(ends.ptype(), |T| {
162 Ok(reduce_with_ends(
163 ends.as_slice::<T>(),
164 matches.boolean_buffer(),
165 list_array.validity().clone().union_nullability(nullability),
166 ))
167 })
168}
169
170fn list_false_or_null(list_array: &ListArray) -> VortexResult<ArrayRef> {
173 match list_array.validity() {
174 Validity::NonNullable => {
175 Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array())
177 }
178 Validity::AllValid => {
179 Ok(
181 ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
182 .into_array(),
183 )
184 }
185 Validity::AllInvalid => {
186 Ok(ConstantArray::new(
188 Scalar::null(DType::Bool(Nullability::Nullable)),
189 list_array.len(),
190 )
191 .into_array())
192 }
193 Validity::Array(validity_array) => {
194 let buffer = BooleanBuffer::new_unset(list_array.len());
196 Ok(BoolArray::new(buffer, Validity::Array(validity_array.clone())).into_array())
197 }
198 }
199}
200
201fn list_is_not_empty(list_array: &ListArray) -> VortexResult<ArrayRef> {
204 if matches!(list_array.validity(), Validity::AllInvalid) {
206 return Ok(ConstantArray::new(
207 Scalar::null(DType::Bool(Nullability::Nullable)),
208 list_array.len(),
209 )
210 .into_array());
211 }
212
213 let offsets = list_array.offsets().to_primitive()?;
214 let buffer = match_each_integer_ptype!(offsets.ptype(), |T| {
215 element_is_not_empty(offsets.as_slice::<T>())
216 });
217
218 Ok(BoolArray::new(buffer, list_array.validity().clone()).into_array())
220}
221
222fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
225 ends: &[T],
226 matches: &BooleanBuffer,
227 validity: Validity,
228) -> ArrayRef {
229 let mask: BooleanBuffer = ends
230 .windows(2)
231 .map(|window| {
232 let len = window[1].as_() - window[0].as_();
233 let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
234 set_bits.next().is_some()
235 })
236 .collect();
237
238 BoolArray::new(mask, validity).into_array()
239}
240
241pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
264 if !matches!(array.dtype(), DType::List(..)) {
265 vortex_bail!("Array must be of list type");
266 }
267
268 if array.is_constant() && array.len() > 1 {
270 let elem_lens = list_elem_len(&array.slice(0, 1)?)?;
271 return Ok(ConstantArray::new(elem_lens.scalar_at(0)?, array.len()).into_array());
272 }
273
274 let list_array = array.to_list()?;
275 let offsets = list_array.offsets().to_primitive()?;
276 let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
277 element_lens(offsets.as_slice::<T>()).into_array()
278 });
279
280 Ok(lens_array)
281}
282
283fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
284 values
285 .windows(2)
286 .map(|window| window[1] - window[0])
287 .collect()
288}
289
290fn element_is_not_empty<T: NativePType>(values: &[T]) -> BooleanBuffer {
291 BooleanBuffer::from_iter(values.windows(2).map(|window| window[1] != window[0]))
292}
293
294#[cfg(test)]
295mod tests {
296 use std::sync::Arc;
297
298 use itertools::Itertools;
299 use rstest::rstest;
300 use vortex_buffer::Buffer;
301 use vortex_dtype::{DType, Nullability, PType};
302 use vortex_scalar::Scalar;
303
304 use crate::arrays::{
305 BoolArray, ConstantArray, ConstantVTable, ListArray, PrimitiveArray, VarBinArray,
306 };
307 use crate::canonical::ToCanonical;
308 use crate::compute::list_contains;
309 use crate::validity::Validity;
310 use crate::vtable::ValidityHelper;
311 use crate::{Array, ArrayRef, IntoArray};
312
313 fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
314 ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
315 .unwrap()
316 }
317
318 fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
319 let elements = values.iter().flatten().cloned().collect_vec();
320 let mut offsets = values
321 .iter()
322 .scan(0u64, |st, v| {
323 *st += v.len() as u64;
324 Some(*st)
325 })
326 .collect_vec();
327 offsets.insert(0, 0u64);
328 let offsets = Buffer::from_iter(offsets).into_array();
329
330 let elements =
331 VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
332
333 ListArray::try_new(elements, offsets, Validity::NonNullable)
334 .unwrap()
335 .into_array()
336 }
337
338 fn bool_array(values: Vec<bool>, validity: Option<Vec<bool>>) -> BoolArray {
339 let validity = match validity {
340 None => Validity::NonNullable,
341 Some(v) => Validity::from_iter(v),
342 };
343
344 BoolArray::new(values.into_iter().collect(), validity)
345 }
346
347 #[rstest]
348 #[case(
349 nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
350 Some("a"),
351 bool_array(vec![false, true, true], None)
352 )]
353 #[case(
355 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
356 Some("a"),
357 bool_array(vec![false, true, true], Some(vec![true, true, true]))
358 )]
359 #[case(
361 null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
362 Some("a"),
363 bool_array(vec![false, true, false], Some(vec![true, true, true]))
364 )]
365 #[case(
367 nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
368 Some("a"),
369 bool_array(vec![false, true, true], None)
370 )]
371 #[case(
373 nonnull_strings(vec![vec![], vec![], vec![]]),
374 Some("a"),
375 bool_array(vec![false, false, false], None)
376 )]
377 #[case(
379 nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
380 Some("a"),
381 bool_array(vec![false, false, false], None)
382 )]
383 #[case(
385 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
386 None,
387 bool_array(vec![false, true, true], Some(vec![false, false, false]))
388 )]
389 #[case(
391 null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
392 Some("a"),
393 bool_array(vec![false, false, false], None)
394 )]
395 fn test_contains_nullable(
396 #[case] list_array: ArrayRef,
397 #[case] value: Option<&str>,
398 #[case] expected: BoolArray,
399 ) {
400 let element_nullability = list_array.dtype().as_list_element().unwrap().nullability();
401 let scalar = match value {
402 None => Scalar::null(DType::Utf8(Nullability::Nullable)),
403 Some(v) => Scalar::utf8(v, element_nullability),
404 };
405 let elem = ConstantArray::new(scalar, list_array.len());
406 let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
407 let bool_result = result.to_bool().expect("to_bool failed");
408 assert_eq!(
409 bool_result.opt_bool_vec().unwrap(),
410 expected.opt_bool_vec().unwrap()
411 );
412 assert_eq!(bool_result.validity(), expected.validity());
413 }
414
415 #[test]
416 fn test_constant_list() {
417 let list_array = ConstantArray::new(
418 Scalar::list(
419 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
420 vec![1i32.into(), 2i32.into(), 3i32.into()],
421 Nullability::NonNullable,
422 ),
423 2,
424 )
425 .into_array();
426
427 let contains = list_contains(
428 &list_array,
429 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
430 )
431 .unwrap();
432 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
433 assert_eq!(
434 contains
435 .to_bool()
436 .unwrap()
437 .boolean_buffer()
438 .iter()
439 .collect_vec(),
440 vec![true, true],
441 );
442 }
443
444 #[test]
445 fn test_all_nulls() {
446 let list_array = ConstantArray::new(
447 Scalar::null(DType::List(
448 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
449 Nullability::Nullable,
450 )),
451 5,
452 )
453 .into_array();
454
455 let contains = list_contains(
456 &list_array,
457 ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
458 )
459 .unwrap();
460 assert!(contains.is::<ConstantVTable>(), "Expected constant result");
461
462 assert_eq!(contains.len(), 5);
463 assert_eq!(
464 contains.to_bool().unwrap().validity(),
465 &Validity::AllInvalid
466 );
467 }
468
469 #[test]
470 fn test_list_array_element() {
471 let list_scalar = Scalar::list(
472 Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
473 vec![1.into(), 3.into(), 6.into()],
474 Nullability::NonNullable,
475 );
476
477 let contains = list_contains(
478 ConstantArray::new(list_scalar, 7).as_ref(),
479 (0..7).collect::<PrimitiveArray>().as_ref(),
480 )
481 .unwrap();
482
483 assert_eq!(contains.len(), 7);
484 assert_eq!(
485 contains.to_bool().unwrap().opt_bool_vec().unwrap(),
486 vec![
487 Some(false),
488 Some(true),
489 Some(false),
490 Some(true),
491 Some(false),
492 Some(false),
493 Some(true)
494 ]
495 );
496 }
497}