vortex_array/arrays/list/
mod.rs

1mod compute;
2mod serde;
3
4use std::sync::Arc;
5
6#[cfg(feature = "test-harness")]
7use itertools::Itertools;
8use num_traits::{AsPrimitive, PrimInt};
9use serde::ListMetadata;
10use vortex_dtype::{DType, NativePType, match_each_native_ptype};
11use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_panic};
12use vortex_mask::Mask;
13use vortex_scalar::Scalar;
14
15use crate::arrays::PrimitiveArray;
16#[cfg(feature = "test-harness")]
17use crate::builders::{ArrayBuilder, ListBuilder};
18use crate::compute::{scalar_at, slice};
19use crate::stats::{ArrayStats, StatsSetRef};
20use crate::validity::Validity;
21use crate::variants::{ListArrayTrait, PrimitiveArrayTrait};
22use crate::vtable::VTableRef;
23use crate::{
24    Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
25    ArrayVariantsImpl, Canonical, Encoding, RkyvMetadata, TryFromArrayRef,
26};
27
28#[derive(Clone, Debug)]
29pub struct ListArray {
30    dtype: DType,
31    elements: ArrayRef,
32    offsets: ArrayRef,
33    validity: Validity,
34    stats_set: ArrayStats,
35}
36
37pub struct ListEncoding;
38impl Encoding for ListEncoding {
39    type Array = ListArray;
40    type Metadata = RkyvMetadata<ListMetadata>;
41}
42
43pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
44
45impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
46
47// A list is valid if the:
48// - offsets start at a value in elements
49// - offsets are sorted
50// - the final offset points to an element in the elements list, pointing to zero
51//   if elements are empty.
52// - final_offset >= start_offset
53// - The size of the validity is the size-1 of the offset array
54
55impl ListArray {
56    pub fn try_new(
57        elements: ArrayRef,
58        offsets: ArrayRef,
59        validity: Validity,
60    ) -> VortexResult<Self> {
61        let nullability = validity.nullability();
62
63        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
64            vortex_bail!(
65                "Expected offsets to be an non-nullable integer type, got {:?}",
66                offsets.dtype()
67            );
68        }
69
70        if offsets.is_empty() {
71            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
72        }
73
74        Ok(Self {
75            dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
76            elements,
77            offsets,
78            validity,
79            stats_set: Default::default(),
80        })
81    }
82
83    pub fn validity(&self) -> &Validity {
84        &self.validity
85    }
86
87    // TODO: merge logic with varbin
88    // TODO(ngates): should return a result if it requires canonicalizing offsets
89    pub fn offset_at(&self, index: usize) -> usize {
90        PrimitiveArray::try_from_array(self.offsets().clone())
91            .ok()
92            .map(|p| {
93                match_each_native_ptype!(p.ptype(), |$P| {
94                    p.as_slice::<$P>()[index].as_()
95                })
96            })
97            .unwrap_or_else(|| {
98                scalar_at(self.offsets(), index)
99                    .unwrap_or_else(|err| {
100                        vortex_panic!(err, "Failed to get offset at index: {}", index)
101                    })
102                    .as_ref()
103                    .try_into()
104                    .vortex_expect("Failed to convert offset to usize")
105            })
106    }
107
108    // TODO: fetches the elements at index
109    pub fn elements_at(&self, index: usize) -> VortexResult<ArrayRef> {
110        let start = self.offset_at(index);
111        let end = self.offset_at(index + 1);
112        slice(self.elements(), start, end)
113    }
114
115    // TODO: fetches the offsets of the array ignoring validity
116    pub fn offsets(&self) -> &ArrayRef {
117        &self.offsets
118    }
119
120    // TODO: fetches the elements of the array ignoring validity
121    pub fn elements(&self) -> &ArrayRef {
122        &self.elements
123    }
124}
125
126impl ArrayImpl for ListArray {
127    type Encoding = ListEncoding;
128
129    fn _len(&self) -> usize {
130        self.offsets.len().saturating_sub(1)
131    }
132
133    fn _dtype(&self) -> &DType {
134        &self.dtype
135    }
136
137    fn _vtable(&self) -> VTableRef {
138        VTableRef::new_ref(&ListEncoding)
139    }
140
141    fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
142        let elements = children[0].clone();
143        let offsets = children[1].clone();
144        let validity = if self.validity().is_array() {
145            Validity::Array(children[2].clone())
146        } else {
147            self.validity().clone()
148        };
149
150        Self::try_new(elements, offsets, validity)
151    }
152}
153
154impl ArrayStatisticsImpl for ListArray {
155    fn _stats_ref(&self) -> StatsSetRef<'_> {
156        self.stats_set.to_ref(self)
157    }
158}
159
160impl ArrayVariantsImpl for ListArray {
161    fn _as_list_typed(&self) -> Option<&dyn ListArrayTrait> {
162        Some(self)
163    }
164}
165
166impl ListArrayTrait for ListArray {}
167
168impl ArrayCanonicalImpl for ListArray {
169    fn _to_canonical(&self) -> VortexResult<Canonical> {
170        Ok(Canonical::List(self.clone()))
171    }
172}
173
174impl ArrayValidityImpl for ListArray {
175    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
176        self.validity.is_valid(index)
177    }
178
179    fn _all_valid(&self) -> VortexResult<bool> {
180        self.validity.all_valid()
181    }
182
183    fn _all_invalid(&self) -> VortexResult<bool> {
184        self.validity.all_invalid()
185    }
186
187    fn _validity_mask(&self) -> VortexResult<Mask> {
188        self.validity.to_mask(self.len())
189    }
190}
191
192#[cfg(feature = "test-harness")]
193impl ListArray {
194    /// This is a convenience method to create a list array from an iterator of iterators.
195    /// This method is slow however since each element is first converted to a scalar and then
196    /// appended to the array.
197    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
198        iter: I,
199        dtype: Arc<DType>,
200    ) -> VortexResult<ArrayRef>
201    where
202        I::Item: IntoIterator,
203        <I::Item as IntoIterator>::Item: Into<Scalar>,
204    {
205        let iter = iter.into_iter();
206        let mut builder = ListBuilder::<O>::with_capacity(
207            dtype.clone(),
208            vortex_dtype::Nullability::NonNullable,
209            iter.size_hint().0,
210        );
211
212        for v in iter {
213            let elem = Scalar::list(
214                dtype.clone(),
215                v.into_iter().map(|x| x.into()).collect_vec(),
216                dtype.nullability(),
217            );
218            builder.append_value(elem.as_list())?
219        }
220        Ok(builder.finish())
221    }
222
223    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
224        iter: I,
225        dtype: Arc<DType>,
226    ) -> VortexResult<ArrayRef>
227    where
228        T: IntoIterator,
229        T::Item: Into<Scalar>,
230    {
231        let iter = iter.into_iter();
232        let mut builder = ListBuilder::<O>::with_capacity(
233            dtype.clone(),
234            vortex_dtype::Nullability::Nullable,
235            iter.size_hint().0,
236        );
237
238        for v in iter {
239            if let Some(v) = v {
240                let elem = Scalar::list(
241                    dtype.clone(),
242                    v.into_iter().map(|x| x.into()).collect_vec(),
243                    dtype.nullability(),
244                );
245                builder.append_value(elem.as_list())?
246            } else {
247                builder.append_null()
248            }
249        }
250        Ok(builder.finish())
251    }
252}
253
254#[cfg(test)]
255mod test {
256    use std::sync::Arc;
257
258    use arrow_buffer::BooleanBuffer;
259    use vortex_dtype::Nullability;
260    use vortex_dtype::PType::I32;
261    use vortex_mask::Mask;
262    use vortex_scalar::Scalar;
263
264    use crate::array::Array;
265    use crate::arrays::PrimitiveArray;
266    use crate::arrays::list::ListArray;
267    use crate::compute::{filter, scalar_at};
268    use crate::validity::Validity;
269
270    #[test]
271    fn test_empty_list_array() {
272        let elements = PrimitiveArray::empty::<u32>(Nullability::NonNullable);
273        let offsets = PrimitiveArray::from_iter([0]);
274        let validity = Validity::AllValid;
275
276        let list =
277            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
278
279        assert_eq!(0, list.len());
280    }
281
282    #[test]
283    fn test_simple_list_array() {
284        let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
285        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
286        let validity = Validity::AllValid;
287
288        let list =
289            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
290
291        assert_eq!(
292            Scalar::list(
293                Arc::new(I32.into()),
294                vec![1.into(), 2.into()],
295                Nullability::Nullable
296            ),
297            scalar_at(&list, 0).unwrap()
298        );
299        assert_eq!(
300            Scalar::list(
301                Arc::new(I32.into()),
302                vec![3.into(), 4.into()],
303                Nullability::Nullable
304            ),
305            scalar_at(&list, 1).unwrap()
306        );
307        assert_eq!(
308            Scalar::list(Arc::new(I32.into()), vec![5.into()], Nullability::Nullable),
309            scalar_at(&list, 2).unwrap()
310        );
311    }
312
313    #[test]
314    fn test_simple_list_array_from_iter() {
315        let elements = PrimitiveArray::from_iter([1i32, 2, 3]);
316        let offsets = PrimitiveArray::from_iter([0, 2, 3]);
317        let validity = Validity::NonNullable;
318
319        let list =
320            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
321
322        let list_from_iter =
323            ListArray::from_iter_slow::<u32, _>(vec![vec![1i32, 2], vec![3]], Arc::new(I32.into()))
324                .unwrap();
325
326        assert_eq!(list.len(), list_from_iter.len());
327        assert_eq!(
328            scalar_at(&list, 0).unwrap(),
329            scalar_at(&list_from_iter, 0).unwrap()
330        );
331        assert_eq!(
332            scalar_at(&list, 1).unwrap(),
333            scalar_at(&list_from_iter, 1).unwrap()
334        );
335    }
336
337    #[test]
338    fn test_simple_list_filter() {
339        let elements = PrimitiveArray::from_option_iter([None, Some(2), Some(3), Some(4), Some(5)]);
340        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
341        let validity = Validity::AllValid;
342
343        let list = ListArray::try_new(elements.into_array(), offsets.into_array(), validity)
344            .unwrap()
345            .into_array();
346
347        let filtered = filter(
348            &list,
349            &Mask::from(BooleanBuffer::from(vec![false, true, true])),
350        );
351
352        assert!(filtered.is_ok())
353    }
354}