vortex_array/arrays/list/
mod.rs

1mod compute;
2mod serde;
3
4use std::sync::Arc;
5
6#[cfg(feature = "test-harness")]
7use itertools::Itertools;
8use num_traits::{AsPrimitive, PrimInt};
9use serde::ListMetadata;
10use vortex_dtype::{DType, NativePType, match_each_native_ptype};
11use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_panic};
12use vortex_mask::Mask;
13use vortex_scalar::Scalar;
14
15use crate::arrays::PrimitiveArray;
16#[cfg(feature = "test-harness")]
17use crate::builders::{ArrayBuilder, ListBuilder};
18use crate::compute::{scalar_at, slice};
19use crate::stats::{ArrayStats, StatsSetRef};
20use crate::validity::Validity;
21use crate::variants::{ListArrayTrait, PrimitiveArrayTrait};
22use crate::vtable::VTableRef;
23use crate::{
24    Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
25    ArrayVariantsImpl, Canonical, Encoding, ProstMetadata, TryFromArrayRef,
26};
27
28#[derive(Clone, Debug)]
29pub struct ListArray {
30    dtype: DType,
31    elements: ArrayRef,
32    offsets: ArrayRef,
33    validity: Validity,
34    stats_set: ArrayStats,
35}
36
37#[derive(Debug)]
38pub struct ListEncoding;
39impl Encoding for ListEncoding {
40    type Array = ListArray;
41    type Metadata = ProstMetadata<ListMetadata>;
42}
43
44pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
45
46impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
47
48// A list is valid if the:
49// - offsets start at a value in elements
50// - offsets are sorted
51// - the final offset points to an element in the elements list, pointing to zero
52//   if elements are empty.
53// - final_offset >= start_offset
54// - The size of the validity is the size-1 of the offset array
55
56impl ListArray {
57    pub fn try_new(
58        elements: ArrayRef,
59        offsets: ArrayRef,
60        validity: Validity,
61    ) -> VortexResult<Self> {
62        let nullability = validity.nullability();
63
64        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
65            vortex_bail!(
66                "Expected offsets to be an non-nullable integer type, got {:?}",
67                offsets.dtype()
68            );
69        }
70
71        if offsets.is_empty() {
72            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
73        }
74
75        Ok(Self {
76            dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
77            elements,
78            offsets,
79            validity,
80            stats_set: Default::default(),
81        })
82    }
83
84    pub fn validity(&self) -> &Validity {
85        &self.validity
86    }
87
88    // TODO: merge logic with varbin
89    // TODO(ngates): should return a result if it requires canonicalizing offsets
90    pub fn offset_at(&self, index: usize) -> usize {
91        PrimitiveArray::try_from_array(self.offsets().clone())
92            .ok()
93            .map(|p| {
94                match_each_native_ptype!(p.ptype(), |$P| {
95                    p.as_slice::<$P>()[index].as_()
96                })
97            })
98            .unwrap_or_else(|| {
99                scalar_at(self.offsets(), index)
100                    .unwrap_or_else(|err| {
101                        vortex_panic!(err, "Failed to get offset at index: {}", index)
102                    })
103                    .as_ref()
104                    .try_into()
105                    .vortex_expect("Failed to convert offset to usize")
106            })
107    }
108
109    // TODO: fetches the elements at index
110    pub fn elements_at(&self, index: usize) -> VortexResult<ArrayRef> {
111        let start = self.offset_at(index);
112        let end = self.offset_at(index + 1);
113        slice(self.elements(), start, end)
114    }
115
116    // TODO: fetches the offsets of the array ignoring validity
117    pub fn offsets(&self) -> &ArrayRef {
118        &self.offsets
119    }
120
121    // TODO: fetches the elements of the array ignoring validity
122    pub fn elements(&self) -> &ArrayRef {
123        &self.elements
124    }
125}
126
127impl ArrayImpl for ListArray {
128    type Encoding = ListEncoding;
129
130    fn _len(&self) -> usize {
131        self.offsets.len().saturating_sub(1)
132    }
133
134    fn _dtype(&self) -> &DType {
135        &self.dtype
136    }
137
138    fn _vtable(&self) -> VTableRef {
139        VTableRef::new_ref(&ListEncoding)
140    }
141
142    fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
143        let elements = children[0].clone();
144        let offsets = children[1].clone();
145        let validity = if self.validity().is_array() {
146            Validity::Array(children[2].clone())
147        } else {
148            self.validity().clone()
149        };
150
151        Self::try_new(elements, offsets, validity)
152    }
153}
154
155impl ArrayStatisticsImpl for ListArray {
156    fn _stats_ref(&self) -> StatsSetRef<'_> {
157        self.stats_set.to_ref(self)
158    }
159}
160
161impl ArrayVariantsImpl for ListArray {
162    fn _as_list_typed(&self) -> Option<&dyn ListArrayTrait> {
163        Some(self)
164    }
165}
166
167impl ListArrayTrait for ListArray {}
168
169impl ArrayCanonicalImpl for ListArray {
170    fn _to_canonical(&self) -> VortexResult<Canonical> {
171        Ok(Canonical::List(self.clone()))
172    }
173}
174
175impl ArrayValidityImpl for ListArray {
176    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
177        self.validity.is_valid(index)
178    }
179
180    fn _all_valid(&self) -> VortexResult<bool> {
181        self.validity.all_valid()
182    }
183
184    fn _all_invalid(&self) -> VortexResult<bool> {
185        self.validity.all_invalid()
186    }
187
188    fn _validity_mask(&self) -> VortexResult<Mask> {
189        self.validity.to_mask(self.len())
190    }
191}
192
193#[cfg(feature = "test-harness")]
194impl ListArray {
195    /// This is a convenience method to create a list array from an iterator of iterators.
196    /// This method is slow however since each element is first converted to a scalar and then
197    /// appended to the array.
198    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
199        iter: I,
200        dtype: Arc<DType>,
201    ) -> VortexResult<ArrayRef>
202    where
203        I::Item: IntoIterator,
204        <I::Item as IntoIterator>::Item: Into<Scalar>,
205    {
206        let iter = iter.into_iter();
207        let mut builder = ListBuilder::<O>::with_capacity(
208            dtype.clone(),
209            vortex_dtype::Nullability::NonNullable,
210            iter.size_hint().0,
211        );
212
213        for v in iter {
214            let elem = Scalar::list(
215                dtype.clone(),
216                v.into_iter().map(|x| x.into()).collect_vec(),
217                dtype.nullability(),
218            );
219            builder.append_value(elem.as_list())?
220        }
221        Ok(builder.finish())
222    }
223
224    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
225        iter: I,
226        dtype: Arc<DType>,
227    ) -> VortexResult<ArrayRef>
228    where
229        T: IntoIterator,
230        T::Item: Into<Scalar>,
231    {
232        let iter = iter.into_iter();
233        let mut builder = ListBuilder::<O>::with_capacity(
234            dtype.clone(),
235            vortex_dtype::Nullability::Nullable,
236            iter.size_hint().0,
237        );
238
239        for v in iter {
240            if let Some(v) = v {
241                let elem = Scalar::list(
242                    dtype.clone(),
243                    v.into_iter().map(|x| x.into()).collect_vec(),
244                    dtype.nullability(),
245                );
246                builder.append_value(elem.as_list())?
247            } else {
248                builder.append_null()
249            }
250        }
251        Ok(builder.finish())
252    }
253}
254
255#[cfg(test)]
256mod test {
257    use std::sync::Arc;
258
259    use arrow_buffer::BooleanBuffer;
260    use vortex_dtype::Nullability;
261    use vortex_dtype::PType::I32;
262    use vortex_mask::Mask;
263    use vortex_scalar::Scalar;
264
265    use crate::array::Array;
266    use crate::arrays::PrimitiveArray;
267    use crate::arrays::list::ListArray;
268    use crate::compute::{filter, scalar_at};
269    use crate::validity::Validity;
270
271    #[test]
272    fn test_empty_list_array() {
273        let elements = PrimitiveArray::empty::<u32>(Nullability::NonNullable);
274        let offsets = PrimitiveArray::from_iter([0]);
275        let validity = Validity::AllValid;
276
277        let list =
278            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
279
280        assert_eq!(0, list.len());
281    }
282
283    #[test]
284    fn test_simple_list_array() {
285        let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
286        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
287        let validity = Validity::AllValid;
288
289        let list =
290            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
291
292        assert_eq!(
293            Scalar::list(
294                Arc::new(I32.into()),
295                vec![1.into(), 2.into()],
296                Nullability::Nullable
297            ),
298            scalar_at(&list, 0).unwrap()
299        );
300        assert_eq!(
301            Scalar::list(
302                Arc::new(I32.into()),
303                vec![3.into(), 4.into()],
304                Nullability::Nullable
305            ),
306            scalar_at(&list, 1).unwrap()
307        );
308        assert_eq!(
309            Scalar::list(Arc::new(I32.into()), vec![5.into()], Nullability::Nullable),
310            scalar_at(&list, 2).unwrap()
311        );
312    }
313
314    #[test]
315    fn test_simple_list_array_from_iter() {
316        let elements = PrimitiveArray::from_iter([1i32, 2, 3]);
317        let offsets = PrimitiveArray::from_iter([0, 2, 3]);
318        let validity = Validity::NonNullable;
319
320        let list =
321            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
322
323        let list_from_iter =
324            ListArray::from_iter_slow::<u32, _>(vec![vec![1i32, 2], vec![3]], Arc::new(I32.into()))
325                .unwrap();
326
327        assert_eq!(list.len(), list_from_iter.len());
328        assert_eq!(
329            scalar_at(&list, 0).unwrap(),
330            scalar_at(&list_from_iter, 0).unwrap()
331        );
332        assert_eq!(
333            scalar_at(&list, 1).unwrap(),
334            scalar_at(&list_from_iter, 1).unwrap()
335        );
336    }
337
338    #[test]
339    fn test_simple_list_filter() {
340        let elements = PrimitiveArray::from_option_iter([None, Some(2), Some(3), Some(4), Some(5)]);
341        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
342        let validity = Validity::AllValid;
343
344        let list = ListArray::try_new(elements.into_array(), offsets.into_array(), validity)
345            .unwrap()
346            .into_array();
347
348        let filtered = filter(
349            &list,
350            &Mask::from(BooleanBuffer::from(vec![false, true, true])),
351        );
352
353        assert!(filtered.is_ok())
354    }
355}