vortex_array/arrays/list/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod compute;
5mod serde;
6
7use std::sync::Arc;
8
9use itertools::Itertools;
10use num_traits::{AsPrimitive, PrimInt};
11use vortex_dtype::{DType, NativePType, match_each_native_ptype};
12use vortex_error::{VortexResult, VortexUnwrap, vortex_bail};
13use vortex_scalar::Scalar;
14
15use crate::arrays::PrimitiveVTable;
16#[cfg(feature = "test-harness")]
17use crate::builders::{ArrayBuilder, ListBuilder};
18use crate::compute::sub_scalar;
19use crate::stats::{ArrayStats, StatsSetRef};
20use crate::validity::Validity;
21use crate::vtable::{
22    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
23    ValidityVTableFromValidityHelper,
24};
25use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
26
27vtable!(List);
28
29impl VTable for ListVTable {
30    type Array = ListArray;
31    type Encoding = ListEncoding;
32
33    type ArrayVTable = Self;
34    type CanonicalVTable = Self;
35    type OperationsVTable = Self;
36    type ValidityVTable = ValidityVTableFromValidityHelper;
37    type VisitorVTable = Self;
38    type ComputeVTable = NotSupported;
39    type EncodeVTable = NotSupported;
40    type SerdeVTable = Self;
41
42    fn id(_encoding: &Self::Encoding) -> EncodingId {
43        EncodingId::new_ref("vortex.list")
44    }
45
46    fn encoding(_array: &Self::Array) -> EncodingRef {
47        EncodingRef::new_ref(ListEncoding.as_ref())
48    }
49}
50
51/// A list array that stores variable-length lists of elements, similar to `Vec<Vec<T>>`.
52///
53/// This mirrors the Apache Arrow List array encoding and provides efficient storage
54/// for nested data where each row contains a list of elements of the same type.
55///
56/// ## Data Layout
57///
58/// The list array uses an offset-based encoding:
59/// - **Elements array**: A flat array containing all list elements concatenated together
60/// - **Offsets array**: Integer array where `offsets[i]` is an (inclusive) start index into
61///   the **elements** and `offsets[i+1]` is the (exclusive) stop index for the `i`th list.
62/// - **Validity**: Optional mask indicating which lists are null
63///
64/// This allows for excellent cascading compression of the elements and offsets, as similar values
65/// are clustered together and the offsets have a predictable pattern and small deltas between
66/// consecutive elements.
67///
68/// ## Offset Semantics
69///
70/// - Offsets must be non-nullable integers (i32, i64, etc.)
71/// - Offsets array has length `n+1` where `n` is the number of lists
72/// - List `i` contains elements from `elements[offsets[i]..offsets[i+1]]`  
73/// - Offsets must be monotonically increasing
74///
75/// # Examples
76///
77/// ```
78/// use vortex_array::arrays::{ListArray, PrimitiveArray};
79/// use vortex_array::validity::Validity;
80/// use vortex_array::IntoArray;
81/// use std::sync::Arc;
82///
83/// // Create a list array representing [[1, 2], [3, 4, 5], []]
84/// let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
85/// let offsets = PrimitiveArray::from_iter([0u32, 2, 5, 5]); // 3 lists
86///
87/// let list_array = ListArray::try_new(
88///     elements.into_array(),
89///     offsets.into_array(),
90///     Validity::NonNullable,
91/// ).unwrap();
92///
93/// assert_eq!(list_array.len(), 3);
94///
95/// // Access individual lists
96/// let first_list = list_array.elements_at(0).unwrap();
97/// assert_eq!(first_list.len(), 2); // [1, 2]
98///
99/// let third_list = list_array.elements_at(2).unwrap();
100/// assert!(third_list.is_empty()); // []
101/// ```
102#[derive(Clone, Debug)]
103pub struct ListArray {
104    dtype: DType,
105    elements: ArrayRef,
106    offsets: ArrayRef,
107    validity: Validity,
108    stats_set: ArrayStats,
109}
110
111#[derive(Clone, Debug)]
112pub struct ListEncoding;
113
114pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
115
116impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
117
118// A list is valid if the:
119// - offsets start at a value in elements
120// - offsets are sorted
121// - the final offset points to an element in the elements list, pointing to zero
122//   if elements are empty.
123// - final_offset >= start_offset
124// - The size of the validity is the size-1 of the offset array
125
126impl ListArray {
127    pub fn try_new(
128        elements: ArrayRef,
129        offsets: ArrayRef,
130        validity: Validity,
131    ) -> VortexResult<Self> {
132        let nullability = validity.nullability();
133
134        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
135            vortex_bail!(
136                "Expected offsets to be an non-nullable integer type, got {:?}",
137                offsets.dtype()
138            );
139        }
140
141        if offsets.is_empty() {
142            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
143        }
144
145        Ok(Self {
146            dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
147            elements,
148            offsets,
149            validity,
150            stats_set: Default::default(),
151        })
152    }
153
154    /// Returns the offset at the given index from the list array.
155    ///
156    /// Panics if the index is out of bounds.
157    pub fn offset_at(&self, index: usize) -> usize {
158        assert!(
159            index <= self.len(),
160            "Index {index} out of bounds 0..={}",
161            self.len()
162        );
163
164        self.offsets()
165            .as_opt::<PrimitiveVTable>()
166            .map(|p| {
167                Ok(match_each_native_ptype!(p.ptype(), |P| {
168                    p.as_slice::<P>()[index].as_()
169                }))
170            })
171            .unwrap_or_else(|| {
172                self.offsets()
173                    .scalar_at(index)
174                    .and_then(|s| usize::try_from(&s))
175            })
176            .vortex_unwrap()
177    }
178
179    /// Returns the elements at the given index from the list array.
180    pub fn elements_at(&self, index: usize) -> VortexResult<ArrayRef> {
181        let start = self.offset_at(index);
182        let end = self.offset_at(index + 1);
183        self.elements().slice(start, end)
184    }
185
186    /// Returns elements of the list array referenced by the offsets array
187    pub fn sliced_elements(&self) -> VortexResult<ArrayRef> {
188        let start = self.offset_at(0);
189        let end = self.offset_at(self.len());
190        self.elements().slice(start, end)
191    }
192
193    /// Returns the offsets array.
194    pub fn offsets(&self) -> &ArrayRef {
195        &self.offsets
196    }
197
198    /// Returns the elements array.
199    pub fn elements(&self) -> &ArrayRef {
200        &self.elements
201    }
202
203    /// Create a copy of this array by adjusting offsets to start at 0 and removing elements not referenced by the offsets
204    pub fn reset_offsets(&self) -> VortexResult<Self> {
205        let elements = self.sliced_elements()?;
206        let offsets = self.offsets();
207        let first_offset = offsets.scalar_at(0)?;
208        let adjusted_offsets = sub_scalar(offsets, first_offset)?;
209
210        Self::try_new(elements, adjusted_offsets, self.validity.clone())
211    }
212}
213
214impl ArrayVTable<ListVTable> for ListVTable {
215    fn len(array: &ListArray) -> usize {
216        array.offsets.len().saturating_sub(1)
217    }
218
219    fn dtype(array: &ListArray) -> &DType {
220        &array.dtype
221    }
222
223    fn stats(array: &ListArray) -> StatsSetRef<'_> {
224        array.stats_set.to_ref(array.as_ref())
225    }
226}
227
228impl OperationsVTable<ListVTable> for ListVTable {
229    fn slice(array: &ListArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
230        Ok(ListArray::try_new(
231            array.elements().clone(),
232            array.offsets().slice(start, stop + 1)?,
233            array.validity().slice(start, stop)?,
234        )?
235        .into_array())
236    }
237
238    fn scalar_at(array: &ListArray, index: usize) -> VortexResult<Scalar> {
239        let elem = array.elements_at(index)?;
240        let scalars: Vec<Scalar> = (0..elem.len()).map(|i| elem.scalar_at(i)).try_collect()?;
241
242        Ok(Scalar::list(
243            Arc::new(elem.dtype().clone()),
244            scalars,
245            array.dtype().nullability(),
246        ))
247    }
248}
249
250impl CanonicalVTable<ListVTable> for ListVTable {
251    fn canonicalize(array: &ListArray) -> VortexResult<Canonical> {
252        Ok(Canonical::List(array.clone()))
253    }
254}
255
256impl ValidityHelper for ListArray {
257    fn validity(&self) -> &Validity {
258        &self.validity
259    }
260}
261
262#[cfg(feature = "test-harness")]
263impl ListArray {
264    /// This is a convenience method to create a list array from an iterator of iterators.
265    /// This method is slow however since each element is first converted to a scalar and then
266    /// appended to the array.
267    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
268        iter: I,
269        dtype: Arc<DType>,
270    ) -> VortexResult<ArrayRef>
271    where
272        I::Item: IntoIterator,
273        <I::Item as IntoIterator>::Item: Into<Scalar>,
274    {
275        let iter = iter.into_iter();
276        let mut builder = ListBuilder::<O>::with_capacity(
277            dtype.clone(),
278            vortex_dtype::Nullability::NonNullable,
279            iter.size_hint().0,
280        );
281
282        for v in iter {
283            let elem = Scalar::list(
284                dtype.clone(),
285                v.into_iter().map(|x| x.into()).collect_vec(),
286                dtype.nullability(),
287            );
288            builder.append_value(elem.as_list())?
289        }
290        Ok(builder.finish())
291    }
292
293    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
294        iter: I,
295        dtype: Arc<DType>,
296    ) -> VortexResult<ArrayRef>
297    where
298        T: IntoIterator,
299        T::Item: Into<Scalar>,
300    {
301        let iter = iter.into_iter();
302        let mut builder = ListBuilder::<O>::with_capacity(
303            dtype.clone(),
304            vortex_dtype::Nullability::Nullable,
305            iter.size_hint().0,
306        );
307
308        for v in iter {
309            if let Some(v) = v {
310                let elem = Scalar::list(
311                    dtype.clone(),
312                    v.into_iter().map(|x| x.into()).collect_vec(),
313                    dtype.nullability(),
314                );
315                builder.append_value(elem.as_list())?
316            } else {
317                builder.append_null()
318            }
319        }
320        Ok(builder.finish())
321    }
322}
323
324#[cfg(test)]
325mod test {
326    use std::sync::Arc;
327
328    use arrow_buffer::BooleanBuffer;
329    use vortex_dtype::Nullability;
330    use vortex_dtype::PType::I32;
331    use vortex_error::VortexUnwrap;
332    use vortex_mask::Mask;
333    use vortex_scalar::Scalar;
334
335    use crate::arrays::list::ListArray;
336    use crate::arrays::{ListVTable, PrimitiveArray};
337    use crate::builders::{ArrayBuilder, ListBuilder};
338    use crate::compute::filter;
339    use crate::validity::Validity;
340    use crate::{Array, IntoArray};
341
342    #[test]
343    fn test_empty_list_array() {
344        let elements = PrimitiveArray::empty::<u32>(Nullability::NonNullable);
345        let offsets = PrimitiveArray::from_iter([0]);
346        let validity = Validity::AllValid;
347
348        let list =
349            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
350
351        assert_eq!(0, list.len());
352    }
353
354    #[test]
355    fn test_simple_list_array() {
356        let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
357        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
358        let validity = Validity::AllValid;
359
360        let list =
361            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
362
363        assert_eq!(
364            Scalar::list(
365                Arc::new(I32.into()),
366                vec![1.into(), 2.into()],
367                Nullability::Nullable
368            ),
369            list.scalar_at(0).unwrap()
370        );
371        assert_eq!(
372            Scalar::list(
373                Arc::new(I32.into()),
374                vec![3.into(), 4.into()],
375                Nullability::Nullable
376            ),
377            list.scalar_at(1).unwrap()
378        );
379        assert_eq!(
380            Scalar::list(Arc::new(I32.into()), vec![5.into()], Nullability::Nullable),
381            list.scalar_at(2).unwrap()
382        );
383    }
384
385    #[test]
386    fn test_simple_list_array_from_iter() {
387        let elements = PrimitiveArray::from_iter([1i32, 2, 3]);
388        let offsets = PrimitiveArray::from_iter([0, 2, 3]);
389        let validity = Validity::NonNullable;
390
391        let list =
392            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
393
394        let list_from_iter =
395            ListArray::from_iter_slow::<u32, _>(vec![vec![1i32, 2], vec![3]], Arc::new(I32.into()))
396                .unwrap();
397
398        assert_eq!(list.len(), list_from_iter.len());
399        assert_eq!(
400            list.scalar_at(0).unwrap(),
401            list_from_iter.scalar_at(0).unwrap()
402        );
403        assert_eq!(
404            list.scalar_at(1).unwrap(),
405            list_from_iter.scalar_at(1).unwrap()
406        );
407    }
408
409    #[test]
410    fn test_simple_list_filter() {
411        let elements = PrimitiveArray::from_option_iter([None, Some(2), Some(3), Some(4), Some(5)]);
412        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
413        let validity = Validity::AllValid;
414
415        let list = ListArray::try_new(elements.into_array(), offsets.into_array(), validity)
416            .unwrap()
417            .into_array();
418
419        let filtered = filter(
420            &list,
421            &Mask::from(BooleanBuffer::from(vec![false, true, true])),
422        );
423
424        assert!(filtered.is_ok())
425    }
426
427    #[test]
428    fn test_offset_to_0() {
429        let mut builder =
430            ListBuilder::<u32>::with_capacity(Arc::new(I32.into()), Nullability::NonNullable, 5);
431        builder
432            .append_value(
433                Scalar::list(
434                    Arc::new(I32.into()),
435                    vec![1.into(), 2.into(), 3.into()],
436                    Nullability::NonNullable,
437                )
438                .as_list(),
439            )
440            .vortex_unwrap();
441        builder
442            .append_value(
443                Scalar::list(
444                    Arc::new(I32.into()),
445                    vec![4.into(), 5.into(), 6.into()],
446                    Nullability::NonNullable,
447                )
448                .as_list(),
449            )
450            .vortex_unwrap();
451        builder
452            .append_value(
453                Scalar::list(
454                    Arc::new(I32.into()),
455                    vec![7.into(), 8.into(), 9.into()],
456                    Nullability::NonNullable,
457                )
458                .as_list(),
459            )
460            .vortex_unwrap();
461        builder
462            .append_value(
463                Scalar::list(
464                    Arc::new(I32.into()),
465                    vec![10.into(), 11.into(), 12.into()],
466                    Nullability::NonNullable,
467                )
468                .as_list(),
469            )
470            .vortex_unwrap();
471        builder
472            .append_value(
473                Scalar::list(
474                    Arc::new(I32.into()),
475                    vec![13.into(), 14.into(), 15.into()],
476                    Nullability::NonNullable,
477                )
478                .as_list(),
479            )
480            .vortex_unwrap();
481        let list = builder.finish().slice(2, 4).vortex_unwrap();
482        let list = list.as_::<ListVTable>().reset_offsets().unwrap();
483        assert_eq!(list.len(), 2);
484        assert_eq!(list.offsets().len(), 3);
485        assert_eq!(list.elements().len(), 6);
486        assert_eq!(list.offsets().scalar_at(0).unwrap(), 0u32.into());
487    }
488}