vortex_array/arrays/list/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod compute;
5mod serde;
6
7use std::sync::Arc;
8
9use itertools::Itertools;
10use num_traits::{AsPrimitive, PrimInt};
11use vortex_dtype::{DType, NativePType, match_each_native_ptype};
12use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_panic};
13use vortex_scalar::Scalar;
14
15use crate::arrays::PrimitiveVTable;
16#[cfg(feature = "test-harness")]
17use crate::builders::{ArrayBuilder, ListBuilder};
18use crate::stats::{ArrayStats, StatsSetRef};
19use crate::validity::Validity;
20use crate::vtable::{
21    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityHelper,
22    ValidityVTableFromValidityHelper,
23};
24use crate::{Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, vtable};
25
26vtable!(List);
27
28impl VTable for ListVTable {
29    type Array = ListArray;
30    type Encoding = ListEncoding;
31
32    type ArrayVTable = Self;
33    type CanonicalVTable = Self;
34    type OperationsVTable = Self;
35    type ValidityVTable = ValidityVTableFromValidityHelper;
36    type VisitorVTable = Self;
37    type ComputeVTable = NotSupported;
38    type EncodeVTable = NotSupported;
39    type SerdeVTable = Self;
40
41    fn id(_encoding: &Self::Encoding) -> EncodingId {
42        EncodingId::new_ref("vortex.list")
43    }
44
45    fn encoding(_array: &Self::Array) -> EncodingRef {
46        EncodingRef::new_ref(ListEncoding.as_ref())
47    }
48}
49
50/// The canonical array for List type data. This is modeled similarly to a ListArray.
51#[derive(Clone, Debug)]
52pub struct ListArray {
53    dtype: DType,
54    elements: ArrayRef,
55    offsets: ArrayRef,
56    validity: Validity,
57    stats_set: ArrayStats,
58}
59
60#[derive(Clone, Debug)]
61pub struct ListEncoding;
62
63pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
64
65impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
66
67// A list is valid if the:
68// - offsets start at a value in elements
69// - offsets are sorted
70// - the final offset points to an element in the elements list, pointing to zero
71//   if elements are empty.
72// - final_offset >= start_offset
73// - The size of the validity is the size-1 of the offset array
74
75impl ListArray {
76    pub fn try_new(
77        elements: ArrayRef,
78        offsets: ArrayRef,
79        validity: Validity,
80    ) -> VortexResult<Self> {
81        let nullability = validity.nullability();
82
83        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
84            vortex_bail!(
85                "Expected offsets to be an non-nullable integer type, got {:?}",
86                offsets.dtype()
87            );
88        }
89
90        if offsets.is_empty() {
91            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
92        }
93
94        Ok(Self {
95            dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
96            elements,
97            offsets,
98            validity,
99            stats_set: Default::default(),
100        })
101    }
102
103    // TODO: merge logic with varbin
104    // TODO(ngates): should return a result if it requires canonicalizing offsets
105    pub fn offset_at(&self, index: usize) -> usize {
106        self.offsets()
107            .as_opt::<PrimitiveVTable>()
108            .map(|p| match_each_native_ptype!(p.ptype(), |P| { p.as_slice::<P>()[index].as_() }))
109            .unwrap_or_else(|| {
110                self.offsets()
111                    .scalar_at(index)
112                    .unwrap_or_else(|err| {
113                        vortex_panic!(err, "Failed to get offset at index: {}", index)
114                    })
115                    .as_ref()
116                    .try_into()
117                    .vortex_expect("Failed to convert offset to usize")
118            })
119    }
120
121    // TODO: fetches the elements at index
122    pub fn elements_at(&self, index: usize) -> VortexResult<ArrayRef> {
123        let start = self.offset_at(index);
124        let end = self.offset_at(index + 1);
125        self.elements().slice(start, end)
126    }
127
128    // TODO: fetches the offsets of the array ignoring validity
129    pub fn offsets(&self) -> &ArrayRef {
130        &self.offsets
131    }
132
133    // TODO: fetches the elements of the array ignoring validity
134    pub fn elements(&self) -> &ArrayRef {
135        &self.elements
136    }
137}
138
139impl ArrayVTable<ListVTable> for ListVTable {
140    fn len(array: &ListArray) -> usize {
141        array.offsets.len().saturating_sub(1)
142    }
143
144    fn dtype(array: &ListArray) -> &DType {
145        &array.dtype
146    }
147
148    fn stats(array: &ListArray) -> StatsSetRef<'_> {
149        array.stats_set.to_ref(array.as_ref())
150    }
151}
152
153impl OperationsVTable<ListVTable> for ListVTable {
154    fn slice(array: &ListArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
155        Ok(ListArray::try_new(
156            array.elements().clone(),
157            array.offsets().slice(start, stop + 1)?,
158            array.validity().slice(start, stop)?,
159        )?
160        .into_array())
161    }
162
163    fn scalar_at(array: &ListArray, index: usize) -> VortexResult<Scalar> {
164        let elem = array.elements_at(index)?;
165        let scalars: Vec<Scalar> = (0..elem.len()).map(|i| elem.scalar_at(i)).try_collect()?;
166
167        Ok(Scalar::list(
168            Arc::new(elem.dtype().clone()),
169            scalars,
170            array.dtype().nullability(),
171        ))
172    }
173}
174
175impl CanonicalVTable<ListVTable> for ListVTable {
176    fn canonicalize(array: &ListArray) -> VortexResult<Canonical> {
177        Ok(Canonical::List(array.clone()))
178    }
179}
180
181impl ValidityHelper for ListArray {
182    fn validity(&self) -> &Validity {
183        &self.validity
184    }
185}
186
187#[cfg(feature = "test-harness")]
188impl ListArray {
189    /// This is a convenience method to create a list array from an iterator of iterators.
190    /// This method is slow however since each element is first converted to a scalar and then
191    /// appended to the array.
192    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
193        iter: I,
194        dtype: Arc<DType>,
195    ) -> VortexResult<ArrayRef>
196    where
197        I::Item: IntoIterator,
198        <I::Item as IntoIterator>::Item: Into<Scalar>,
199    {
200        let iter = iter.into_iter();
201        let mut builder = ListBuilder::<O>::with_capacity(
202            dtype.clone(),
203            vortex_dtype::Nullability::NonNullable,
204            iter.size_hint().0,
205        );
206
207        for v in iter {
208            let elem = Scalar::list(
209                dtype.clone(),
210                v.into_iter().map(|x| x.into()).collect_vec(),
211                dtype.nullability(),
212            );
213            builder.append_value(elem.as_list())?
214        }
215        Ok(builder.finish())
216    }
217
218    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
219        iter: I,
220        dtype: Arc<DType>,
221    ) -> VortexResult<ArrayRef>
222    where
223        T: IntoIterator,
224        T::Item: Into<Scalar>,
225    {
226        let iter = iter.into_iter();
227        let mut builder = ListBuilder::<O>::with_capacity(
228            dtype.clone(),
229            vortex_dtype::Nullability::Nullable,
230            iter.size_hint().0,
231        );
232
233        for v in iter {
234            if let Some(v) = v {
235                let elem = Scalar::list(
236                    dtype.clone(),
237                    v.into_iter().map(|x| x.into()).collect_vec(),
238                    dtype.nullability(),
239                );
240                builder.append_value(elem.as_list())?
241            } else {
242                builder.append_null()
243            }
244        }
245        Ok(builder.finish())
246    }
247}
248
249#[cfg(test)]
250mod test {
251    use std::sync::Arc;
252
253    use arrow_buffer::BooleanBuffer;
254    use vortex_dtype::Nullability;
255    use vortex_dtype::PType::I32;
256    use vortex_mask::Mask;
257    use vortex_scalar::Scalar;
258
259    use crate::arrays::PrimitiveArray;
260    use crate::arrays::list::ListArray;
261    use crate::compute::filter;
262    use crate::validity::Validity;
263    use crate::{Array, IntoArray};
264
265    #[test]
266    fn test_empty_list_array() {
267        let elements = PrimitiveArray::empty::<u32>(Nullability::NonNullable);
268        let offsets = PrimitiveArray::from_iter([0]);
269        let validity = Validity::AllValid;
270
271        let list =
272            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
273
274        assert_eq!(0, list.len());
275    }
276
277    #[test]
278    fn test_simple_list_array() {
279        let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
280        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
281        let validity = Validity::AllValid;
282
283        let list =
284            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
285
286        assert_eq!(
287            Scalar::list(
288                Arc::new(I32.into()),
289                vec![1.into(), 2.into()],
290                Nullability::Nullable
291            ),
292            list.scalar_at(0).unwrap()
293        );
294        assert_eq!(
295            Scalar::list(
296                Arc::new(I32.into()),
297                vec![3.into(), 4.into()],
298                Nullability::Nullable
299            ),
300            list.scalar_at(1).unwrap()
301        );
302        assert_eq!(
303            Scalar::list(Arc::new(I32.into()), vec![5.into()], Nullability::Nullable),
304            list.scalar_at(2).unwrap()
305        );
306    }
307
308    #[test]
309    fn test_simple_list_array_from_iter() {
310        let elements = PrimitiveArray::from_iter([1i32, 2, 3]);
311        let offsets = PrimitiveArray::from_iter([0, 2, 3]);
312        let validity = Validity::NonNullable;
313
314        let list =
315            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
316
317        let list_from_iter =
318            ListArray::from_iter_slow::<u32, _>(vec![vec![1i32, 2], vec![3]], Arc::new(I32.into()))
319                .unwrap();
320
321        assert_eq!(list.len(), list_from_iter.len());
322        assert_eq!(
323            list.scalar_at(0).unwrap(),
324            list_from_iter.scalar_at(0).unwrap()
325        );
326        assert_eq!(
327            list.scalar_at(1).unwrap(),
328            list_from_iter.scalar_at(1).unwrap()
329        );
330    }
331
332    #[test]
333    fn test_simple_list_filter() {
334        let elements = PrimitiveArray::from_option_iter([None, Some(2), Some(3), Some(4), Some(5)]);
335        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
336        let validity = Validity::AllValid;
337
338        let list = ListArray::try_new(elements.into_array(), offsets.into_array(), validity)
339            .unwrap()
340            .into_array();
341
342        let filtered = filter(
343            &list,
344            &Mask::from(BooleanBuffer::from(vec![false, true, true])),
345        );
346
347        assert!(filtered.is_ok())
348    }
349}