vortex_array/compute/
list.rs

1//! List-related compute operations.
2
3use arrow_buffer::BooleanBuffer;
4use arrow_buffer::bit_iterator::BitIndexIterator;
5use num_traits::AsPrimitive;
6use vortex_buffer::Buffer;
7use vortex_dtype::{DType, NativePType, Nullability, match_each_integer_ptype};
8use vortex_error::{VortexResult, vortex_bail};
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::arrays::{BoolArray, ConstantArray, ListArray};
13use crate::compute::{Operator, compare, invert};
14use crate::validity::Validity;
15use crate::vtable::ValidityHelper;
16use crate::{Array, ArrayRef, IntoArray, ToCanonical};
17
18/// Compute a `Bool`-typed array the same length as `array` where elements are `true` if the list
19/// item contains the `value`, or `false` otherwise.
20///
21/// If the ListArray is nullable, then the result will contain nulls matching the null mask
22/// of the original array.
23///
24/// ## Null scalar handling
25///
26/// When the search scalar is `NULL`, then the resulting array will be a `BoolArray` containing
27/// `true` if the list contains any nulls, and `false` if the list does not contain any nulls,
28/// or `NULL` for null lists.
29///
30/// ## Example
31///
32/// ```rust
33/// use vortex_array::{Array, IntoArray, ToCanonical};
34/// use vortex_array::arrays::{ListArray, VarBinArray};
35/// use vortex_array::compute::list_contains;
36/// use vortex_array::validity::Validity;
37/// use vortex_buffer::buffer;
38/// use vortex_dtype::DType;
39/// let elements = VarBinArray::from_vec(
40///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
41/// let offsets = buffer![0u32, 1, 3, 5].into_array();
42/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
43///
44/// let matches = list_contains(list_array.as_ref(), "b".into()).unwrap();
45/// let to_vec: Vec<bool> = matches.to_bool().unwrap().boolean_buffer().iter().collect();
46/// assert_eq!(to_vec, vec![false, true, false]);
47/// ```
48pub fn list_contains(array: &dyn Array, value: Scalar) -> VortexResult<ArrayRef> {
49    let DType::List(elem_dtype, _nullability) = array.dtype() else {
50        vortex_bail!("Array must be of List type");
51    };
52    if &**elem_dtype != value.dtype() {
53        vortex_bail!("Element type of ListArray does not match search value");
54    }
55
56    // If the list array is constant, we perform a single comparison.
57    if array.is_constant() && array.len() > 1 {
58        let contains = list_contains(&array.slice(0, 1)?, value)?;
59        return Ok(ConstantArray::new(contains.scalar_at(0)?, array.len()).into_array());
60    }
61
62    // Canonicalize to a list array.
63    // NOTE(ngates): we may wish to add elements and offsets accessors to the ListArrayTrait.
64    let list_array = array.to_list()?;
65
66    if value.is_null() {
67        return list_contains_null(&list_array);
68    }
69
70    let elems = list_array.elements();
71    let ends = list_array.offsets().to_primitive()?;
72
73    let rhs = ConstantArray::new(value, elems.len());
74    let matching_elements = compare(elems, rhs.as_ref(), Operator::Eq)?;
75    let matches = matching_elements.to_bool()?;
76
77    // Fast path: no elements match.
78    if let Some(pred) = matches.as_constant() {
79        if matches!(pred.as_bool().value(), None | Some(false)) {
80            // TODO(aduffy): how do we handle null?
81            return Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array());
82        }
83    }
84
85    match_each_integer_ptype!(ends.ptype(), |$T| {
86        Ok(reduce_with_ends(ends.as_slice::<$T>(), &matches.boolean_buffer(), list_array.validity().clone()))
87    })
88}
89
90/// Returns a `Bool` array with `true` for lists which contains NULL and `false` if not, or
91/// NULL if the list itself is null.
92fn list_contains_null(list_array: &ListArray) -> VortexResult<ArrayRef> {
93    let elems = list_array.elements();
94
95    // Check element validity. We need to intersect
96    match elems.validity_mask()? {
97        // No NULL elements
98        Mask::AllTrue(_) => match list_array.validity() {
99            Validity::NonNullable => {
100                Ok(ConstantArray::new::<bool>(false, list_array.len()).into_array())
101            }
102            Validity::AllValid => Ok(ConstantArray::new(
103                Scalar::bool(true, Nullability::Nullable),
104                list_array.len(),
105            )
106            .into_array()),
107            Validity::AllInvalid => Ok(ConstantArray::new(
108                Scalar::null(DType::Bool(Nullability::Nullable)),
109                list_array.len(),
110            )
111            .into_array()),
112            Validity::Array(list_mask) => {
113                // Create a new bool array with false, and the provided nulls
114                let buffer = BooleanBuffer::new_unset(list_array.len());
115                Ok(BoolArray::new(buffer, Validity::Array(list_mask.clone())).into_array())
116            }
117        },
118        // All null elements
119        Mask::AllFalse(_) => Ok(ConstantArray::new::<bool>(true, list_array.len()).into_array()),
120        Mask::Values(mask) => {
121            let nulls = invert(&mask.into_array())?.to_bool()?;
122            let ends = list_array.offsets().to_primitive()?;
123            match_each_integer_ptype!(ends.ptype(), |$T| {
124                Ok(reduce_with_ends(
125                    list_array.offsets().to_primitive()?.as_slice::<$T>(),
126                    &nulls.boolean_buffer(),
127                    list_array.validity().clone(),
128                ))
129            })
130        }
131    }
132}
133
134// Reduce each boolean values into a Mask that indicates which elements in the
135// ListArray contain the matching value.
136fn reduce_with_ends<T: NativePType + AsPrimitive<usize>>(
137    ends: &[T],
138    matches: &BooleanBuffer,
139    validity: Validity,
140) -> ArrayRef {
141    let mask: BooleanBuffer = ends
142        .windows(2)
143        .map(|window| {
144            let len = window[1].as_() - window[0].as_();
145            let mut set_bits = BitIndexIterator::new(matches.values(), window[0].as_(), len);
146            set_bits.next().is_some()
147        })
148        .collect();
149
150    BoolArray::new(mask, validity).into_array()
151}
152
153/// Returns a new array of `u64` representing the length of each list element.
154///
155/// ## Example
156///
157/// ```rust
158/// use vortex_array::arrays::{ListArray, VarBinArray};
159/// use vortex_array::{Array, IntoArray};
160/// use vortex_array::compute::{list_elem_len};
161/// use vortex_array::validity::Validity;
162/// use vortex_buffer::buffer;
163/// use vortex_dtype::DType;
164///
165/// let elements = VarBinArray::from_vec(
166///         vec!["a", "a", "b", "a", "c"], DType::Utf8(false.into())).into_array();
167/// let offsets = buffer![0u32, 1, 3, 5].into_array();
168/// let list_array = ListArray::try_new(elements, offsets, Validity::NonNullable).unwrap();
169///
170/// let lens = list_elem_len(list_array.as_ref()).unwrap();
171/// assert_eq!(lens.scalar_at(0).unwrap(), 1u32.into());
172/// assert_eq!(lens.scalar_at(1).unwrap(), 2u32.into());
173/// assert_eq!(lens.scalar_at(2).unwrap(), 2u32.into());
174/// ```
175pub fn list_elem_len(array: &dyn Array) -> VortexResult<ArrayRef> {
176    if !matches!(array.dtype(), DType::List(..)) {
177        vortex_bail!("Array must be of list type");
178    }
179
180    // Short-circuit for constant list arrays.
181    if array.is_constant() && array.len() > 1 {
182        let elem_lens = list_elem_len(&array.slice(0, 1)?)?;
183        return Ok(ConstantArray::new(elem_lens.scalar_at(0)?, array.len()).into_array());
184    }
185
186    let list_array = array.to_list()?;
187    let offsets = list_array.offsets().to_primitive()?;
188    let lens_array = match_each_integer_ptype!(offsets.ptype(), |$T| {
189        element_lens(offsets.as_slice::<$T>()).into_array()
190    });
191
192    Ok(lens_array)
193}
194
195fn element_lens<T: NativePType>(values: &[T]) -> Buffer<T> {
196    values
197        .windows(2)
198        .map(|window| window[1] - window[0])
199        .collect()
200}
201
202#[cfg(test)]
203mod tests {
204    use std::sync::Arc;
205
206    use itertools::Itertools;
207    use rstest::rstest;
208    use vortex_buffer::Buffer;
209    use vortex_dtype::{DType, Nullability, PType};
210    use vortex_scalar::Scalar;
211
212    use crate::arrays::{BoolArray, ConstantArray, ConstantVTable, ListArray, VarBinArray};
213    use crate::canonical::ToCanonical;
214    use crate::compute::list_contains;
215    use crate::validity::Validity;
216    use crate::vtable::ValidityHelper;
217    use crate::{ArrayRef, IntoArray};
218
219    fn nonnull_strings(values: Vec<Vec<&str>>) -> ArrayRef {
220        ListArray::from_iter_slow::<u64, _>(values, Arc::new(DType::Utf8(Nullability::NonNullable)))
221            .unwrap()
222    }
223
224    fn null_strings(values: Vec<Vec<Option<&str>>>) -> ArrayRef {
225        let elements = values.iter().flatten().cloned().collect_vec();
226        let mut offsets = values
227            .iter()
228            .scan(0u64, |st, v| {
229                *st += v.len() as u64;
230                Some(*st)
231            })
232            .collect_vec();
233        offsets.insert(0, 0u64);
234        let offsets = Buffer::from_iter(offsets).into_array();
235
236        let elements =
237            VarBinArray::from_iter(elements, DType::Utf8(Nullability::Nullable)).into_array();
238
239        ListArray::try_new(elements, offsets, Validity::NonNullable)
240            .unwrap()
241            .into_array()
242    }
243
244    fn bool_array(values: Vec<bool>, validity: Option<Vec<bool>>) -> BoolArray {
245        let validity = match validity {
246            None => Validity::NonNullable,
247            Some(v) => Validity::from_iter(v),
248        };
249
250        BoolArray::new(values.into_iter().collect(), validity)
251    }
252
253    #[rstest]
254    // Case 1: list(utf8)
255    #[case(
256        nonnull_strings(vec![vec![], vec!["a"], vec!["a", "b"]]),
257        Some("a"),
258        bool_array(vec![false, true, true], None)
259    )]
260    // Case 2: list(utf8?) with NULL search scalar
261    #[case(
262        null_strings(vec![vec![], vec![Some("a"), None], vec![Some("a"), None, Some("b")]]),
263        None,
264        bool_array(vec![false, true, true], None)
265    )]
266    // Case 3: list(utf8) with all elements matching, but some empty lists
267    #[case(
268        nonnull_strings(vec![vec![], vec!["a"], vec!["a"]]),
269        Some("a"),
270        bool_array(vec![false, true, true], None)
271    )]
272    // Case 4: list(utf8) all lists empty.
273    #[case(
274        nonnull_strings(vec![vec![], vec![], vec![]]),
275        Some("a"),
276        bool_array(vec![false, false, false], None)
277    )]
278    // Case 5: list(utf8) no elements matching.
279    #[case(
280        nonnull_strings(vec![vec!["b"], vec![], vec!["b"]]),
281        Some("a"),
282        bool_array(vec![false, false, false], None)
283    )]
284    fn test_contains_nullable(
285        #[case] list_array: ArrayRef,
286        #[case] value: Option<&str>,
287        #[case] expected: BoolArray,
288    ) {
289        let element_nullability = list_array.dtype().as_list_element().unwrap().nullability();
290        let scalar = match value {
291            None => Scalar::null(DType::Utf8(Nullability::Nullable)),
292            Some(v) => Scalar::utf8(v, element_nullability),
293        };
294        let result = list_contains(&list_array, scalar).expect("list_contains failed");
295        let bool_result = result.to_bool().expect("to_bool failed");
296        assert_eq!(
297            bool_result.boolean_buffer().iter().collect_vec(),
298            expected.boolean_buffer().iter().collect_vec()
299        );
300        assert_eq!(bool_result.validity(), expected.validity());
301    }
302
303    #[test]
304    fn test_constant_list() {
305        let list_array = ConstantArray::new(
306            Scalar::list(
307                Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable)),
308                vec![1i32.into(), 2i32.into(), 3i32.into()],
309                Nullability::NonNullable,
310            ),
311            2,
312        )
313        .into_array();
314
315        let contains = list_contains(&list_array, 2i32.into()).unwrap();
316        assert!(contains.is::<ConstantVTable>(), "Expected constant result");
317        assert_eq!(
318            contains
319                .to_bool()
320                .unwrap()
321                .boolean_buffer()
322                .iter()
323                .collect_vec(),
324            vec![true, true],
325        );
326    }
327}