vortex_array/arrays/list/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod compute;
5mod serde;
6
7use std::sync::Arc;
8
9#[cfg(feature = "test-harness")]
10use itertools::Itertools;
11use num_traits::{AsPrimitive, NumCast, PrimInt};
12use vortex_dtype::{DType, NativePType, match_each_integer_ptype, match_each_native_ptype};
13use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_ensure};
14use vortex_scalar::Scalar;
15
16use crate::arrays::PrimitiveVTable;
17#[cfg(feature = "test-harness")]
18use crate::builders::{ArrayBuilder, ListBuilder};
19use crate::compute::{min_max, sub_scalar};
20use crate::stats::{ArrayStats, StatsSetRef};
21use crate::validity::Validity;
22use crate::vtable::{
23    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
24    ValidityVTableFromValidityHelper,
25};
26use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
27
28vtable!(List);
29
30impl VTable for ListVTable {
31    type Array = ListArray;
32    type Encoding = ListEncoding;
33
34    type ArrayVTable = Self;
35    type CanonicalVTable = Self;
36    type OperationsVTable = Self;
37    type ValidityVTable = ValidityVTableFromValidityHelper;
38    type VisitorVTable = Self;
39    type ComputeVTable = NotSupported;
40    type EncodeVTable = NotSupported;
41    type SerdeVTable = Self;
42
43    fn id(_encoding: &Self::Encoding) -> EncodingId {
44        EncodingId::new_ref("vortex.list")
45    }
46
47    fn encoding(_array: &Self::Array) -> EncodingRef {
48        EncodingRef::new_ref(ListEncoding.as_ref())
49    }
50}
51
52/// A list array that stores variable-length lists of elements, similar to `Vec<Vec<T>>`.
53///
54/// This mirrors the Apache Arrow List array encoding and provides efficient storage
55/// for nested data where each row contains a list of elements of the same type.
56///
57/// ## Data Layout
58///
59/// The list array uses an offset-based encoding:
60/// - **Elements array**: A flat array containing all list elements concatenated together
61/// - **Offsets array**: Integer array where `offsets[i]` is an (inclusive) start index into
62///   the **elements** and `offsets[i+1]` is the (exclusive) stop index for the `i`th list.
63/// - **Validity**: Optional mask indicating which lists are null
64///
65/// This allows for excellent cascading compression of the elements and offsets, as similar values
66/// are clustered together and the offsets have a predictable pattern and small deltas between
67/// consecutive elements.
68///
69/// ## Offset Semantics
70///
71/// - Offsets must be non-nullable integers (i32, i64, etc.)
72/// - Offsets array has length `n+1` where `n` is the number of lists
73/// - List `i` contains elements from `elements[offsets[i]..offsets[i+1]]`  
74/// - Offsets must be monotonically increasing
75///
76/// # Examples
77///
78/// ```
79/// use vortex_array::arrays::{ListArray, PrimitiveArray};
80/// use vortex_array::validity::Validity;
81/// use vortex_array::IntoArray;
82/// use std::sync::Arc;
83///
84/// // Create a list array representing [[1, 2], [3, 4, 5], []]
85/// let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
86/// let offsets = PrimitiveArray::from_iter([0u32, 2, 5, 5]); // 3 lists
87///
88/// let list_array = ListArray::try_new(
89///     elements.into_array(),
90///     offsets.into_array(),
91///     Validity::NonNullable,
92/// ).unwrap();
93///
94/// assert_eq!(list_array.len(), 3);
95///
96/// // Access individual lists
97/// let first_list = list_array.elements_at(0);
98/// assert_eq!(first_list.len(), 2); // [1, 2]
99///
100/// let third_list = list_array.elements_at(2);
101/// assert!(third_list.is_empty()); // []
102/// ```
103#[derive(Clone, Debug)]
104pub struct ListArray {
105    dtype: DType,
106    elements: ArrayRef,
107    offsets: ArrayRef,
108    validity: Validity,
109    stats_set: ArrayStats,
110}
111
112#[derive(Clone, Debug)]
113pub struct ListEncoding;
114
115pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
116
117impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
118
119// A list is valid if the:
120// - offsets start at a value in elements
121// - offsets are sorted
122// - the final offset points to an element in the elements list, pointing to zero
123//   if elements are empty.
124// - final_offset >= start_offset
125// - The size of the validity is the size-1 of the offset array
126
127impl ListArray {
128    fn validate(
129        elements: &dyn Array,
130        offsets: &dyn Array,
131        validity: &Validity,
132    ) -> VortexResult<()> {
133        // Offsets must be of integer type, and cannot go lower than 0.
134        vortex_ensure!(
135            offsets.dtype().is_int() && !offsets.dtype().is_nullable(),
136            "offsets have invalid type {}",
137            offsets.dtype()
138        );
139
140        // We can safely unwrap the DType as primitive now
141        let offsets_ptype = offsets.dtype().as_ptype();
142
143        // Offsets must be sorted (but not strictly sorted, zero-length lists are allowed)
144        if let Some(is_sorted) = offsets.statistics().compute_is_sorted() {
145            vortex_ensure!(is_sorted, "offsets must be sorted");
146        } else {
147            vortex_bail!("offsets must report is_sorted statistic");
148        }
149
150        // Validate that offsets min is non-negative, and max does not exceed the length of
151        // the elements array.
152        if let Some(min_max) = min_max(offsets)? {
153            match_each_integer_ptype!(offsets_ptype, |P| {
154                let max_offset = <P as NumCast>::from(elements.len()).unwrap_or(P::MAX);
155
156                #[allow(clippy::absurd_extreme_comparisons, unused_comparisons)]
157                {
158                    if let Some(min) = min_max.min.as_primitive().as_::<P>() {
159                        vortex_ensure!(
160                            min >= 0 && min <= max_offset,
161                            "offsets minimum {min} outside valid range [0, {max_offset}]"
162                        );
163                    }
164
165                    if let Some(max) = min_max.max.as_primitive().as_::<P>() {
166                        vortex_ensure!(
167                            max >= 0 && max <= max_offset,
168                            "offsets maximum {max} outside valid range [0, {max_offset}]"
169                        )
170                    }
171                }
172            })
173        } else {
174            // TODO(aduffy): fallback to slower validation pathway?
175            vortex_bail!(
176                "offsets array with encoding {} must support min_max compute function",
177                offsets.encoding_id()
178            );
179        };
180
181        // If a validity array is present, it must be the same length as the ListArray
182        if let Some(validity_len) = validity.maybe_len() {
183            vortex_ensure!(
184                validity_len == offsets.len() - 1,
185                "validity with size {validity_len} does not match array size {}",
186                offsets.len() - 1
187            );
188        }
189
190        Ok(())
191    }
192}
193
194impl ListArray {
195    pub fn new(elements: ArrayRef, offsets: ArrayRef, validity: Validity) -> Self {
196        Self::try_new(elements, offsets, validity).vortex_expect("ListArray new")
197    }
198
199    pub fn try_new(
200        elements: ArrayRef,
201        offsets: ArrayRef,
202        validity: Validity,
203    ) -> VortexResult<Self> {
204        let nullability = validity.nullability();
205
206        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
207            vortex_bail!(
208                "Expected offsets to be an non-nullable integer type, got {:?}",
209                offsets.dtype()
210            );
211        }
212
213        if offsets.is_empty() {
214            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
215        }
216
217        Self::validate(&elements, &offsets, &validity)?;
218
219        Ok(Self {
220            dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
221            elements,
222            offsets,
223            validity,
224            stats_set: Default::default(),
225        })
226    }
227
228    /// Returns the offset at the given index from the list array.
229    ///
230    /// Panics if the index is out of bounds.
231    pub fn offset_at(&self, index: usize) -> usize {
232        assert!(
233            index <= self.len(),
234            "Index {index} out of bounds 0..={}",
235            self.len()
236        );
237
238        self.offsets()
239            .as_opt::<PrimitiveVTable>()
240            .map(|p| match_each_native_ptype!(p.ptype(), |P| { p.as_slice::<P>()[index].as_() }))
241            .unwrap_or_else(|| {
242                self.offsets()
243                    .scalar_at(index)
244                    .as_primitive()
245                    .as_::<usize>()
246                    .vortex_expect("index must fit in usize")
247            })
248    }
249
250    /// Returns the elements at the given index from the list array.
251    pub fn elements_at(&self, index: usize) -> ArrayRef {
252        let start = self.offset_at(index);
253        let end = self.offset_at(index + 1);
254        self.elements().slice(start, end)
255    }
256
257    /// Returns elements of the list array referenced by the offsets array
258    pub fn sliced_elements(&self) -> ArrayRef {
259        let start = self.offset_at(0);
260        let end = self.offset_at(self.len());
261        self.elements().slice(start, end)
262    }
263
264    /// Returns the offsets array.
265    pub fn offsets(&self) -> &ArrayRef {
266        &self.offsets
267    }
268
269    /// Returns the elements array.
270    pub fn elements(&self) -> &ArrayRef {
271        &self.elements
272    }
273
274    /// Create a copy of this array by adjusting offsets to start at 0 and removing elements not referenced by the offsets
275    pub fn reset_offsets(&self) -> VortexResult<Self> {
276        let elements = self.sliced_elements();
277        let offsets = self.offsets();
278        let first_offset = offsets.scalar_at(0);
279        let adjusted_offsets = sub_scalar(offsets, first_offset)?;
280
281        Self::try_new(elements, adjusted_offsets, self.validity.clone())
282    }
283}
284
285impl ArrayVTable<ListVTable> for ListVTable {
286    fn len(array: &ListArray) -> usize {
287        array.offsets.len().saturating_sub(1)
288    }
289
290    fn dtype(array: &ListArray) -> &DType {
291        &array.dtype
292    }
293
294    fn stats(array: &ListArray) -> StatsSetRef<'_> {
295        array.stats_set.to_ref(array.as_ref())
296    }
297}
298
299impl OperationsVTable<ListVTable> for ListVTable {
300    fn slice(array: &ListArray, start: usize, stop: usize) -> ArrayRef {
301        ListArray::new(
302            array.elements().clone(),
303            array.offsets().slice(start, stop + 1),
304            array.validity().slice(start, stop),
305        )
306        .into_array()
307    }
308
309    fn scalar_at(array: &ListArray, index: usize) -> Scalar {
310        let elem = array.elements_at(index);
311        let scalars: Vec<Scalar> = (0..elem.len()).map(|i| elem.scalar_at(i)).collect();
312
313        Scalar::list(
314            Arc::new(elem.dtype().clone()),
315            scalars,
316            array.dtype().nullability(),
317        )
318    }
319}
320
321impl CanonicalVTable<ListVTable> for ListVTable {
322    fn canonicalize(array: &ListArray) -> VortexResult<Canonical> {
323        Ok(Canonical::List(array.clone()))
324    }
325}
326
327impl ValidityHelper for ListArray {
328    fn validity(&self) -> &Validity {
329        &self.validity
330    }
331}
332
333#[cfg(feature = "test-harness")]
334impl ListArray {
335    /// This is a convenience method to create a list array from an iterator of iterators.
336    /// This method is slow however since each element is first converted to a scalar and then
337    /// appended to the array.
338    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
339        iter: I,
340        dtype: Arc<DType>,
341    ) -> VortexResult<ArrayRef>
342    where
343        I::Item: IntoIterator,
344        <I::Item as IntoIterator>::Item: Into<Scalar>,
345    {
346        let iter = iter.into_iter();
347        let mut builder = ListBuilder::<O>::with_capacity(
348            dtype.clone(),
349            vortex_dtype::Nullability::NonNullable,
350            iter.size_hint().0,
351        );
352
353        for v in iter {
354            let elem = Scalar::list(
355                dtype.clone(),
356                v.into_iter().map(|x| x.into()).collect_vec(),
357                dtype.nullability(),
358            );
359            builder.append_value(elem.as_list())?
360        }
361        Ok(builder.finish())
362    }
363
364    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
365        iter: I,
366        dtype: Arc<DType>,
367    ) -> VortexResult<ArrayRef>
368    where
369        T: IntoIterator,
370        T::Item: Into<Scalar>,
371    {
372        let iter = iter.into_iter();
373        let mut builder = ListBuilder::<O>::with_capacity(
374            dtype.clone(),
375            vortex_dtype::Nullability::Nullable,
376            iter.size_hint().0,
377        );
378
379        for v in iter {
380            if let Some(v) = v {
381                let elem = Scalar::list(
382                    dtype.clone(),
383                    v.into_iter().map(|x| x.into()).collect_vec(),
384                    dtype.nullability(),
385                );
386                builder.append_value(elem.as_list())?
387            } else {
388                builder.append_null()
389            }
390        }
391        Ok(builder.finish())
392    }
393}
394
395#[cfg(test)]
396mod test {
397    use std::sync::Arc;
398
399    use arrow_buffer::BooleanBuffer;
400    use vortex_dtype::Nullability;
401    use vortex_dtype::PType::I32;
402    use vortex_error::VortexUnwrap;
403    use vortex_mask::Mask;
404    use vortex_scalar::Scalar;
405
406    use crate::arrays::list::ListArray;
407    use crate::arrays::{ListVTable, PrimitiveArray};
408    use crate::builders::{ArrayBuilder, ListBuilder};
409    use crate::compute::filter;
410    use crate::validity::Validity;
411    use crate::{Array, IntoArray};
412
413    #[test]
414    fn test_empty_list_array() {
415        let elements = PrimitiveArray::empty::<u32>(Nullability::NonNullable);
416        let offsets = PrimitiveArray::from_iter([0]);
417        let validity = Validity::AllValid;
418
419        let list =
420            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
421
422        assert_eq!(0, list.len());
423    }
424
425    #[test]
426    fn test_simple_list_array() {
427        let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
428        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
429        let validity = Validity::AllValid;
430
431        let list =
432            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
433
434        assert_eq!(
435            Scalar::list(
436                Arc::new(I32.into()),
437                vec![1.into(), 2.into()],
438                Nullability::Nullable
439            ),
440            list.scalar_at(0)
441        );
442        assert_eq!(
443            Scalar::list(
444                Arc::new(I32.into()),
445                vec![3.into(), 4.into()],
446                Nullability::Nullable
447            ),
448            list.scalar_at(1)
449        );
450        assert_eq!(
451            Scalar::list(Arc::new(I32.into()), vec![5.into()], Nullability::Nullable),
452            list.scalar_at(2)
453        );
454    }
455
456    #[test]
457    fn test_simple_list_array_from_iter() {
458        let elements = PrimitiveArray::from_iter([1i32, 2, 3]);
459        let offsets = PrimitiveArray::from_iter([0, 2, 3]);
460        let validity = Validity::NonNullable;
461
462        let list =
463            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
464
465        let list_from_iter =
466            ListArray::from_iter_slow::<u32, _>(vec![vec![1i32, 2], vec![3]], Arc::new(I32.into()))
467                .unwrap();
468
469        assert_eq!(list.len(), list_from_iter.len());
470        assert_eq!(list.scalar_at(0), list_from_iter.scalar_at(0));
471        assert_eq!(list.scalar_at(1), list_from_iter.scalar_at(1));
472    }
473
474    #[test]
475    fn test_simple_list_filter() {
476        let elements = PrimitiveArray::from_option_iter([None, Some(2), Some(3), Some(4), Some(5)]);
477        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
478        let validity = Validity::AllValid;
479
480        let list = ListArray::try_new(elements.into_array(), offsets.into_array(), validity)
481            .unwrap()
482            .into_array();
483
484        let filtered = filter(
485            &list,
486            &Mask::from(BooleanBuffer::from(vec![false, true, true])),
487        );
488
489        assert!(filtered.is_ok())
490    }
491
492    #[test]
493    fn test_offset_to_0() {
494        let mut builder =
495            ListBuilder::<u32>::with_capacity(Arc::new(I32.into()), Nullability::NonNullable, 5);
496        builder
497            .append_value(
498                Scalar::list(
499                    Arc::new(I32.into()),
500                    vec![1.into(), 2.into(), 3.into()],
501                    Nullability::NonNullable,
502                )
503                .as_list(),
504            )
505            .vortex_unwrap();
506        builder
507            .append_value(
508                Scalar::list(
509                    Arc::new(I32.into()),
510                    vec![4.into(), 5.into(), 6.into()],
511                    Nullability::NonNullable,
512                )
513                .as_list(),
514            )
515            .vortex_unwrap();
516        builder
517            .append_value(
518                Scalar::list(
519                    Arc::new(I32.into()),
520                    vec![7.into(), 8.into(), 9.into()],
521                    Nullability::NonNullable,
522                )
523                .as_list(),
524            )
525            .vortex_unwrap();
526        builder
527            .append_value(
528                Scalar::list(
529                    Arc::new(I32.into()),
530                    vec![10.into(), 11.into(), 12.into()],
531                    Nullability::NonNullable,
532                )
533                .as_list(),
534            )
535            .vortex_unwrap();
536        builder
537            .append_value(
538                Scalar::list(
539                    Arc::new(I32.into()),
540                    vec![13.into(), 14.into(), 15.into()],
541                    Nullability::NonNullable,
542                )
543                .as_list(),
544            )
545            .vortex_unwrap();
546        let list = builder.finish().slice(2, 4);
547        let list = list.as_::<ListVTable>().reset_offsets().unwrap();
548        assert_eq!(list.len(), 2);
549        assert_eq!(list.offsets().len(), 3);
550        assert_eq!(list.elements().len(), 6);
551        assert_eq!(list.offsets().scalar_at(0), 0u32.into());
552    }
553}