vortex_array/arrays/list/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod compute;
5mod serde;
6
7use std::sync::Arc;
8
9#[cfg(feature = "test-harness")]
10use itertools::Itertools;
11use num_traits::{AsPrimitive, NumCast, PrimInt};
12use vortex_dtype::{DType, NativePType, match_each_integer_ptype, match_each_native_ptype};
13use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_ensure};
14use vortex_scalar::Scalar;
15
16use crate::arrays::PrimitiveVTable;
17#[cfg(feature = "test-harness")]
18use crate::builders::{ArrayBuilder, ListBuilder};
19use crate::compute::{min_max, sub_scalar};
20use crate::stats::{ArrayStats, StatsSetRef};
21use crate::validity::Validity;
22use crate::vtable::{
23    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
24    ValidityVTableFromValidityHelper,
25};
26use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
27
28vtable!(List);
29
30impl VTable for ListVTable {
31    type Array = ListArray;
32    type Encoding = ListEncoding;
33
34    type ArrayVTable = Self;
35    type CanonicalVTable = Self;
36    type OperationsVTable = Self;
37    type ValidityVTable = ValidityVTableFromValidityHelper;
38    type VisitorVTable = Self;
39    type ComputeVTable = NotSupported;
40    type EncodeVTable = NotSupported;
41    type PipelineVTable = NotSupported;
42    type SerdeVTable = Self;
43
44    fn id(_encoding: &Self::Encoding) -> EncodingId {
45        EncodingId::new_ref("vortex.list")
46    }
47
48    fn encoding(_array: &Self::Array) -> EncodingRef {
49        EncodingRef::new_ref(ListEncoding.as_ref())
50    }
51}
52
53/// A list array that stores variable-length lists of elements, similar to `Vec<Vec<T>>`.
54///
55/// This mirrors the Apache Arrow List array encoding and provides efficient storage
56/// for nested data where each row contains a list of elements of the same type.
57///
58/// ## Data Layout
59///
60/// The list array uses an offset-based encoding:
61/// - **Elements array**: A flat array containing all list elements concatenated together
62/// - **Offsets array**: Integer array where `offsets[i]` is an (inclusive) start index into
63///   the **elements** and `offsets[i+1]` is the (exclusive) stop index for the `i`th list.
64/// - **Validity**: Optional mask indicating which lists are null
65///
66/// This allows for excellent cascading compression of the elements and offsets, as similar values
67/// are clustered together and the offsets have a predictable pattern and small deltas between
68/// consecutive elements.
69///
70/// ## Offset Semantics
71///
72/// - Offsets must be non-nullable integers (i32, i64, etc.)
73/// - Offsets array has length `n+1` where `n` is the number of lists
74/// - List `i` contains elements from `elements[offsets[i]..offsets[i+1]]`
75/// - Offsets must be monotonically increasing
76///
77/// # Examples
78///
79/// ```
80/// use vortex_array::arrays::{ListArray, PrimitiveArray};
81/// use vortex_array::validity::Validity;
82/// use vortex_array::IntoArray;
83/// use std::sync::Arc;
84///
85/// // Create a list array representing [[1, 2], [3, 4, 5], []]
86/// let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
87/// let offsets = PrimitiveArray::from_iter([0u32, 2, 5, 5]); // 3 lists
88///
89/// let list_array = ListArray::try_new(
90///     elements.into_array(),
91///     offsets.into_array(),
92///     Validity::NonNullable,
93/// ).unwrap();
94///
95/// assert_eq!(list_array.len(), 3);
96///
97/// // Access individual lists
98/// let first_list = list_array.elements_at(0);
99/// assert_eq!(first_list.len(), 2); // [1, 2]
100///
101/// let third_list = list_array.elements_at(2);
102/// assert!(third_list.is_empty()); // []
103/// ```
104#[derive(Clone, Debug)]
105pub struct ListArray {
106    dtype: DType,
107    elements: ArrayRef,
108    offsets: ArrayRef,
109    validity: Validity,
110    stats_set: ArrayStats,
111}
112
113#[derive(Clone, Debug)]
114pub struct ListEncoding;
115
116pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
117
118impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
119
120impl ListArray {
121    pub fn new(elements: ArrayRef, offsets: ArrayRef, validity: Validity) -> Self {
122        Self::try_new(elements, offsets, validity).vortex_expect("ListArray new")
123    }
124
125    pub fn try_new(
126        elements: ArrayRef,
127        offsets: ArrayRef,
128        validity: Validity,
129    ) -> VortexResult<Self> {
130        let nullability = validity.nullability();
131
132        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
133            vortex_bail!(
134                "Expected offsets to be an non-nullable integer type, got {:?}",
135                offsets.dtype()
136            );
137        }
138
139        if offsets.is_empty() {
140            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
141        }
142
143        Self::validate(&elements, &offsets, &validity)?;
144
145        Ok(Self {
146            dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
147            elements,
148            offsets,
149            validity,
150            stats_set: Default::default(),
151        })
152    }
153
154    /// Returns the offset at the given index from the list array.
155    ///
156    /// Panics if the index is out of bounds.
157    pub fn offset_at(&self, index: usize) -> usize {
158        assert!(
159            index <= self.len(),
160            "Index {index} out of bounds 0..={}",
161            self.len()
162        );
163
164        self.offsets()
165            .as_opt::<PrimitiveVTable>()
166            .map(|p| match_each_native_ptype!(p.ptype(), |P| { p.as_slice::<P>()[index].as_() }))
167            .unwrap_or_else(|| {
168                self.offsets()
169                    .scalar_at(index)
170                    .as_primitive()
171                    .as_::<usize>()
172                    .vortex_expect("index must fit in usize")
173            })
174    }
175
176    /// Returns the elements at the given index from the list array.
177    pub fn elements_at(&self, index: usize) -> ArrayRef {
178        let start = self.offset_at(index);
179        let end = self.offset_at(index + 1);
180        self.elements().slice(start, end)
181    }
182
183    /// Returns elements of the list array referenced by the offsets array
184    pub fn sliced_elements(&self) -> ArrayRef {
185        let start = self.offset_at(0);
186        let end = self.offset_at(self.len());
187        self.elements().slice(start, end)
188    }
189
190    /// Returns the offsets array.
191    pub fn offsets(&self) -> &ArrayRef {
192        &self.offsets
193    }
194
195    /// Returns the elements array.
196    pub fn elements(&self) -> &ArrayRef {
197        &self.elements
198    }
199
200    /// Create a copy of this array by adjusting offsets to start at 0 and removing elements not referenced by the offsets
201    pub fn reset_offsets(&self) -> VortexResult<Self> {
202        let elements = self.sliced_elements();
203        let offsets = self.offsets();
204        let first_offset = offsets.scalar_at(0);
205        let adjusted_offsets = sub_scalar(offsets, first_offset)?;
206
207        Self::try_new(elements, adjusted_offsets, self.validity.clone())
208    }
209
210    /// A list is valid if the:
211    /// - offsets start at a value in elements
212    /// - offsets are sorted
213    /// - the final offset points to an element in the elements list, pointing to zero
214    ///   if elements are empty.
215    /// - final_offset >= start_offset
216    /// - The size of the validity is the size-1 of the offset array
217    fn validate(
218        elements: &dyn Array,
219        offsets: &dyn Array,
220        validity: &Validity,
221    ) -> VortexResult<()> {
222        // Offsets must be of integer type, and cannot go lower than 0.
223        vortex_ensure!(
224            offsets.dtype().is_int() && !offsets.dtype().is_nullable(),
225            "offsets have invalid type {}",
226            offsets.dtype()
227        );
228
229        // We can safely unwrap the DType as primitive now
230        let offsets_ptype = offsets.dtype().as_ptype();
231
232        // Offsets must be sorted (but not strictly sorted, zero-length lists are allowed)
233        if let Some(is_sorted) = offsets.statistics().compute_is_sorted() {
234            vortex_ensure!(is_sorted, "offsets must be sorted");
235        } else {
236            vortex_bail!("offsets must report is_sorted statistic");
237        }
238
239        // Validate that offsets min is non-negative, and max does not exceed the length of
240        // the elements array.
241        if let Some(min_max) = min_max(offsets)? {
242            match_each_integer_ptype!(offsets_ptype, |P| {
243                let max_offset = <P as NumCast>::from(elements.len()).unwrap_or(P::MAX);
244
245                #[allow(clippy::absurd_extreme_comparisons, unused_comparisons)]
246                {
247                    if let Some(min) = min_max.min.as_primitive().as_::<P>() {
248                        vortex_ensure!(
249                            min >= 0 && min <= max_offset,
250                            "offsets minimum {min} outside valid range [0, {max_offset}]"
251                        );
252                    }
253
254                    if let Some(max) = min_max.max.as_primitive().as_::<P>() {
255                        vortex_ensure!(
256                            max >= 0 && max <= max_offset,
257                            "offsets maximum {max} outside valid range [0, {max_offset}]"
258                        )
259                    }
260                }
261            })
262        } else {
263            // TODO(aduffy): fallback to slower validation pathway?
264            vortex_bail!(
265                "offsets array with encoding {} must support min_max compute function",
266                offsets.encoding_id()
267            );
268        };
269
270        // If a validity array is present, it must be the same length as the ListArray
271        if let Some(validity_len) = validity.maybe_len() {
272            vortex_ensure!(
273                validity_len == offsets.len() - 1,
274                "validity with size {validity_len} does not match array size {}",
275                offsets.len() - 1
276            );
277        }
278
279        Ok(())
280    }
281}
282
283impl ArrayVTable<ListVTable> for ListVTable {
284    fn len(array: &ListArray) -> usize {
285        array.offsets.len().saturating_sub(1)
286    }
287
288    fn dtype(array: &ListArray) -> &DType {
289        &array.dtype
290    }
291
292    fn stats(array: &ListArray) -> StatsSetRef<'_> {
293        array.stats_set.to_ref(array.as_ref())
294    }
295}
296
297impl OperationsVTable<ListVTable> for ListVTable {
298    fn slice(array: &ListArray, start: usize, stop: usize) -> ArrayRef {
299        ListArray::new(
300            array.elements().clone(),
301            array.offsets().slice(start, stop + 1),
302            array.validity().slice(start, stop),
303        )
304        .into_array()
305    }
306
307    fn scalar_at(array: &ListArray, index: usize) -> Scalar {
308        let elem = array.elements_at(index);
309        let scalars: Vec<Scalar> = (0..elem.len()).map(|i| elem.scalar_at(i)).collect();
310
311        Scalar::list(
312            Arc::new(elem.dtype().clone()),
313            scalars,
314            array.dtype().nullability(),
315        )
316    }
317}
318
319impl CanonicalVTable<ListVTable> for ListVTable {
320    fn canonicalize(array: &ListArray) -> VortexResult<Canonical> {
321        Ok(Canonical::List(array.clone()))
322    }
323}
324
325impl ValidityHelper for ListArray {
326    fn validity(&self) -> &Validity {
327        &self.validity
328    }
329}
330
331#[cfg(feature = "test-harness")]
332impl ListArray {
333    /// This is a convenience method to create a list array from an iterator of iterators.
334    /// This method is slow however since each element is first converted to a scalar and then
335    /// appended to the array.
336    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
337        iter: I,
338        dtype: Arc<DType>,
339    ) -> VortexResult<ArrayRef>
340    where
341        I::Item: IntoIterator,
342        <I::Item as IntoIterator>::Item: Into<Scalar>,
343    {
344        let iter = iter.into_iter();
345        let mut builder = ListBuilder::<O>::with_capacity(
346            dtype.clone(),
347            vortex_dtype::Nullability::NonNullable,
348            iter.size_hint().0,
349        );
350
351        for v in iter {
352            let elem = Scalar::list(
353                dtype.clone(),
354                v.into_iter().map(|x| x.into()).collect_vec(),
355                dtype.nullability(),
356            );
357            builder.append_value(elem.as_list())?
358        }
359        Ok(builder.finish())
360    }
361
362    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
363        iter: I,
364        dtype: Arc<DType>,
365    ) -> VortexResult<ArrayRef>
366    where
367        T: IntoIterator,
368        T::Item: Into<Scalar>,
369    {
370        let iter = iter.into_iter();
371        let mut builder = ListBuilder::<O>::with_capacity(
372            dtype.clone(),
373            vortex_dtype::Nullability::Nullable,
374            iter.size_hint().0,
375        );
376
377        for v in iter {
378            if let Some(v) = v {
379                let elem = Scalar::list(
380                    dtype.clone(),
381                    v.into_iter().map(|x| x.into()).collect_vec(),
382                    dtype.nullability(),
383                );
384                builder.append_value(elem.as_list())?
385            } else {
386                builder.append_null()
387            }
388        }
389        Ok(builder.finish())
390    }
391}
392
393#[cfg(test)]
394mod tests;