vortex_array/arrays/list/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod compute;
5mod serde;
6
7use std::ops::Range;
8use std::sync::Arc;
9
10#[cfg(feature = "test-harness")]
11use itertools::Itertools;
12use num_traits::{AsPrimitive, PrimInt};
13use vortex_dtype::{DType, NativePType, match_each_integer_ptype, match_each_native_ptype};
14use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_ensure};
15use vortex_scalar::Scalar;
16
17use crate::arrays::PrimitiveVTable;
18#[cfg(feature = "test-harness")]
19use crate::builders::{ArrayBuilder, ListBuilder};
20use crate::compute::{min_max, sub_scalar};
21use crate::stats::{ArrayStats, StatsSetRef};
22use crate::validity::Validity;
23use crate::vtable::{
24    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
25    ValidityVTableFromValidityHelper,
26};
27use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
28
29vtable!(List);
30
31impl VTable for ListVTable {
32    type Array = ListArray;
33    type Encoding = ListEncoding;
34
35    type ArrayVTable = Self;
36    type CanonicalVTable = Self;
37    type OperationsVTable = Self;
38    type ValidityVTable = ValidityVTableFromValidityHelper;
39    type VisitorVTable = Self;
40    type ComputeVTable = NotSupported;
41    type EncodeVTable = NotSupported;
42    type PipelineVTable = NotSupported;
43    type SerdeVTable = Self;
44
45    fn id(_encoding: &Self::Encoding) -> EncodingId {
46        EncodingId::new_ref("vortex.list")
47    }
48
49    fn encoding(_array: &Self::Array) -> EncodingRef {
50        EncodingRef::new_ref(ListEncoding.as_ref())
51    }
52}
53
54/// A list array that stores variable-length lists of elements, similar to `Vec<Vec<T>>`.
55///
56/// This mirrors the Apache Arrow List array encoding and provides efficient storage
57/// for nested data where each row contains a list of elements of the same type.
58///
59/// ## Data Layout
60///
61/// The list array uses an offset-based encoding:
62/// - **Elements array**: A flat array containing all list elements concatenated together
63/// - **Offsets array**: Integer array where `offsets[i]` is an (inclusive) start index into
64///   the **elements** and `offsets[i+1]` is the (exclusive) stop index for the `i`th list.
65/// - **Validity**: Optional mask indicating which lists are null
66///
67/// This allows for excellent cascading compression of the elements and offsets, as similar values
68/// are clustered together and the offsets have a predictable pattern and small deltas between
69/// consecutive elements.
70///
71/// ## Offset Semantics
72///
73/// - Offsets must be non-nullable integers (i32, i64, etc.)
74/// - Offsets array has length `n+1` where `n` is the number of lists
75/// - List `i` contains elements from `elements[offsets[i]..offsets[i+1]]`
76/// - Offsets must be monotonically increasing
77///
78/// # Examples
79///
80/// ```
81/// use vortex_array::arrays::{ListArray, PrimitiveArray};
82/// use vortex_array::validity::Validity;
83/// use vortex_array::IntoArray;
84/// use std::sync::Arc;
85///
86/// // Create a list array representing [[1, 2], [3, 4, 5], []]
87/// let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
88/// let offsets = PrimitiveArray::from_iter([0u32, 2, 5, 5]); // 3 lists
89///
90/// let list_array = ListArray::try_new(
91///     elements.into_array(),
92///     offsets.into_array(),
93///     Validity::NonNullable,
94/// ).unwrap();
95///
96/// assert_eq!(list_array.len(), 3);
97///
98/// // Access individual lists
99/// let first_list = list_array.elements_at(0);
100/// assert_eq!(first_list.len(), 2); // [1, 2]
101///
102/// let third_list = list_array.elements_at(2);
103/// assert!(third_list.is_empty()); // []
104/// ```
105#[derive(Clone, Debug)]
106pub struct ListArray {
107    dtype: DType,
108    elements: ArrayRef,
109    offsets: ArrayRef,
110    validity: Validity,
111    stats_set: ArrayStats,
112}
113
114#[derive(Clone, Debug)]
115pub struct ListEncoding;
116
117pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
118
119impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
120
121impl ListArray {
122    pub fn new(elements: ArrayRef, offsets: ArrayRef, validity: Validity) -> Self {
123        Self::try_new(elements, offsets, validity).vortex_expect("ListArray new")
124    }
125
126    pub fn try_new(
127        elements: ArrayRef,
128        offsets: ArrayRef,
129        validity: Validity,
130    ) -> VortexResult<Self> {
131        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
132            vortex_bail!(
133                "Expected offsets to be an non-nullable integer type, got {:?}",
134                offsets.dtype()
135            );
136        }
137
138        if offsets.is_empty() {
139            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
140        }
141
142        Self::validate(&elements, &offsets, &validity)?;
143
144        Ok(unsafe { Self::new_unchecked(elements, offsets, validity) })
145    }
146
147    /// # Safety
148    /// This method is safe too call if the arguments passed wouldn't cause an error when calling [`Self::try_new`]
149    pub unsafe fn new_unchecked(elements: ArrayRef, offsets: ArrayRef, validity: Validity) -> Self {
150        Self {
151            dtype: DType::List(Arc::new(elements.dtype().clone()), validity.nullability()),
152            elements,
153            offsets,
154            validity,
155            stats_set: Default::default(),
156        }
157    }
158
159    /// Returns the offset at the given index from the list array.
160    ///
161    /// Panics if the index is out of bounds.
162    pub fn offset_at(&self, index: usize) -> usize {
163        assert!(
164            index <= self.len(),
165            "Index {index} out of bounds 0..={}",
166            self.len()
167        );
168
169        self.offsets()
170            .as_opt::<PrimitiveVTable>()
171            .map(|p| match_each_native_ptype!(p.ptype(), |P| { p.as_slice::<P>()[index].as_() }))
172            .unwrap_or_else(|| {
173                self.offsets()
174                    .scalar_at(index)
175                    .as_primitive()
176                    .as_::<usize>()
177                    .vortex_expect("index must fit in usize")
178            })
179    }
180
181    /// Returns the elements at the given index from the list array.
182    pub fn elements_at(&self, index: usize) -> ArrayRef {
183        let start = self.offset_at(index);
184        let end = self.offset_at(index + 1);
185        self.elements().slice(start..end)
186    }
187
188    /// Returns elements of the list array referenced by the offsets array
189    pub fn sliced_elements(&self) -> ArrayRef {
190        let start = self.offset_at(0);
191        let end = self.offset_at(self.len());
192        self.elements().slice(start..end)
193    }
194
195    /// Returns the offsets array.
196    pub fn offsets(&self) -> &ArrayRef {
197        &self.offsets
198    }
199
200    /// Returns the elements array.
201    pub fn elements(&self) -> &ArrayRef {
202        &self.elements
203    }
204
205    /// Create a copy of this array by adjusting offsets to start at 0 and removing elements not referenced by the offsets
206    pub fn reset_offsets(&self) -> VortexResult<Self> {
207        let elements = self.sliced_elements();
208        let offsets = self.offsets();
209        let first_offset = offsets.scalar_at(0);
210        let adjusted_offsets = sub_scalar(offsets, first_offset)?;
211
212        Self::try_new(elements, adjusted_offsets, self.validity.clone())
213    }
214
215    /// A list is valid if the:
216    /// - offsets start at a value in elements
217    /// - offsets are sorted
218    /// - the final offset points to an element in the elements list, pointing to zero
219    ///   if elements are empty.
220    /// - final_offset >= start_offset
221    /// - The size of the validity is the size-1 of the offset array
222    fn validate(
223        elements: &dyn Array,
224        offsets: &dyn Array,
225        validity: &Validity,
226    ) -> VortexResult<()> {
227        // Offsets must be of integer type, and cannot go lower than 0.
228        vortex_ensure!(
229            offsets.dtype().is_int() && !offsets.dtype().is_nullable(),
230            "offsets have invalid type {}",
231            offsets.dtype()
232        );
233
234        // We can safely unwrap the DType as primitive now
235        let offsets_ptype = offsets.dtype().as_ptype();
236
237        // Offsets must be sorted (but not strictly sorted, zero-length lists are allowed)
238        if let Some(is_sorted) = offsets.statistics().compute_is_sorted() {
239            vortex_ensure!(is_sorted, "offsets must be sorted");
240        } else {
241            vortex_bail!("offsets must report is_sorted statistic");
242        }
243
244        // Validate that offsets min is non-negative, and max does not exceed the length of
245        // the elements array.
246        if let Some(min_max) = min_max(offsets)? {
247            match_each_integer_ptype!(offsets_ptype, |P| {
248                let max_offset = P::try_from(offsets.scalar_at(offsets.len() - 1))
249                    .vortex_expect("Offsets type must fit offsets values");
250
251                #[allow(clippy::absurd_extreme_comparisons, unused_comparisons)]
252                {
253                    if let Some(min) = min_max.min.as_primitive().as_::<P>() {
254                        vortex_ensure!(
255                            min >= 0 && min <= max_offset,
256                            "offsets minimum {min} outside valid range [0, {max_offset}]"
257                        );
258                    }
259
260                    if let Some(max) = min_max.max.as_primitive().as_::<P>() {
261                        vortex_ensure!(
262                            max >= 0 && max <= max_offset,
263                            "offsets maximum {max} outside valid range [0, {max_offset}]"
264                        )
265                    }
266                }
267
268                vortex_ensure!(
269                    max_offset
270                        <= P::try_from(elements.len())
271                            .vortex_expect("Offsets type must be able to fit elements length"),
272                    "Max offset {max_offset} is beyond the length of the elements array {}",
273                    elements.len()
274                );
275            })
276        } else {
277            // TODO(aduffy): fallback to slower validation pathway?
278            vortex_bail!(
279                "offsets array with encoding {} must support min_max compute function",
280                offsets.encoding_id()
281            );
282        };
283
284        // If a validity array is present, it must be the same length as the ListArray
285        if let Some(validity_len) = validity.maybe_len() {
286            vortex_ensure!(
287                validity_len == offsets.len() - 1,
288                "validity with size {validity_len} does not match array size {}",
289                offsets.len() - 1
290            );
291        }
292
293        Ok(())
294    }
295}
296
297impl ArrayVTable<ListVTable> for ListVTable {
298    fn len(array: &ListArray) -> usize {
299        array.offsets.len().saturating_sub(1)
300    }
301
302    fn dtype(array: &ListArray) -> &DType {
303        &array.dtype
304    }
305
306    fn stats(array: &ListArray) -> StatsSetRef<'_> {
307        array.stats_set.to_ref(array.as_ref())
308    }
309}
310
311impl OperationsVTable<ListVTable> for ListVTable {
312    fn slice(array: &ListArray, range: Range<usize>) -> ArrayRef {
313        ListArray::new(
314            array.elements().clone(),
315            array.offsets().slice(range.start..range.end + 1),
316            array.validity().slice(range),
317        )
318        .into_array()
319    }
320
321    fn scalar_at(array: &ListArray, index: usize) -> Scalar {
322        let elem = array.elements_at(index);
323        let scalars: Vec<Scalar> = (0..elem.len()).map(|i| elem.scalar_at(i)).collect();
324
325        Scalar::list(
326            Arc::new(elem.dtype().clone()),
327            scalars,
328            array.dtype().nullability(),
329        )
330    }
331}
332
333impl CanonicalVTable<ListVTable> for ListVTable {
334    fn canonicalize(array: &ListArray) -> Canonical {
335        Canonical::List(array.clone())
336    }
337}
338
339impl ValidityHelper for ListArray {
340    fn validity(&self) -> &Validity {
341        &self.validity
342    }
343}
344
345#[cfg(feature = "test-harness")]
346impl ListArray {
347    /// This is a convenience method to create a list array from an iterator of iterators.
348    /// This method is slow however since each element is first converted to a scalar and then
349    /// appended to the array.
350    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
351        iter: I,
352        dtype: Arc<DType>,
353    ) -> VortexResult<ArrayRef>
354    where
355        I::Item: IntoIterator,
356        <I::Item as IntoIterator>::Item: Into<Scalar>,
357    {
358        let iter = iter.into_iter();
359        let mut builder = ListBuilder::<O>::with_capacity(
360            dtype.clone(),
361            vortex_dtype::Nullability::NonNullable,
362            iter.size_hint().0,
363        );
364
365        for v in iter {
366            let elem = Scalar::list(
367                dtype.clone(),
368                v.into_iter().map(|x| x.into()).collect_vec(),
369                dtype.nullability(),
370            );
371            builder.append_value(elem.as_list())?
372        }
373        Ok(builder.finish())
374    }
375
376    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
377        iter: I,
378        dtype: Arc<DType>,
379    ) -> VortexResult<ArrayRef>
380    where
381        T: IntoIterator,
382        T::Item: Into<Scalar>,
383    {
384        let iter = iter.into_iter();
385        let mut builder = ListBuilder::<O>::with_capacity(
386            dtype.clone(),
387            vortex_dtype::Nullability::Nullable,
388            iter.size_hint().0,
389        );
390
391        for v in iter {
392            if let Some(v) = v {
393                let elem = Scalar::list(
394                    dtype.clone(),
395                    v.into_iter().map(|x| x.into()).collect_vec(),
396                    dtype.nullability(),
397                );
398                builder.append_value(elem.as_list())?
399            } else {
400                builder.append_null()
401            }
402        }
403        Ok(builder.finish())
404    }
405}
406
407#[cfg(test)]
408mod tests;