vortex_array/compute/
list.rs

1//! List-related compute operations.
2
3use arrow_buffer::BooleanBuffer;
4use arrow_buffer::bit_iterator::BitIndexIterator;
5use num_traits::AsPrimitive;
6use vortex_buffer::Buffer;
7use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
8use vortex_error::{VortexResult, vortex_bail};
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::arrays::{BoolArray, ConstantArray, ListArray};
13use crate::compute::{Operator, compare, invert};
14use crate::validity::Validity;
15use crate::vtable::ValidityHelper;
16use crate::{Array, ArrayRef, IntoArray, ToCanonical};
17
18/// Compute a `Bool`-typed array the same length as `array` where elements are `true` if the list
19/// item contains the `value`, or `false` otherwise.
20///
21/// If the ListArray is nullable, then the result will contain nulls matching the null mask
22/// of the original array.
23///
24/// ## Null scalar handling
25///
26/// When the search scalar is `NULL`, then the resulting array will be a `BoolArray` containing
27/// `true` if the list contains any nulls, and `false` if the list does not contain any nulls,
28/// or `NULL` for null lists.
29///
30/// ## Example
31///
32/// ```rust
33/// use vortex_array::{Array, IntoArray, ToCanonical};
34/// use vortex_array::arrays::{ListArray, VarBinArray};
35/// use vortex_array::compute::list_contains;
36/// use vortex_array::validity::Validity;
37/// use vortex_buffer::buffer;
38/// use vortex_dtype::DType;
39/// let elements = VarBinArray::from_vec(
40///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
41/// let offsets = buffer![0u32, 1, 3, 5].into_array();
42/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
43///
44/// let matches = list_contains(list_array.as_ref(), "b".into()).unwrap();
45/// let to_vec: Vec<bool> = matches.to_bool().unwrap().boolean_buffer().iter().collect();
46/// assert_eq!(to_vec, vec![false, true, false]);
47/// ```
48pub fn list_contains(array: &dyn Array, value: Scalar) -> VortexResult<ArrayRef> {
49    let DType::List(elem_dtype, _nullability) = array.dtype() else {
50        vortex_bail!("Array must be of List type");
51    };
52    if &**elem_dtype != value.dtype() {
53        vortex_bail!("Element type of ListArray does not match search value");
54    }
55
56    // If the list array is constant, we perform a single comparison.
57    if array.is_constant() && array.len() > 1 {
58        let contains = list_contains(&array.slice(0, 1)?, value)?;
59        return Ok(ConstantArray::new(contains.scalar_at(0)?, array.len()).into_array());
60    }
61
62    // Canonicalize to a list array.
63    // NOTE(ngates): we may wish to add elements and offsets accessors to the ListArrayTrait.
64    let list_array = array.to_list()?;
65
66    if value.is_null() {
67        return list_contains_null(&list_array);
68    }
69
70    let elems = list_array.elements();
71    let ends = list_array.offsets().to_primitive()?;
72
73    let rhs = ConstantArray::new(value, elems.len());
74    let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
75    let matches = matching_elements.to_bool()?;
76
77    // Fast path: no elements match.
78    if let Some(pred) = matches.as_constant() {
79        if matches!(pred.as_bool().value(), None | Some(false)) {
80            // TODO(aduffy): how do we handle null?
81            return Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array());
82        }
83    }
84
85    match_each_integer_ptype!(ends.ptype(), |T| {
86        Ok(reduce_with_ends(
87            ends.as_slice::<T>(),
88            matches.boolean_buffer(),
89            list_array.validity().clone(),
90        ))
91    })
92}
93
94/// Returns a `Bool` array with `true` for lists which contains NULL and `false` if not, or
95/// NULL if the list itself is null.
96fn list_contains_null(list_array: &ListArray) -> VortexResult<ArrayRef> {
97    let elems = list_array.elements();
98
99    // Check element validity. We need to intersect
100    match elems.validity_mask()? {
101        // No NULL elements
102        Mask::AllTrue(_) => match list_array.validity() {
103            Validity::NonNullable => {
104                Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array())
105            }
106            Validity::AllValid => Ok(ConstantArray::new(
107                Scalar::bool(true, Nullability::Nullable),
108                list_array.len(),
109            )
110            .into_array()),
111            Validity::AllInvalid => Ok(ConstantArray::new(
112                Scalar::null(DType::Bool(Nullability::Nullable)),
113                list_array.len(),
114            )
115            .into_array()),
116            Validity::Array(list_mask) => {
117                // Create a new bool array with false, and the provided nulls
118                let buffer = BooleanBuffer::new_unset(list_array.len());
119                Ok(BoolArray::new(buffer, Validity::Array(list_mask.clone())).into_array())
120            }
121        },
122        // All null elements
123        Mask::AllFalse(_) => Ok(ConstantArray::new::<bool>(true, list_array.len()).into_array()),
124        Mask::Values(mask) => {
125            let nulls = invert(&mask.into_array())?.to_bool()?;
126            let ends = list_array.offsets().to_primitive()?;
127            match_each_integer_ptype!(ends.ptype(), |T| {
128                Ok(reduce_with_ends(
129                    list_array.offsets().to_primitive()?.as_slice::<T>(),
130                    nulls.boolean_buffer(),
131                    list_array.validity().clone(),
132                ))
133            })
134        }
135    }
136}
137
138// Reduce each boolean values into a Mask that indicates which elements in the
139// ListArray contain the matching value.
140fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
141    ends: &[T],
142    matches: &BooleanBuffer,
143    validity: Validity,
144) -> ArrayRef {
145    let mask: BooleanBuffer = ends
146        .windows(2)
147        .map(|window| {
148            let len = window[1].as_() - window[0].as_();
149            let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
150            set_bits.next().is_some()
151        })
152        .collect();
153
154    BoolArray::new(mask, validity).into_array()
155}
156
157/// Returns a new array of `u64` representing the length of each list element.
158///
159/// ## Example
160///
161/// ```rust
162/// use vortex_array::arrays::{ListArray, VarBinArray};
163/// use vortex_array::{Array, IntoArray};
164/// use vortex_array::compute::{list_elem_len};
165/// use vortex_array::validity::Validity;
166/// use vortex_buffer::buffer;
167/// use vortex_dtype::DType;
168///
169/// let elements = VarBinArray::from_vec(
170///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
171/// let offsets = buffer![0u32, 1, 3, 5].into_array();
172/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
173///
174/// let lens = list_elem_len(list_array.as_ref()).unwrap();
175/// assert_eq!(lens.scalar_at(0).unwrap(), 1u32.into());
176/// assert_eq!(lens.scalar_at(1).unwrap(), 2u32.into());
177/// assert_eq!(lens.scalar_at(2).unwrap(), 2u32.into());
178/// ```
179pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
180    if !matches!(array.dtype(), DType::List(..)) {
181        vortex_bail!("Array must be of list type");
182    }
183
184    // Short-circuit for constant list arrays.
185    if array.is_constant() && array.len() > 1 {
186        let elem_lens = list_elem_len(&array.slice(0, 1)?)?;
187        return Ok(ConstantArray::new(elem_lens.scalar_at(0)?, array.len()).into_array());
188    }
189
190    let list_array = array.to_list()?;
191    let offsets = list_array.offsets().to_primitive()?;
192    let lens_array = match_each_integer_ptype!(offsets.ptype(), |T| {
193        element_lens(offsets.as_slice::<T>()).into_array()
194    });
195
196    Ok(lens_array)
197}
198
199fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
200    values
201        .windows(2)
202        .map(|window| window[1] - window[0])
203        .collect()
204}
205
206#[cfg(test)]
207mod tests {
208    use std::sync::Arc;
209
210    use itertools::Itertools;
211    use rstest::rstest;
212    use vortex_buffer::Buffer;
213    use vortex_dtype::{DType, Nullability, PType};
214    use vortex_scalar::Scalar;
215
216    use crate::arrays::{BoolArray, ConstantArray, ConstantVTable, ListArray, VarBinArray};
217    use crate::canonical::ToCanonical;
218    use crate::compute::list_contains;
219    use crate::validity::Validity;
220    use crate::vtable::ValidityHelper;
221    use crate::{ArrayRef, IntoArray};
222
223    fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
224        ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
225            .unwrap()
226    }
227
228    fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
229        let elements = values.iter().flatten().cloned().collect_vec();
230        let mut offsets = values
231            .iter()
232            .scan(0u64, |st, v| {
233                *st += v.len() as u64;
234                Some(*st)
235            })
236            .collect_vec();
237        offsets.insert(0, 0u64);
238        let offsets = Buffer::from_iter(offsets).into_array();
239
240        let elements =
241            VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
242
243        ListArray::try_new(elements, offsets, Validity::NonNullable)
244            .unwrap()
245            .into_array()
246    }
247
248    fn bool_array(values: Vec<bool>, validity: Option<Vec<bool>>) -> BoolArray {
249        let validity = match validity {
250            None => Validity::NonNullable,
251            Some(v) => Validity::from_iter(v),
252        };
253
254        BoolArray::new(values.into_iter().collect(), validity)
255    }
256
257    #[rstest]
258    // Case 1: list(utf8)
259    #[case(
260        nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
261        Some("a"),
262        bool_array(vec![false, true, true], None)
263    )]
264    // Case 2: list(utf8?) with NULL search scalar
265    #[case(
266        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
267        None,
268        bool_array(vec![false, true, true], None)
269    )]
270    // Case 3: list(utf8) with all elements matching, but some empty lists
271    #[case(
272        nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
273        Some("a"),
274        bool_array(vec![false, true, true], None)
275    )]
276    // Case 4: list(utf8) all lists empty.
277    #[case(
278        nonnull_strings(vec![vec![], vec![], vec![]]),
279        Some("a"),
280        bool_array(vec![false, false, false], None)
281    )]
282    // Case 5: list(utf8) no elements matching.
283    #[case(
284        nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
285        Some("a"),
286        bool_array(vec![false, false, false], None)
287    )]
288    fn test_contains_nullable(
289        #[case] list_array: ArrayRef,
290        #[case] value: Option<&str>,
291        #[case] expected: BoolArray,
292    ) {
293        let element_nullability = list_array.dtype().as_list_element().unwrap().nullability();
294        let scalar = match value {
295            None => Scalar::null(DType::Utf8(Nullability::Nullable)),
296            Some(v) => Scalar::utf8(v, element_nullability),
297        };
298        let result = list_contains(&list_array, scalar).expect("list_contains failed");
299        let bool_result = result.to_bool().expect("to_bool failed");
300        assert_eq!(
301            bool_result.boolean_buffer().iter().collect_vec(),
302            expected.boolean_buffer().iter().collect_vec()
303        );
304        assert_eq!(bool_result.validity(), expected.validity());
305    }
306
307    #[test]
308    fn test_constant_list() {
309        let list_array = ConstantArray::new(
310            Scalar::list(
311                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
312                vec![1i32.into(), 2i32.into(), 3i32.into()],
313                Nullability::NonNullable,
314            ),
315            2,
316        )
317        .into_array();
318
319        let contains = list_contains(&list_array, 2i32.into()).unwrap();
320        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
321        assert_eq!(
322            contains
323                .to_bool()
324                .unwrap()
325                .boolean_buffer()
326                .iter()
327                .collect_vec(),
328            vec![true, true],
329        );
330    }
331}