vortex_array/arrays/list/
mod.rs

1mod compute;
2mod serde;
3
4use std::sync::Arc;
5
6use itertools::Itertools;
7use num_traits::{AsPrimitive, PrimInt};
8use vortex_dtype::{DType, NativePType, match_each_native_ptype};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_panic};
10use vortex_scalar::Scalar;
11
12use crate::arrays::PrimitiveVTable;
13#[cfg(feature = "test-harness")]
14use crate::builders::{ArrayBuilder, ListBuilder};
15use crate::stats::{ArrayStats, StatsSetRef};
16use crate::validity::Validity;
17use crate::vtable::{
18    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
19    ValidityVTableFromValidityHelper,
20};
21use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
22
23vtable!(List);
24
25impl VTable for ListVTable {
26    type Array = ListArray;
27    type Encoding = ListEncoding;
28
29    type ArrayVTable = Self;
30    type CanonicalVTable = Self;
31    type OperationsVTable = Self;
32    type ValidityVTable = ValidityVTableFromValidityHelper;
33    type VisitorVTable = Self;
34    type ComputeVTable = NotSupported;
35    type EncodeVTable = NotSupported;
36    type SerdeVTable = Self;
37
38    fn id(_encoding: &Self::Encoding) -> EncodingId {
39        EncodingId::new_ref("vortex.list")
40    }
41
42    fn encoding(_array: &Self::Array) -> EncodingRef {
43        EncodingRef::new_ref(ListEncoding.as_ref())
44    }
45}
46
47#[derive(Clone, Debug)]
48pub struct ListArray {
49    dtype: DType,
50    elements: ArrayRef,
51    offsets: ArrayRef,
52    validity: Validity,
53    stats_set: ArrayStats,
54}
55
56#[derive(Clone, Debug)]
57pub struct ListEncoding;
58
59pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
60
61impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
62
63// A list is valid if the:
64// - offsets start at a value in elements
65// - offsets are sorted
66// - the final offset points to an element in the elements list, pointing to zero
67//   if elements are empty.
68// - final_offset >= start_offset
69// - The size of the validity is the size-1 of the offset array
70
71impl ListArray {
72    pub fn try_new(
73        elements: ArrayRef,
74        offsets: ArrayRef,
75        validity: Validity,
76    ) -> VortexResult<Self> {
77        let nullability = validity.nullability();
78
79        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
80            vortex_bail!(
81                "Expected offsets to be an non-nullable integer type, got {:?}",
82                offsets.dtype()
83            );
84        }
85
86        if offsets.is_empty() {
87            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
88        }
89
90        Ok(Self {
91            dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
92            elements,
93            offsets,
94            validity,
95            stats_set: Default::default(),
96        })
97    }
98
99    // TODO: merge logic with varbin
100    // TODO(ngates): should return a result if it requires canonicalizing offsets
101    pub fn offset_at(&self, index: usize) -> usize {
102        self.offsets()
103            .as_opt::<PrimitiveVTable>()
104            .map(|p| match_each_native_ptype!(p.ptype(), |P| { p.as_slice::<P>()[index].as_() }))
105            .unwrap_or_else(|| {
106                self.offsets()
107                    .scalar_at(index)
108                    .unwrap_or_else(|err| {
109                        vortex_panic!(err, "Failed to get offset at index: {}", index)
110                    })
111                    .as_ref()
112                    .try_into()
113                    .vortex_expect("Failed to convert offset to usize")
114            })
115    }
116
117    // TODO: fetches the elements at index
118    pub fn elements_at(&self, index: usize) -> VortexResult<ArrayRef> {
119        let start = self.offset_at(index);
120        let end = self.offset_at(index + 1);
121        self.elements().slice(start, end)
122    }
123
124    // TODO: fetches the offsets of the array ignoring validity
125    pub fn offsets(&self) -> &ArrayRef {
126        &self.offsets
127    }
128
129    // TODO: fetches the elements of the array ignoring validity
130    pub fn elements(&self) -> &ArrayRef {
131        &self.elements
132    }
133}
134
135impl ArrayVTable<ListVTable> for ListVTable {
136    fn len(array: &ListArray) -> usize {
137        array.offsets.len().saturating_sub(1)
138    }
139
140    fn dtype(array: &ListArray) -> &DType {
141        &array.dtype
142    }
143
144    fn stats(array: &ListArray) -> StatsSetRef<'_> {
145        array.stats_set.to_ref(array.as_ref())
146    }
147}
148
149impl OperationsVTable<ListVTable> for ListVTable {
150    fn slice(array: &ListArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
151        Ok(ListArray::try_new(
152            array.elements().clone(),
153            array.offsets().slice(start, stop + 1)?,
154            array.validity().slice(start, stop)?,
155        )?
156        .into_array())
157    }
158
159    fn scalar_at(array: &ListArray, index: usize) -> VortexResult<Scalar> {
160        let elem = array.elements_at(index)?;
161        let scalars: Vec<Scalar> = (0..elem.len()).map(|i| elem.scalar_at(i)).try_collect()?;
162
163        Ok(Scalar::list(
164            Arc::new(elem.dtype().clone()),
165            scalars,
166            array.dtype().nullability(),
167        ))
168    }
169}
170
171impl CanonicalVTable<ListVTable> for ListVTable {
172    fn canonicalize(array: &ListArray) -> VortexResult<Canonical> {
173        Ok(Canonical::List(array.clone()))
174    }
175}
176
177impl ValidityHelper for ListArray {
178    fn validity(&self) -> &Validity {
179        &self.validity
180    }
181}
182
183#[cfg(feature = "test-harness")]
184impl ListArray {
185    /// This is a convenience method to create a list array from an iterator of iterators.
186    /// This method is slow however since each element is first converted to a scalar and then
187    /// appended to the array.
188    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
189        iter: I,
190        dtype: Arc<DType>,
191    ) -> VortexResult<ArrayRef>
192    where
193        I::Item: IntoIterator,
194        <I::Item as IntoIterator>::Item: Into<Scalar>,
195    {
196        let iter = iter.into_iter();
197        let mut builder = ListBuilder::<O>::with_capacity(
198            dtype.clone(),
199            vortex_dtype::Nullability::NonNullable,
200            iter.size_hint().0,
201        );
202
203        for v in iter {
204            let elem = Scalar::list(
205                dtype.clone(),
206                v.into_iter().map(|x| x.into()).collect_vec(),
207                dtype.nullability(),
208            );
209            builder.append_value(elem.as_list())?
210        }
211        Ok(builder.finish())
212    }
213
214    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
215        iter: I,
216        dtype: Arc<DType>,
217    ) -> VortexResult<ArrayRef>
218    where
219        T: IntoIterator,
220        T::Item: Into<Scalar>,
221    {
222        let iter = iter.into_iter();
223        let mut builder = ListBuilder::<O>::with_capacity(
224            dtype.clone(),
225            vortex_dtype::Nullability::Nullable,
226            iter.size_hint().0,
227        );
228
229        for v in iter {
230            if let Some(v) = v {
231                let elem = Scalar::list(
232                    dtype.clone(),
233                    v.into_iter().map(|x| x.into()).collect_vec(),
234                    dtype.nullability(),
235                );
236                builder.append_value(elem.as_list())?
237            } else {
238                builder.append_null()
239            }
240        }
241        Ok(builder.finish())
242    }
243}
244
245#[cfg(test)]
246mod test {
247    use std::sync::Arc;
248
249    use arrow_buffer::BooleanBuffer;
250    use vortex_dtype::Nullability;
251    use vortex_dtype::PType::I32;
252    use vortex_mask::Mask;
253    use vortex_scalar::Scalar;
254
255    use crate::arrays::PrimitiveArray;
256    use crate::arrays::list::ListArray;
257    use crate::compute::filter;
258    use crate::validity::Validity;
259    use crate::{Array, IntoArray};
260
261    #[test]
262    fn test_empty_list_array() {
263        let elements = PrimitiveArray::empty::<u32>(Nullability::NonNullable);
264        let offsets = PrimitiveArray::from_iter([0]);
265        let validity = Validity::AllValid;
266
267        let list =
268            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
269
270        assert_eq!(0, list.len());
271    }
272
273    #[test]
274    fn test_simple_list_array() {
275        let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
276        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
277        let validity = Validity::AllValid;
278
279        let list =
280            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
281
282        assert_eq!(
283            Scalar::list(
284                Arc::new(I32.into()),
285                vec![1.into(), 2.into()],
286                Nullability::Nullable
287            ),
288            list.scalar_at(0).unwrap()
289        );
290        assert_eq!(
291            Scalar::list(
292                Arc::new(I32.into()),
293                vec![3.into(), 4.into()],
294                Nullability::Nullable
295            ),
296            list.scalar_at(1).unwrap()
297        );
298        assert_eq!(
299            Scalar::list(Arc::new(I32.into()), vec![5.into()], Nullability::Nullable),
300            list.scalar_at(2).unwrap()
301        );
302    }
303
304    #[test]
305    fn test_simple_list_array_from_iter() {
306        let elements = PrimitiveArray::from_iter([1i32, 2, 3]);
307        let offsets = PrimitiveArray::from_iter([0, 2, 3]);
308        let validity = Validity::NonNullable;
309
310        let list =
311            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
312
313        let list_from_iter =
314            ListArray::from_iter_slow::<u32, _>(vec![vec![1i32, 2], vec![3]], Arc::new(I32.into()))
315                .unwrap();
316
317        assert_eq!(list.len(), list_from_iter.len());
318        assert_eq!(
319            list.scalar_at(0).unwrap(),
320            list_from_iter.scalar_at(0).unwrap()
321        );
322        assert_eq!(
323            list.scalar_at(1).unwrap(),
324            list_from_iter.scalar_at(1).unwrap()
325        );
326    }
327
328    #[test]
329    fn test_simple_list_filter() {
330        let elements = PrimitiveArray::from_option_iter([None, Some(2), Some(3), Some(4), Some(5)]);
331        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
332        let validity = Validity::AllValid;
333
334        let list = ListArray::try_new(elements.into_array(), offsets.into_array(), validity)
335            .unwrap()
336            .into_array();
337
338        let filtered = filter(
339            &list,
340            &Mask::from(BooleanBuffer::from(vec![false, true, true])),
341        );
342
343        assert!(filtered.is_ok())
344    }
345}