vortex_array/arrays/list/
mod.rs

1mod compute;
2mod serde;
3
4use std::sync::Arc;
5
6#[cfg(feature = "test-harness")]
7use itertools::Itertools;
8use num_traits::{AsPrimitive, PrimInt};
9use vortex_dtype::{DType, NativePType, match_each_native_ptype};
10use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_panic};
11use vortex_scalar::Scalar;
12
13use crate::arrays::PrimitiveVTable;
14#[cfg(feature = "test-harness")]
15use crate::builders::{ArrayBuilder, ListBuilder};
16use crate::stats::{ArrayStats, StatsSetRef};
17use crate::validity::Validity;
18use crate::vtable::{
19    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
20    ValidityVTableFromValidityHelper,
21};
22use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
23
24vtable!(List);
25
26impl VTable for ListVTable {
27    type Array = ListArray;
28    type Encoding = ListEncoding;
29
30    type ArrayVTable = Self;
31    type CanonicalVTable = Self;
32    type OperationsVTable = Self;
33    type ValidityVTable = ValidityVTableFromValidityHelper;
34    type VisitorVTable = Self;
35    type ComputeVTable = NotSupported;
36    type EncodeVTable = NotSupported;
37    type SerdeVTable = Self;
38
39    fn id(_encoding: &Self::Encoding) -> EncodingId {
40        EncodingId::new_ref("vortex.list")
41    }
42
43    fn encoding(_array: &Self::Array) -> EncodingRef {
44        EncodingRef::new_ref(ListEncoding.as_ref())
45    }
46}
47
48#[derive(Clone, Debug)]
49pub struct ListArray {
50    dtype: DType,
51    elements: ArrayRef,
52    offsets: ArrayRef,
53    validity: Validity,
54    stats_set: ArrayStats,
55}
56
57#[derive(Clone, Debug)]
58pub struct ListEncoding;
59
60pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
61
62impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
63
64// A list is valid if the:
65// - offsets start at a value in elements
66// - offsets are sorted
67// - the final offset points to an element in the elements list, pointing to zero
68//   if elements are empty.
69// - final_offset >= start_offset
70// - The size of the validity is the size-1 of the offset array
71
72impl ListArray {
73    pub fn try_new(
74        elements: ArrayRef,
75        offsets: ArrayRef,
76        validity: Validity,
77    ) -> VortexResult<Self> {
78        let nullability = validity.nullability();
79
80        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
81            vortex_bail!(
82                "Expected offsets to be an non-nullable integer type, got {:?}",
83                offsets.dtype()
84            );
85        }
86
87        if offsets.is_empty() {
88            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
89        }
90
91        Ok(Self {
92            dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
93            elements,
94            offsets,
95            validity,
96            stats_set: Default::default(),
97        })
98    }
99
100    // TODO: merge logic with varbin
101    // TODO(ngates): should return a result if it requires canonicalizing offsets
102    pub fn offset_at(&self, index: usize) -> usize {
103        self.offsets()
104            .as_opt::<PrimitiveVTable>()
105            .map(|p| match_each_native_ptype!(p.ptype(), |P| { p.as_slice::<P>()[index].as_() }))
106            .unwrap_or_else(|| {
107                self.offsets()
108                    .scalar_at(index)
109                    .unwrap_or_else(|err| {
110                        vortex_panic!(err, "Failed to get offset at index: {}", index)
111                    })
112                    .as_ref()
113                    .try_into()
114                    .vortex_expect("Failed to convert offset to usize")
115            })
116    }
117
118    // TODO: fetches the elements at index
119    pub fn elements_at(&self, index: usize) -> VortexResult<ArrayRef> {
120        let start = self.offset_at(index);
121        let end = self.offset_at(index + 1);
122        self.elements().slice(start, end)
123    }
124
125    // TODO: fetches the offsets of the array ignoring validity
126    pub fn offsets(&self) -> &ArrayRef {
127        &self.offsets
128    }
129
130    // TODO: fetches the elements of the array ignoring validity
131    pub fn elements(&self) -> &ArrayRef {
132        &self.elements
133    }
134}
135
136impl ArrayVTable<ListVTable> for ListVTable {
137    fn len(array: &ListArray) -> usize {
138        array.offsets.len().saturating_sub(1)
139    }
140
141    fn dtype(array: &ListArray) -> &DType {
142        &array.dtype
143    }
144
145    fn stats(array: &ListArray) -> StatsSetRef<'_> {
146        array.stats_set.to_ref(array.as_ref())
147    }
148}
149
150impl OperationsVTable<ListVTable> for ListVTable {
151    fn slice(array: &ListArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
152        Ok(ListArray::try_new(
153            array.elements().clone(),
154            array.offsets().slice(start, stop + 1)?,
155            array.validity().slice(start, stop)?,
156        )?
157        .into_array())
158    }
159
160    fn scalar_at(array: &ListArray, index: usize) -> VortexResult<Scalar> {
161        let elem = array.elements_at(index)?;
162        let scalars: Vec<Scalar> = (0..elem.len()).map(|i| elem.scalar_at(i)).try_collect()?;
163
164        Ok(Scalar::list(
165            Arc::new(elem.dtype().clone()),
166            scalars,
167            array.dtype().nullability(),
168        ))
169    }
170}
171
172impl CanonicalVTable<ListVTable> for ListVTable {
173    fn canonicalize(array: &ListArray) -> VortexResult<Canonical> {
174        Ok(Canonical::List(array.clone()))
175    }
176}
177
178impl ValidityHelper for ListArray {
179    fn validity(&self) -> &Validity {
180        &self.validity
181    }
182}
183
184#[cfg(feature = "test-harness")]
185impl ListArray {
186    /// This is a convenience method to create a list array from an iterator of iterators.
187    /// This method is slow however since each element is first converted to a scalar and then
188    /// appended to the array.
189    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
190        iter: I,
191        dtype: Arc<DType>,
192    ) -> VortexResult<ArrayRef>
193    where
194        I::Item: IntoIterator,
195        <I::Item as IntoIterator>::Item: Into<Scalar>,
196    {
197        let iter = iter.into_iter();
198        let mut builder = ListBuilder::<O>::with_capacity(
199            dtype.clone(),
200            vortex_dtype::Nullability::NonNullable,
201            iter.size_hint().0,
202        );
203
204        for v in iter {
205            let elem = Scalar::list(
206                dtype.clone(),
207                v.into_iter().map(|x| x.into()).collect_vec(),
208                dtype.nullability(),
209            );
210            builder.append_value(elem.as_list())?
211        }
212        Ok(builder.finish())
213    }
214
215    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
216        iter: I,
217        dtype: Arc<DType>,
218    ) -> VortexResult<ArrayRef>
219    where
220        T: IntoIterator,
221        T::Item: Into<Scalar>,
222    {
223        let iter = iter.into_iter();
224        let mut builder = ListBuilder::<O>::with_capacity(
225            dtype.clone(),
226            vortex_dtype::Nullability::Nullable,
227            iter.size_hint().0,
228        );
229
230        for v in iter {
231            if let Some(v) = v {
232                let elem = Scalar::list(
233                    dtype.clone(),
234                    v.into_iter().map(|x| x.into()).collect_vec(),
235                    dtype.nullability(),
236                );
237                builder.append_value(elem.as_list())?
238            } else {
239                builder.append_null()
240            }
241        }
242        Ok(builder.finish())
243    }
244}
245
246#[cfg(test)]
247mod test {
248    use std::sync::Arc;
249
250    use arrow_buffer::BooleanBuffer;
251    use vortex_dtype::Nullability;
252    use vortex_dtype::PType::I32;
253    use vortex_mask::Mask;
254    use vortex_scalar::Scalar;
255
256    use crate::arrays::PrimitiveArray;
257    use crate::arrays::list::ListArray;
258    use crate::compute::filter;
259    use crate::validity::Validity;
260    use crate::{Array, IntoArray};
261
262    #[test]
263    fn test_empty_list_array() {
264        let elements = PrimitiveArray::empty::<u32>(Nullability::NonNullable);
265        let offsets = PrimitiveArray::from_iter([0]);
266        let validity = Validity::AllValid;
267
268        let list =
269            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
270
271        assert_eq!(0, list.len());
272    }
273
274    #[test]
275    fn test_simple_list_array() {
276        let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
277        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
278        let validity = Validity::AllValid;
279
280        let list =
281            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
282
283        assert_eq!(
284            Scalar::list(
285                Arc::new(I32.into()),
286                vec![1.into(), 2.into()],
287                Nullability::Nullable
288            ),
289            list.scalar_at(0).unwrap()
290        );
291        assert_eq!(
292            Scalar::list(
293                Arc::new(I32.into()),
294                vec![3.into(), 4.into()],
295                Nullability::Nullable
296            ),
297            list.scalar_at(1).unwrap()
298        );
299        assert_eq!(
300            Scalar::list(Arc::new(I32.into()), vec![5.into()], Nullability::Nullable),
301            list.scalar_at(2).unwrap()
302        );
303    }
304
305    #[test]
306    fn test_simple_list_array_from_iter() {
307        let elements = PrimitiveArray::from_iter([1i32, 2, 3]);
308        let offsets = PrimitiveArray::from_iter([0, 2, 3]);
309        let validity = Validity::NonNullable;
310
311        let list =
312            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
313
314        let list_from_iter =
315            ListArray::from_iter_slow::<u32, _>(vec![vec![1i32, 2], vec![3]], Arc::new(I32.into()))
316                .unwrap();
317
318        assert_eq!(list.len(), list_from_iter.len());
319        assert_eq!(
320            list.scalar_at(0).unwrap(),
321            list_from_iter.scalar_at(0).unwrap()
322        );
323        assert_eq!(
324            list.scalar_at(1).unwrap(),
325            list_from_iter.scalar_at(1).unwrap()
326        );
327    }
328
329    #[test]
330    fn test_simple_list_filter() {
331        let elements = PrimitiveArray::from_option_iter([None, Some(2), Some(3), Some(4), Some(5)]);
332        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
333        let validity = Validity::AllValid;
334
335        let list = ListArray::try_new(elements.into_array(), offsets.into_array(), validity)
336            .unwrap()
337            .into_array();
338
339        let filtered = filter(
340            &list,
341            &Mask::from(BooleanBuffer::from(vec![false, true, true])),
342        );
343
344        assert!(filtered.is_ok())
345    }
346}