vortex_array/arrays/list/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod compute;
5mod serde;
6
7use std::ops::Range;
8use std::sync::Arc;
9
10#[cfg(feature = "test-harness")]
11use itertools::Itertools;
12use num_traits::{AsPrimitive, PrimInt};
13use vortex_dtype::{DType, NativePType, match_each_integer_ptype, match_each_native_ptype};
14use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_ensure};
15use vortex_scalar::Scalar;
16
17use crate::arrays::PrimitiveVTable;
18#[cfg(feature = "test-harness")]
19use crate::builders::{ArrayBuilder, ListBuilder};
20use crate::compute::{min_max, sub_scalar};
21use crate::stats::{ArrayStats, StatsSetRef};
22use crate::validity::Validity;
23use crate::vtable::{
24    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
25    ValidityVTableFromValidityHelper,
26};
27use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
28
29vtable!(List);
30
31impl VTable for ListVTable {
32    type Array = ListArray;
33    type Encoding = ListEncoding;
34
35    type ArrayVTable = Self;
36    type CanonicalVTable = Self;
37    type OperationsVTable = Self;
38    type ValidityVTable = ValidityVTableFromValidityHelper;
39    type VisitorVTable = Self;
40    type ComputeVTable = NotSupported;
41    type EncodeVTable = NotSupported;
42    type PipelineVTable = NotSupported;
43    type SerdeVTable = Self;
44
45    fn id(_encoding: &Self::Encoding) -> EncodingId {
46        EncodingId::new_ref("vortex.list")
47    }
48
49    fn encoding(_array: &Self::Array) -> EncodingRef {
50        EncodingRef::new_ref(ListEncoding.as_ref())
51    }
52}
53
54/// A list array that stores variable-length lists of elements, similar to `Vec<Vec<T>>`.
55///
56/// This mirrors the Apache Arrow List array encoding and provides efficient storage
57/// for nested data where each row contains a list of elements of the same type.
58///
59/// ## Data Layout
60///
61/// The list array uses an offset-based encoding:
62/// - **Elements array**: A flat array containing all list elements concatenated together
63/// - **Offsets array**: Integer array where `offsets[i]` is an (inclusive) start index into
64///   the **elements** and `offsets[i+1]` is the (exclusive) stop index for the `i`th list.
65/// - **Validity**: Optional mask indicating which lists are null
66///
67/// This allows for excellent cascading compression of the elements and offsets, as similar values
68/// are clustered together and the offsets have a predictable pattern and small deltas between
69/// consecutive elements.
70///
71/// ## Offset Semantics
72///
73/// - Offsets must be non-nullable integers (i32, i64, etc.)
74/// - Offsets array has length `n+1` where `n` is the number of lists
75/// - List `i` contains elements from `elements[offsets[i]..offsets[i+1]]`
76/// - Offsets must be monotonically increasing
77///
78/// # Examples
79///
80/// ```
81/// use vortex_array::arrays::{ListArray, PrimitiveArray};
82/// use vortex_array::validity::Validity;
83/// use vortex_array::IntoArray;
84/// use std::sync::Arc;
85///
86/// // Create a list array representing [[1, 2], [3, 4, 5], []]
87/// let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
88/// let offsets = PrimitiveArray::from_iter([0u32, 2, 5, 5]); // 3 lists
89///
90/// let list_array = ListArray::try_new(
91///     elements.into_array(),
92///     offsets.into_array(),
93///     Validity::NonNullable,
94/// ).unwrap();
95///
96/// assert_eq!(list_array.len(), 3);
97///
98/// // Access individual lists
99/// let first_list = list_array.elements_at(0);
100/// assert_eq!(first_list.len(), 2); // [1, 2]
101///
102/// let third_list = list_array.elements_at(2);
103/// assert!(third_list.is_empty()); // []
104/// ```
105#[derive(Clone, Debug)]
106pub struct ListArray {
107    dtype: DType,
108    elements: ArrayRef,
109    offsets: ArrayRef,
110    validity: Validity,
111    stats_set: ArrayStats,
112}
113
114#[derive(Clone, Debug)]
115pub struct ListEncoding;
116
117pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
118
119impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
120
121impl ListArray {
122    pub fn new(elements: ArrayRef, offsets: ArrayRef, validity: Validity) -> Self {
123        Self::try_new(elements, offsets, validity).vortex_expect("ListArray new")
124    }
125
126    pub fn try_new(
127        elements: ArrayRef,
128        offsets: ArrayRef,
129        validity: Validity,
130    ) -> VortexResult<Self> {
131        let nullability = validity.nullability();
132
133        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
134            vortex_bail!(
135                "Expected offsets to be an non-nullable integer type, got {:?}",
136                offsets.dtype()
137            );
138        }
139
140        if offsets.is_empty() {
141            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
142        }
143
144        Self::validate(&elements, &offsets, &validity)?;
145
146        Ok(Self {
147            dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
148            elements,
149            offsets,
150            validity,
151            stats_set: Default::default(),
152        })
153    }
154
155    /// Returns the offset at the given index from the list array.
156    ///
157    /// Panics if the index is out of bounds.
158    pub fn offset_at(&self, index: usize) -> usize {
159        assert!(
160            index <= self.len(),
161            "Index {index} out of bounds 0..={}",
162            self.len()
163        );
164
165        self.offsets()
166            .as_opt::<PrimitiveVTable>()
167            .map(|p| match_each_native_ptype!(p.ptype(), |P| { p.as_slice::<P>()[index].as_() }))
168            .unwrap_or_else(|| {
169                self.offsets()
170                    .scalar_at(index)
171                    .as_primitive()
172                    .as_::<usize>()
173                    .vortex_expect("index must fit in usize")
174            })
175    }
176
177    /// Returns the elements at the given index from the list array.
178    pub fn elements_at(&self, index: usize) -> ArrayRef {
179        let start = self.offset_at(index);
180        let end = self.offset_at(index + 1);
181        self.elements().slice(start..end)
182    }
183
184    /// Returns elements of the list array referenced by the offsets array
185    pub fn sliced_elements(&self) -> ArrayRef {
186        let start = self.offset_at(0);
187        let end = self.offset_at(self.len());
188        self.elements().slice(start..end)
189    }
190
191    /// Returns the offsets array.
192    pub fn offsets(&self) -> &ArrayRef {
193        &self.offsets
194    }
195
196    /// Returns the elements array.
197    pub fn elements(&self) -> &ArrayRef {
198        &self.elements
199    }
200
201    /// Create a copy of this array by adjusting offsets to start at 0 and removing elements not referenced by the offsets
202    pub fn reset_offsets(&self) -> VortexResult<Self> {
203        let elements = self.sliced_elements();
204        let offsets = self.offsets();
205        let first_offset = offsets.scalar_at(0);
206        let adjusted_offsets = sub_scalar(offsets, first_offset)?;
207
208        Self::try_new(elements, adjusted_offsets, self.validity.clone())
209    }
210
211    /// A list is valid if the:
212    /// - offsets start at a value in elements
213    /// - offsets are sorted
214    /// - the final offset points to an element in the elements list, pointing to zero
215    ///   if elements are empty.
216    /// - final_offset >= start_offset
217    /// - The size of the validity is the size-1 of the offset array
218    fn validate(
219        elements: &dyn Array,
220        offsets: &dyn Array,
221        validity: &Validity,
222    ) -> VortexResult<()> {
223        // Offsets must be of integer type, and cannot go lower than 0.
224        vortex_ensure!(
225            offsets.dtype().is_int() && !offsets.dtype().is_nullable(),
226            "offsets have invalid type {}",
227            offsets.dtype()
228        );
229
230        // We can safely unwrap the DType as primitive now
231        let offsets_ptype = offsets.dtype().as_ptype();
232
233        // Offsets must be sorted (but not strictly sorted, zero-length lists are allowed)
234        if let Some(is_sorted) = offsets.statistics().compute_is_sorted() {
235            vortex_ensure!(is_sorted, "offsets must be sorted");
236        } else {
237            vortex_bail!("offsets must report is_sorted statistic");
238        }
239
240        // Validate that offsets min is non-negative, and max does not exceed the length of
241        // the elements array.
242        if let Some(min_max) = min_max(offsets)? {
243            match_each_integer_ptype!(offsets_ptype, |P| {
244                let max_offset = P::try_from(offsets.scalar_at(offsets.len() - 1))
245                    .vortex_expect("Offsets type must fit offsets values");
246
247                #[allow(clippy::absurd_extreme_comparisons, unused_comparisons)]
248                {
249                    if let Some(min) = min_max.min.as_primitive().as_::<P>() {
250                        vortex_ensure!(
251                            min >= 0 && min <= max_offset,
252                            "offsets minimum {min} outside valid range [0, {max_offset}]"
253                        );
254                    }
255
256                    if let Some(max) = min_max.max.as_primitive().as_::<P>() {
257                        vortex_ensure!(
258                            max >= 0 && max <= max_offset,
259                            "offsets maximum {max} outside valid range [0, {max_offset}]"
260                        )
261                    }
262                }
263
264                vortex_ensure!(
265                    max_offset
266                        <= P::try_from(elements.len())
267                            .vortex_expect("Offsets type must be able to fit elements length"),
268                    "Max offset {max_offset} is beyond the length of the elements array {}",
269                    elements.len()
270                );
271            })
272        } else {
273            // TODO(aduffy): fallback to slower validation pathway?
274            vortex_bail!(
275                "offsets array with encoding {} must support min_max compute function",
276                offsets.encoding_id()
277            );
278        };
279
280        // If a validity array is present, it must be the same length as the ListArray
281        if let Some(validity_len) = validity.maybe_len() {
282            vortex_ensure!(
283                validity_len == offsets.len() - 1,
284                "validity with size {validity_len} does not match array size {}",
285                offsets.len() - 1
286            );
287        }
288
289        Ok(())
290    }
291}
292
293impl ArrayVTable<ListVTable> for ListVTable {
294    fn len(array: &ListArray) -> usize {
295        array.offsets.len().saturating_sub(1)
296    }
297
298    fn dtype(array: &ListArray) -> &DType {
299        &array.dtype
300    }
301
302    fn stats(array: &ListArray) -> StatsSetRef<'_> {
303        array.stats_set.to_ref(array.as_ref())
304    }
305}
306
307impl OperationsVTable<ListVTable> for ListVTable {
308    fn slice(array: &ListArray, range: Range<usize>) -> ArrayRef {
309        ListArray::new(
310            array.elements().clone(),
311            array.offsets().slice(range.start..range.end + 1),
312            array.validity().slice(range),
313        )
314        .into_array()
315    }
316
317    fn scalar_at(array: &ListArray, index: usize) -> Scalar {
318        let elem = array.elements_at(index);
319        let scalars: Vec<Scalar> = (0..elem.len()).map(|i| elem.scalar_at(i)).collect();
320
321        Scalar::list(
322            Arc::new(elem.dtype().clone()),
323            scalars,
324            array.dtype().nullability(),
325        )
326    }
327}
328
329impl CanonicalVTable<ListVTable> for ListVTable {
330    fn canonicalize(array: &ListArray) -> VortexResult<Canonical> {
331        Ok(Canonical::List(array.clone()))
332    }
333}
334
335impl ValidityHelper for ListArray {
336    fn validity(&self) -> &Validity {
337        &self.validity
338    }
339}
340
341#[cfg(feature = "test-harness")]
342impl ListArray {
343    /// This is a convenience method to create a list array from an iterator of iterators.
344    /// This method is slow however since each element is first converted to a scalar and then
345    /// appended to the array.
346    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
347        iter: I,
348        dtype: Arc<DType>,
349    ) -> VortexResult<ArrayRef>
350    where
351        I::Item: IntoIterator,
352        <I::Item as IntoIterator>::Item: Into<Scalar>,
353    {
354        let iter = iter.into_iter();
355        let mut builder = ListBuilder::<O>::with_capacity(
356            dtype.clone(),
357            vortex_dtype::Nullability::NonNullable,
358            iter.size_hint().0,
359        );
360
361        for v in iter {
362            let elem = Scalar::list(
363                dtype.clone(),
364                v.into_iter().map(|x| x.into()).collect_vec(),
365                dtype.nullability(),
366            );
367            builder.append_value(elem.as_list())?
368        }
369        Ok(builder.finish())
370    }
371
372    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
373        iter: I,
374        dtype: Arc<DType>,
375    ) -> VortexResult<ArrayRef>
376    where
377        T: IntoIterator,
378        T::Item: Into<Scalar>,
379    {
380        let iter = iter.into_iter();
381        let mut builder = ListBuilder::<O>::with_capacity(
382            dtype.clone(),
383            vortex_dtype::Nullability::Nullable,
384            iter.size_hint().0,
385        );
386
387        for v in iter {
388            if let Some(v) = v {
389                let elem = Scalar::list(
390                    dtype.clone(),
391                    v.into_iter().map(|x| x.into()).collect_vec(),
392                    dtype.nullability(),
393                );
394                builder.append_value(elem.as_list())?
395            } else {
396                builder.append_null()
397            }
398        }
399        Ok(builder.finish())
400    }
401}
402
403#[cfg(test)]
404mod tests;