vortex_array/arrays/list/
mod.rs

1mod compute;
2mod serde;
3
4use std::sync::Arc;
5
6#[cfg(feature = "test-harness")]
7use itertools::Itertools;
8use num_traits::{AsPrimitive, PrimInt};
9use vortex_dtype::{DType, NativePType, match_each_native_ptype};
10use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_panic};
11use vortex_scalar::Scalar;
12
13use crate::arrays::PrimitiveVTable;
14#[cfg(feature = "test-harness")]
15use crate::builders::{ArrayBuilder, ListBuilder};
16use crate::stats::{ArrayStats, StatsSetRef};
17use crate::validity::Validity;
18use crate::vtable::{
19    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
20    ValidityVTableFromValidityHelper,
21};
22use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
23
24vtable!(List);
25
26impl VTable for ListVTable {
27    type Array = ListArray;
28    type Encoding = ListEncoding;
29
30    type ArrayVTable = Self;
31    type CanonicalVTable = Self;
32    type OperationsVTable = Self;
33    type ValidityVTable = ValidityVTableFromValidityHelper;
34    type VisitorVTable = Self;
35    type ComputeVTable = NotSupported;
36    type EncodeVTable = NotSupported;
37    type SerdeVTable = Self;
38
39    fn id(_encoding: &Self::Encoding) -> EncodingId {
40        EncodingId::new_ref("vortex.list")
41    }
42
43    fn encoding(_array: &Self::Array) -> EncodingRef {
44        EncodingRef::new_ref(ListEncoding.as_ref())
45    }
46}
47
48#[derive(Clone, Debug)]
49pub struct ListArray {
50    dtype: DType,
51    elements: ArrayRef,
52    offsets: ArrayRef,
53    validity: Validity,
54    stats_set: ArrayStats,
55}
56
57#[derive(Clone, Debug)]
58pub struct ListEncoding;
59
60pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
61
62impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
63
64// A list is valid if the:
65// - offsets start at a value in elements
66// - offsets are sorted
67// - the final offset points to an element in the elements list, pointing to zero
68//   if elements are empty.
69// - final_offset >= start_offset
70// - The size of the validity is the size-1 of the offset array
71
72impl ListArray {
73    pub fn try_new(
74        elements: ArrayRef,
75        offsets: ArrayRef,
76        validity: Validity,
77    ) -> VortexResult<Self> {
78        let nullability = validity.nullability();
79
80        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
81            vortex_bail!(
82                "Expected offsets to be an non-nullable integer type, got {:?}",
83                offsets.dtype()
84            );
85        }
86
87        if offsets.is_empty() {
88            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
89        }
90
91        Ok(Self {
92            dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
93            elements,
94            offsets,
95            validity,
96            stats_set: Default::default(),
97        })
98    }
99
100    // TODO: merge logic with varbin
101    // TODO(ngates): should return a result if it requires canonicalizing offsets
102    pub fn offset_at(&self, index: usize) -> usize {
103        self.offsets()
104            .as_opt::<PrimitiveVTable>()
105            .map(|p| {
106                match_each_native_ptype!(p.ptype(), |$P| {
107                    p.as_slice::<$P>()[index].as_()
108                })
109            })
110            .unwrap_or_else(|| {
111                self.offsets()
112                    .scalar_at(index)
113                    .unwrap_or_else(|err| {
114                        vortex_panic!(err, "Failed to get offset at index: {}", index)
115                    })
116                    .as_ref()
117                    .try_into()
118                    .vortex_expect("Failed to convert offset to usize")
119            })
120    }
121
122    // TODO: fetches the elements at index
123    pub fn elements_at(&self, index: usize) -> VortexResult<ArrayRef> {
124        let start = self.offset_at(index);
125        let end = self.offset_at(index + 1);
126        self.elements().slice(start, end)
127    }
128
129    // TODO: fetches the offsets of the array ignoring validity
130    pub fn offsets(&self) -> &ArrayRef {
131        &self.offsets
132    }
133
134    // TODO: fetches the elements of the array ignoring validity
135    pub fn elements(&self) -> &ArrayRef {
136        &self.elements
137    }
138}
139
140impl ArrayVTable<ListVTable> for ListVTable {
141    fn len(array: &ListArray) -> usize {
142        array.offsets.len().saturating_sub(1)
143    }
144
145    fn dtype(array: &ListArray) -> &DType {
146        &array.dtype
147    }
148
149    fn stats(array: &ListArray) -> StatsSetRef<'_> {
150        array.stats_set.to_ref(array.as_ref())
151    }
152}
153
154impl OperationsVTable<ListVTable> for ListVTable {
155    fn slice(array: &ListArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
156        Ok(ListArray::try_new(
157            array.elements().clone(),
158            array.offsets().slice(start, stop + 1)?,
159            array.validity().slice(start, stop)?,
160        )?
161        .into_array())
162    }
163
164    fn scalar_at(array: &ListArray, index: usize) -> VortexResult<Scalar> {
165        let elem = array.elements_at(index)?;
166        let scalars: Vec<Scalar> = (0..elem.len()).map(|i| elem.scalar_at(i)).try_collect()?;
167
168        Ok(Scalar::list(
169            Arc::new(elem.dtype().clone()),
170            scalars,
171            array.dtype().nullability(),
172        ))
173    }
174}
175
176impl CanonicalVTable<ListVTable> for ListVTable {
177    fn canonicalize(array: &ListArray) -> VortexResult<Canonical> {
178        Ok(Canonical::List(array.clone()))
179    }
180}
181
182impl ValidityHelper for ListArray {
183    fn validity(&self) -> &Validity {
184        &self.validity
185    }
186}
187
188#[cfg(feature = "test-harness")]
189impl ListArray {
190    /// This is a convenience method to create a list array from an iterator of iterators.
191    /// This method is slow however since each element is first converted to a scalar and then
192    /// appended to the array.
193    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
194        iter: I,
195        dtype: Arc<DType>,
196    ) -> VortexResult<ArrayRef>
197    where
198        I::Item: IntoIterator,
199        <I::Item as IntoIterator>::Item: Into<Scalar>,
200    {
201        let iter = iter.into_iter();
202        let mut builder = ListBuilder::<O>::with_capacity(
203            dtype.clone(),
204            vortex_dtype::Nullability::NonNullable,
205            iter.size_hint().0,
206        );
207
208        for v in iter {
209            let elem = Scalar::list(
210                dtype.clone(),
211                v.into_iter().map(|x| x.into()).collect_vec(),
212                dtype.nullability(),
213            );
214            builder.append_value(elem.as_list())?
215        }
216        Ok(builder.finish())
217    }
218
219    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
220        iter: I,
221        dtype: Arc<DType>,
222    ) -> VortexResult<ArrayRef>
223    where
224        T: IntoIterator,
225        T::Item: Into<Scalar>,
226    {
227        let iter = iter.into_iter();
228        let mut builder = ListBuilder::<O>::with_capacity(
229            dtype.clone(),
230            vortex_dtype::Nullability::Nullable,
231            iter.size_hint().0,
232        );
233
234        for v in iter {
235            if let Some(v) = v {
236                let elem = Scalar::list(
237                    dtype.clone(),
238                    v.into_iter().map(|x| x.into()).collect_vec(),
239                    dtype.nullability(),
240                );
241                builder.append_value(elem.as_list())?
242            } else {
243                builder.append_null()
244            }
245        }
246        Ok(builder.finish())
247    }
248}
249
250#[cfg(test)]
251mod test {
252    use std::sync::Arc;
253
254    use arrow_buffer::BooleanBuffer;
255    use vortex_dtype::Nullability;
256    use vortex_dtype::PType::I32;
257    use vortex_mask::Mask;
258    use vortex_scalar::Scalar;
259
260    use crate::arrays::PrimitiveArray;
261    use crate::arrays::list::ListArray;
262    use crate::compute::filter;
263    use crate::validity::Validity;
264    use crate::{Array, IntoArray};
265
266    #[test]
267    fn test_empty_list_array() {
268        let elements = PrimitiveArray::empty::<u32>(Nullability::NonNullable);
269        let offsets = PrimitiveArray::from_iter([0]);
270        let validity = Validity::AllValid;
271
272        let list =
273            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
274
275        assert_eq!(0, list.len());
276    }
277
278    #[test]
279    fn test_simple_list_array() {
280        let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
281        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
282        let validity = Validity::AllValid;
283
284        let list =
285            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
286
287        assert_eq!(
288            Scalar::list(
289                Arc::new(I32.into()),
290                vec![1.into(), 2.into()],
291                Nullability::Nullable
292            ),
293            list.scalar_at(0).unwrap()
294        );
295        assert_eq!(
296            Scalar::list(
297                Arc::new(I32.into()),
298                vec![3.into(), 4.into()],
299                Nullability::Nullable
300            ),
301            list.scalar_at(1).unwrap()
302        );
303        assert_eq!(
304            Scalar::list(Arc::new(I32.into()), vec![5.into()], Nullability::Nullable),
305            list.scalar_at(2).unwrap()
306        );
307    }
308
309    #[test]
310    fn test_simple_list_array_from_iter() {
311        let elements = PrimitiveArray::from_iter([1i32, 2, 3]);
312        let offsets = PrimitiveArray::from_iter([0, 2, 3]);
313        let validity = Validity::NonNullable;
314
315        let list =
316            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
317
318        let list_from_iter =
319            ListArray::from_iter_slow::<u32, _>(vec![vec![1i32, 2], vec![3]], Arc::new(I32.into()))
320                .unwrap();
321
322        assert_eq!(list.len(), list_from_iter.len());
323        assert_eq!(
324            list.scalar_at(0).unwrap(),
325            list_from_iter.scalar_at(0).unwrap()
326        );
327        assert_eq!(
328            list.scalar_at(1).unwrap(),
329            list_from_iter.scalar_at(1).unwrap()
330        );
331    }
332
333    #[test]
334    fn test_simple_list_filter() {
335        let elements = PrimitiveArray::from_option_iter([None, Some(2), Some(3), Some(4), Some(5)]);
336        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
337        let validity = Validity::AllValid;
338
339        let list = ListArray::try_new(elements.into_array(), offsets.into_array(), validity)
340            .unwrap()
341            .into_array();
342
343        let filtered = filter(
344            &list,
345            &Mask::from(BooleanBuffer::from(vec![false, true, true])),
346        );
347
348        assert!(filtered.is_ok())
349    }
350}