vortex_array/arrays/list/
mod.rs

1mod compute;
2mod serde;
3
4use std::sync::Arc;
5
6#[cfg(feature = "test-harness")]
7use itertools::Itertools;
8use num_traits::{AsPrimitive, PrimInt};
9#[cfg(feature = "test-harness")]
10use vortex_dtype::Nullability::{NonNullable, Nullable};
11use vortex_dtype::{DType, NativePType, match_each_native_ptype};
12use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_panic};
13use vortex_mask::Mask;
14use vortex_scalar::Scalar;
15
16use crate::arrays::PrimitiveArray;
17use crate::arrays::list::serde::ListMetadata;
18#[cfg(feature = "test-harness")]
19use crate::builders::{ArrayBuilder, ListBuilder};
20use crate::compute::{scalar_at, slice};
21use crate::stats::{ArrayStats, StatsSetRef};
22use crate::validity::Validity;
23use crate::variants::{ListArrayTrait, PrimitiveArrayTrait};
24use crate::vtable::{EncodingVTable, StatisticsVTable, VTableRef};
25use crate::{
26    Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
27    ArrayVariantsImpl, Canonical, Encoding, EncodingId, RkyvMetadata, TryFromArrayRef,
28};
29
30#[derive(Clone, Debug)]
31pub struct ListArray {
32    dtype: DType,
33    elements: ArrayRef,
34    offsets: ArrayRef,
35    validity: Validity,
36    stats_set: ArrayStats,
37}
38
39pub struct ListEncoding;
40impl Encoding for ListEncoding {
41    type Array = ListArray;
42    type Metadata = RkyvMetadata<ListMetadata>;
43}
44
45impl EncodingVTable for ListEncoding {
46    fn id(&self) -> EncodingId {
47        EncodingId::new_ref("vortex.list")
48    }
49}
50
51pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
52
53impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
54
55// A list is valid if the:
56// - offsets start at a value in elements
57// - offsets are sorted
58// - the final offset points to an element in the elements list, pointing to zero
59//   if elements are empty.
60// - final_offset >= start_offset
61// - The size of the validity is the size-1 of the offset array
62
63impl ListArray {
64    pub fn try_new(
65        elements: ArrayRef,
66        offsets: ArrayRef,
67        validity: Validity,
68    ) -> VortexResult<Self> {
69        let nullability = validity.nullability();
70
71        if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
72            vortex_bail!(
73                "Expected offsets to be an non-nullable integer type, got {:?}",
74                offsets.dtype()
75            );
76        }
77
78        if offsets.is_empty() {
79            vortex_bail!("Offsets must have at least one element, [0] for an empty list");
80        }
81
82        Ok(Self {
83            dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
84            elements,
85            offsets,
86            validity,
87            stats_set: Default::default(),
88        })
89    }
90
91    pub fn validity(&self) -> &Validity {
92        &self.validity
93    }
94
95    // TODO: merge logic with varbin
96    // TODO(ngates): should return a result if it requires canonicalizing offsets
97    pub fn offset_at(&self, index: usize) -> usize {
98        PrimitiveArray::try_from_array(self.offsets().clone())
99            .ok()
100            .map(|p| {
101                match_each_native_ptype!(p.ptype(), |$P| {
102                    p.as_slice::<$P>()[index].as_()
103                })
104            })
105            .unwrap_or_else(|| {
106                scalar_at(self.offsets(), index)
107                    .unwrap_or_else(|err| {
108                        vortex_panic!(err, "Failed to get offset at index: {}", index)
109                    })
110                    .as_ref()
111                    .try_into()
112                    .vortex_expect("Failed to convert offset to usize")
113            })
114    }
115
116    // TODO: fetches the elements at index
117    pub fn elements_at(&self, index: usize) -> VortexResult<ArrayRef> {
118        let start = self.offset_at(index);
119        let end = self.offset_at(index + 1);
120        slice(self.elements(), start, end)
121    }
122
123    // TODO: fetches the offsets of the array ignoring validity
124    pub fn offsets(&self) -> &ArrayRef {
125        &self.offsets
126    }
127
128    // TODO: fetches the elements of the array ignoring validity
129    pub fn elements(&self) -> &ArrayRef {
130        &self.elements
131    }
132}
133
134impl ArrayImpl for ListArray {
135    type Encoding = ListEncoding;
136
137    fn _len(&self) -> usize {
138        self.offsets.len().saturating_sub(1)
139    }
140
141    fn _dtype(&self) -> &DType {
142        &self.dtype
143    }
144
145    fn _vtable(&self) -> VTableRef {
146        VTableRef::new_ref(&ListEncoding)
147    }
148}
149
150impl ArrayStatisticsImpl for ListArray {
151    fn _stats_ref(&self) -> StatsSetRef<'_> {
152        self.stats_set.to_ref(self)
153    }
154}
155
156impl ArrayVariantsImpl for ListArray {
157    fn _as_list_typed(&self) -> Option<&dyn ListArrayTrait> {
158        Some(self)
159    }
160}
161
162impl ListArrayTrait for ListArray {}
163
164impl ArrayCanonicalImpl for ListArray {
165    fn _to_canonical(&self) -> VortexResult<Canonical> {
166        Ok(Canonical::List(self.clone()))
167    }
168}
169
170impl ArrayValidityImpl for ListArray {
171    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
172        self.validity.is_valid(index)
173    }
174
175    fn _all_valid(&self) -> VortexResult<bool> {
176        self.validity.all_valid()
177    }
178
179    fn _all_invalid(&self) -> VortexResult<bool> {
180        self.validity.all_invalid()
181    }
182
183    fn _validity_mask(&self) -> VortexResult<Mask> {
184        self.validity.to_logical(self.len())
185    }
186}
187
188impl StatisticsVTable<&ListArray> for ListEncoding {}
189
190#[cfg(feature = "test-harness")]
191impl ListArray {
192    /// This is a convenience method to create a list array from an iterator of iterators.
193    /// This method is slow however since each element is first converted to a scalar and then
194    /// appended to the array.
195    pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
196        iter: I,
197        dtype: Arc<DType>,
198    ) -> VortexResult<ArrayRef>
199    where
200        I::Item: IntoIterator,
201        <I::Item as IntoIterator>::Item: Into<Scalar>,
202    {
203        let iter = iter.into_iter();
204        let mut builder =
205            ListBuilder::<O>::with_capacity(dtype.clone(), NonNullable, iter.size_hint().0);
206
207        for v in iter {
208            let elem = Scalar::list(
209                dtype.clone(),
210                v.into_iter().map(|x| x.into()).collect_vec(),
211                dtype.nullability(),
212            );
213            builder.append_value(elem.as_list())?
214        }
215        Ok(builder.finish())
216    }
217
218    pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
219        iter: I,
220        dtype: Arc<DType>,
221    ) -> VortexResult<ArrayRef>
222    where
223        T: IntoIterator,
224        T::Item: Into<Scalar>,
225    {
226        let iter = iter.into_iter();
227        let mut builder =
228            ListBuilder::<O>::with_capacity(dtype.clone(), Nullable, iter.size_hint().0);
229
230        for v in iter {
231            if let Some(v) = v {
232                let elem = Scalar::list(
233                    dtype.clone(),
234                    v.into_iter().map(|x| x.into()).collect_vec(),
235                    dtype.nullability(),
236                );
237                builder.append_value(elem.as_list())?
238            } else {
239                builder.append_null()
240            }
241        }
242        Ok(builder.finish())
243    }
244}
245
246#[cfg(test)]
247mod test {
248    use std::sync::Arc;
249
250    use arrow_buffer::BooleanBuffer;
251    use vortex_dtype::Nullability;
252    use vortex_dtype::Nullability::NonNullable;
253    use vortex_dtype::PType::I32;
254    use vortex_mask::Mask;
255    use vortex_scalar::Scalar;
256
257    use crate::array::Array;
258    use crate::arrays::PrimitiveArray;
259    use crate::arrays::list::ListArray;
260    use crate::compute::{filter, scalar_at};
261    use crate::validity::Validity;
262
263    #[test]
264    fn test_empty_list_array() {
265        let elements = PrimitiveArray::empty::<u32>(NonNullable);
266        let offsets = PrimitiveArray::from_iter([0]);
267        let validity = Validity::AllValid;
268
269        let list =
270            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
271
272        assert_eq!(0, list.len());
273    }
274
275    #[test]
276    fn test_simple_list_array() {
277        let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
278        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
279        let validity = Validity::AllValid;
280
281        let list =
282            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
283
284        assert_eq!(
285            Scalar::list(
286                Arc::new(I32.into()),
287                vec![1.into(), 2.into()],
288                Nullability::Nullable
289            ),
290            scalar_at(&list, 0).unwrap()
291        );
292        assert_eq!(
293            Scalar::list(
294                Arc::new(I32.into()),
295                vec![3.into(), 4.into()],
296                Nullability::Nullable
297            ),
298            scalar_at(&list, 1).unwrap()
299        );
300        assert_eq!(
301            Scalar::list(Arc::new(I32.into()), vec![5.into()], Nullability::Nullable),
302            scalar_at(&list, 2).unwrap()
303        );
304    }
305
306    #[test]
307    fn test_simple_list_array_from_iter() {
308        let elements = PrimitiveArray::from_iter([1i32, 2, 3]);
309        let offsets = PrimitiveArray::from_iter([0, 2, 3]);
310        let validity = Validity::NonNullable;
311
312        let list =
313            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
314
315        let list_from_iter =
316            ListArray::from_iter_slow::<u32, _>(vec![vec![1i32, 2], vec![3]], Arc::new(I32.into()))
317                .unwrap();
318
319        assert_eq!(list.len(), list_from_iter.len());
320        assert_eq!(
321            scalar_at(&list, 0).unwrap(),
322            scalar_at(&list_from_iter, 0).unwrap()
323        );
324        assert_eq!(
325            scalar_at(&list, 1).unwrap(),
326            scalar_at(&list_from_iter, 1).unwrap()
327        );
328    }
329
330    #[test]
331    fn test_simple_list_filter() {
332        let elements = PrimitiveArray::from_option_iter([None, Some(2), Some(3), Some(4), Some(5)]);
333        let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
334        let validity = Validity::AllValid;
335
336        let list = ListArray::try_new(elements.into_array(), offsets.into_array(), validity)
337            .unwrap()
338            .into_array();
339
340        let filtered = filter(
341            &list,
342            &Mask::from(BooleanBuffer::from(vec![false, true, true])),
343        );
344
345        assert!(filtered.is_ok())
346    }
347}