vortex_array/compute/
list.rs

1//! List-related compute operations.
2
3use arrow_buffer::BooleanBuffer;
4use arrow_buffer::bit_iterator::BitIndexIterator;
5use num_traits::AsPrimitive;
6use vortex_buffer::Buffer;
7use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
8use vortex_error::{VortexExpect, VortexResult, vortex_bail};
9use vortex_scalar::{ListScalar, Scalar};
10
11use crate::arrays::{BoolArray, ConstantArray, ListArray};
12use crate::compute::{Operator, compare, fill_null, or};
13use crate::validity::Validity;
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, ToCanonical};
16
17/// Compute a `Bool`-typed array the same length as `array` where elements is `true` if the list
18/// item contains the `value`, `false` otherwise.
19///
20/// ## Null scalar handling
21///
22/// If the `value` or `array` is `null` at any index the result at that index is `null`.
23///
24/// ## Format semantics
25/// ```txt
26/// list_contains(list, elem)
27///   ==> (!is_null(list) or NULL) and (!is_null(elem) or NULL) and any({elem = elem_i | elem_i in list}),
28/// ```
29///
30/// ## Example
31///
32/// ```rust
33/// use vortex_array::{Array, IntoArray, ToCanonical};
34/// use vortex_array::arrays::{ConstantArray, ListArray, VarBinArray};
35/// use vortex_array::compute::list_contains;
36/// use vortex_array::validity::Validity;
37/// use vortex_buffer::buffer;
38/// use vortex_dtype::DType;
39/// use vortex_scalar::Scalar;
40/// let elements = VarBinArray::from_vec(
41///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
42/// let offsets = buffer![0u32, 1, 3, 5].into_array();
43/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
44///
45/// let matches = list_contains(list_array.as_ref(), ConstantArray::new(Scalar::from("b"), list_array.len()).as_ref()).unwrap();
46/// let to_vec: Vec<bool> = matches.to_bool().unwrap().boolean_buffer().iter().collect();
47/// assert_eq!(to_vec, vec![false, true, false]);
48/// ```
49// TODO(joe): promote to compute fn.
50pub fn list_contains(array: &dyn Array, value: &dyn Array) -> VortexResult<ArrayRef> {
51    let DType::List(elem_dtype, _) = array.dtype() else {
52        vortex_bail!("Array must be of List type");
53    };
54    if !elem_dtype.eq_ignore_nullability(value.dtype()) {
55        vortex_bail!("Element type of ListArray does not match search value");
56    }
57
58    if value.all_invalid()? || array.all_invalid()? {
59        return Ok(ConstantArray::new(
60            Scalar::null(DType::Bool(Nullability::Nullable)),
61            array.len(),
62        )
63        .to_array());
64    }
65
66    let nullability = array.dtype().nullability() | value.dtype().nullability();
67
68    if let Some(value_scalar) = value.as_constant() {
69        list_contains_scalar(array, &value_scalar, nullability)
70    } else if let Some(list_scalar) = array.as_constant() {
71        constant_list_scalar_contains(&list_scalar.as_list(), value, nullability)
72    } else {
73        todo!("unsupported list contains with list and element as arrays")
74    }
75}
76
77// Then there is a constant list scalar (haystack) being compared to an array of needles.
78fn constant_list_scalar_contains(
79    list_scalar: &ListScalar,
80    values: &dyn Array,
81    nullability: Nullability,
82) -> VortexResult<ArrayRef> {
83    let elements = list_scalar.elements().vortex_expect("non null");
84
85    let len = values.len();
86    let mut result: Option<ArrayRef> = None;
87    let false_scalar = Scalar::bool(false, nullability);
88    for element in elements {
89        let res = fill_null(
90            &compare(
91                ConstantArray::new(element, len).as_ref(),
92                values,
93                Operator::Eq,
94            )?,
95            &false_scalar,
96        )?;
97        if let Some(acc) = result {
98            result = Some(or(&acc, &res)?)
99        } else {
100            result = Some(res);
101        }
102    }
103    Ok(result.unwrap_or_else(|| ConstantArray::new(false_scalar, len).to_array()))
104}
105
106fn list_contains_scalar(
107    array: &dyn Array,
108    value: &Scalar,
109    nullability: Nullability,
110) -> VortexResult<ArrayRef> {
111    // If the list array is constant, we perform a single comparison.
112    if array.len() > 1 && array.is_constant() {
113        let contains = list_contains_scalar(&array.slice(0, 1)?, value, nullability)?;
114        return Ok(ConstantArray::new(contains.scalar_at(0)?, array.len()).into_array());
115    }
116
117    // Canonicalize to a list array.
118    // NOTE(ngates): we may wish to add elements and offsets accessors to the ListArrayTrait.
119    let list_array = array.to_list()?;
120
121    let elems = list_array.elements();
122    if elems.is_empty() {
123        // Must return false when a list is empty (but valid), or null when the list itself is null.
124        return list_false_or_null(&list_array);
125    }
126
127    let rhs = ConstantArray::new(value.clone(), elems.len());
128    let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
129    let matches = matching_elements.to_bool()?;
130
131    // Fast path: no elements match.
132    if let Some(pred) = matches.as_constant() {
133        return match pred.as_bool().value() {
134            // All comparisons are invalid (result in `null`), and search is not null because
135            // we already checked for null above.
136            None => {
137                assert!(
138                    !rhs.scalar().is_null(),
139                    "Search value must not be null here"
140                );
141                // False, unless the list itself is null in which case we return null.
142                list_false_or_null(&list_array)
143            }
144            // No elements match, and all comparisons are valid (result in `false`).
145            Some(false) => {
146                // False, but match the nullability to the input list array.
147                Ok(
148                    ConstantArray::new(Scalar::bool(false, nullability), list_array.len())
149                        .into_array(),
150                )
151            }
152            // All elements match, and all comparisons are valid (result in `true`).
153            Some(true) => {
154                // True, unless the list itself is empty or NULL.
155                list_is_not_empty(&list_array)
156            }
157        };
158    }
159
160    let ends = list_array.offsets().to_primitive()?;
161    match_each_integer_ptype!(ends.ptype(), |T| {
162        Ok(reduce_with_ends(
163            ends.as_slice::<T>(),
164            matches.boolean_buffer(),
165            list_array.validity().clone().union_nullability(nullability),
166        ))
167    })
168}
169
170/// Returns a `Bool` array with `false` for lists that are valid,
171/// or `NULL` if the list itself is null.
172fn list_false_or_null(list_array: &ListArray) -> VortexResult<ArrayRef> {
173    match list_array.validity() {
174        Validity::NonNullable => {
175            // All false.
176            Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array())
177        }
178        Validity::AllValid => {
179            // All false, but nullable.
180            Ok(
181                ConstantArray::new(Scalar::bool(false, Nullability::Nullable), list_array.len())
182                    .into_array(),
183            )
184        }
185        Validity::AllInvalid => {
186            // All nulls, must be nullable result.
187            Ok(ConstantArray::new(
188                Scalar::null(DType::Bool(Nullability::Nullable)),
189                list_array.len(),
190            )
191            .into_array())
192        }
193        Validity::Array(validity_array) => {
194            // Create a new bool array with false, and the provided nulls
195            let buffer = BooleanBuffer::new_unset(list_array.len());
196            Ok(BoolArray::new(buffer, Validity::Array(validity_array.clone())).into_array())
197        }
198    }
199}
200
201/// Returns a `Bool` array with `true` for lists which are NOT empty, or `false` if they are empty,
202/// or `NULL` if the list itself is null.
203fn list_is_not_empty(list_array: &ListArray) -> VortexResult<ArrayRef> {
204    // Short-circuit for all invalid.
205    if matches!(list_array.validity(), Validity::AllInvalid) {
206        return Ok(ConstantArray::new(
207            Scalar::null(DType::Bool(Nullability::Nullable)),
208            list_array.len(),
209        )
210        .into_array());
211    }
212
213    let offsets = list_array.offsets().to_primitive()?;
214    let buffer = match_each_integer_ptype!(offsets.ptype(), |T| {
215        element_is_not_empty(offsets.as_slice::<T>())
216    });
217
218    // Copy over the validity mask from the input.
219    Ok(BoolArray::new(buffer, list_array.validity().clone()).into_array())
220}
221
222/// Reduces each boolean values into a Mask that indicates which elements in the
223/// ListArray contain the matching value.
224fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
225    ends: &[T],
226    matches: &BooleanBuffer,
227    validity: Validity,
228) -> ArrayRef {
229    let mask: BooleanBuffer = ends
230        .windows(2)
231        .map(|window| {
232            let len = window[1].as_() - window[0].as_();
233            let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
234            set_bits.next().is_some()
235        })
236        .collect();
237
238    BoolArray::new(mask, validity).into_array()
239}
240
241/// Returns a new array of `u64` representing the length of each list element.
242///
243/// ## Example
244///
245/// ```rust
246/// use vortex_array::arrays::{ListArray, VarBinArray};
247/// use vortex_array::{Array, IntoArray};
248/// use vortex_array::compute::{list_elem_len};
249/// use vortex_array::validity::Validity;
250/// use vortex_buffer::buffer;
251/// use vortex_dtype::DType;
252///
253/// let elements = VarBinArray::from_vec(
254///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
255/// let offsets = buffer![0u32, 1, 3, 5].into_array();
256/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
257///
258/// let lens = list_elem_len(list_array.as_ref()).unwrap();
259/// assert_eq!(lens.scalar_at(0).unwrap(), 1u32.into());
260/// assert_eq!(lens.scalar_at(1).unwrap(), 2u32.into());
261/// assert_eq!(lens.scalar_at(2).unwrap(), 2u32.into());
262/// ```
263pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
264    if !matches!(array.dtype(), DType::List(..)) {
265        vortex_bail!("Array must be of list type");
266    }
267
268    // Short-circuit for constant list arrays.
269    if array.is_constant() && array.len() > 1 {
270        let elem_lens = list_elem_len(&array.slice(0, 1)?)?;
271        return Ok(ConstantArray::new(elem_lens.scalar_at(0)?, array.len()).into_array());
272    }
273
274    let list_array = array.to_list()?;
275    let offsets = list_array.offsets().to_primitive()?;
276    let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
277        element_lens(offsets.as_slice::<T>()).into_array()
278    });
279
280    Ok(lens_array)
281}
282
283fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
284    values
285        .windows(2)
286        .map(|window| window[1] - window[0])
287        .collect()
288}
289
290fn element_is_not_empty<T: NativePType>(values: &[T]) -> BooleanBuffer {
291    BooleanBuffer::from_iter(values.windows(2).map(|window| window[1] != window[0]))
292}
293
294#[cfg(test)]
295mod tests {
296    use std::sync::Arc;
297
298    use itertools::Itertools;
299    use rstest::rstest;
300    use vortex_buffer::Buffer;
301    use vortex_dtype::{DType, Nullability, PType};
302    use vortex_scalar::Scalar;
303
304    use crate::arrays::{
305        BoolArray, ConstantArray, ConstantVTable, ListArray, PrimitiveArray, VarBinArray,
306    };
307    use crate::canonical::ToCanonical;
308    use crate::compute::list_contains;
309    use crate::validity::Validity;
310    use crate::vtable::ValidityHelper;
311    use crate::{Array, ArrayRef, IntoArray};
312
313    fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
314        ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
315            .unwrap()
316    }
317
318    fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
319        let elements = values.iter().flatten().cloned().collect_vec();
320        let mut offsets = values
321            .iter()
322            .scan(0u64, |st, v| {
323                *st += v.len() as u64;
324                Some(*st)
325            })
326            .collect_vec();
327        offsets.insert(0, 0u64);
328        let offsets = Buffer::from_iter(offsets).into_array();
329
330        let elements =
331            VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
332
333        ListArray::try_new(elements, offsets, Validity::NonNullable)
334            .unwrap()
335            .into_array()
336    }
337
338    fn bool_array(values: Vec<bool>, validity: Option<Vec<bool>>) -> BoolArray {
339        let validity = match validity {
340            None => Validity::NonNullable,
341            Some(v) => Validity::from_iter(v),
342        };
343
344        BoolArray::new(values.into_iter().collect(), validity)
345    }
346
347    #[rstest]
348    #[case(
349        nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
350        Some("a"),
351        bool_array(vec![false, true, true], None)
352    )]
353    // Cast 2: valid scalar search over nullable list, with all nulls matched
354    #[case(
355        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
356        Some("a"),
357        bool_array(vec![false, true, true], Some(vec![true, true, true]))
358    )]
359    // Cast 3: valid scalar search over nullable list, with some nulls not matched (return no nulls)
360    #[case(
361        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("b"), None, None]]),
362        Some("a"),
363        bool_array(vec![false, true, false], Some(vec![true, true, true]))
364    )]
365    // Case 4: list(utf8) with all elements matching, but some empty lists
366    #[case(
367        nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
368        Some("a"),
369        bool_array(vec![false, true, true], None)
370    )]
371    // Case 5: list(utf8) all lists empty.
372    #[case(
373        nonnull_strings(vec![vec![], vec![], vec![]]),
374        Some("a"),
375        bool_array(vec![false, false, false], None)
376    )]
377    // Case 6: list(utf8) no elements matching.
378    #[case(
379        nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
380        Some("a"),
381        bool_array(vec![false, false, false], None)
382    )]
383    // Case 7: list(utf8?) with empty + NULL elements and NULL search
384    #[case(
385        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
386        None,
387        bool_array(vec![false, true, true], Some(vec![false, false, false]))
388    )]
389    // Case 8: list(utf8?) with empty + NULL elements and search scalar
390    #[case(
391        null_strings(vec![vec![], vec![None, None], vec![None, None, None]]),
392        Some("a"),
393        bool_array(vec![false, false, false], None)
394    )]
395    fn test_contains_nullable(
396        #[case] list_array: ArrayRef,
397        #[case] value: Option<&str>,
398        #[case] expected: BoolArray,
399    ) {
400        let element_nullability = list_array.dtype().as_list_element().unwrap().nullability();
401        let scalar = match value {
402            None => Scalar::null(DType::Utf8(Nullability::Nullable)),
403            Some(v) => Scalar::utf8(v, element_nullability),
404        };
405        let elem = ConstantArray::new(scalar, list_array.len());
406        let result = list_contains(&list_array, elem.as_ref()).expect("list_contains failed");
407        let bool_result = result.to_bool().expect("to_bool failed");
408        assert_eq!(
409            bool_result.opt_bool_vec().unwrap(),
410            expected.opt_bool_vec().unwrap()
411        );
412        assert_eq!(bool_result.validity(), expected.validity());
413    }
414
415    #[test]
416    fn test_constant_list() {
417        let list_array = ConstantArray::new(
418            Scalar::list(
419                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
420                vec![1i32.into(), 2i32.into(), 3i32.into()],
421                Nullability::NonNullable,
422            ),
423            2,
424        )
425        .into_array();
426
427        let contains = list_contains(
428            &list_array,
429            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
430        )
431        .unwrap();
432        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
433        assert_eq!(
434            contains
435                .to_bool()
436                .unwrap()
437                .boolean_buffer()
438                .iter()
439                .collect_vec(),
440            vec![true, true],
441        );
442    }
443
444    #[test]
445    fn test_all_nulls() {
446        let list_array = ConstantArray::new(
447            Scalar::null(DType::List(
448                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
449                Nullability::Nullable,
450            )),
451            5,
452        )
453        .into_array();
454
455        let contains = list_contains(
456            &list_array,
457            ConstantArray::new(Scalar::from(2i32), list_array.len()).as_ref(),
458        )
459        .unwrap();
460        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
461
462        assert_eq!(contains.len(), 5);
463        assert_eq!(
464            contains.to_bool().unwrap().validity(),
465            &Validity::AllInvalid
466        );
467    }
468
469    #[test]
470    fn test_list_array_element() {
471        let list_scalar = Scalar::list(
472            Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
473            vec![1.into(), 3.into(), 6.into()],
474            Nullability::NonNullable,
475        );
476
477        let contains = list_contains(
478            ConstantArray::new(list_scalar, 7).as_ref(),
479            (0..7).collect::<PrimitiveArray>().as_ref(),
480        )
481        .unwrap();
482
483        assert_eq!(contains.len(), 7);
484        assert_eq!(
485            contains.to_bool().unwrap().opt_bool_vec().unwrap(),
486            vec![
487                Some(false),
488                Some(true),
489                Some(false),
490                Some(true),
491                Some(false),
492                Some(false),
493                Some(true)
494            ]
495        );
496    }
497}