vortex_array/arrays/list/
mod.rs1mod compute;
2mod serde;
3
4use std::sync::Arc;
5
6#[cfg(feature = "test-harness")]
7use itertools::Itertools;
8use num_traits::{AsPrimitive, PrimInt};
9#[cfg(feature = "test-harness")]
10use vortex_dtype::Nullability::{NonNullable, Nullable};
11use vortex_dtype::{DType, NativePType, match_each_native_ptype};
12use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_panic};
13use vortex_mask::Mask;
14use vortex_scalar::Scalar;
15
16use crate::arrays::PrimitiveArray;
17use crate::arrays::list::serde::ListMetadata;
18#[cfg(feature = "test-harness")]
19use crate::builders::{ArrayBuilder, ListBuilder};
20use crate::compute::{scalar_at, slice};
21use crate::stats::{ArrayStats, StatsSetRef};
22use crate::validity::Validity;
23use crate::variants::{ListArrayTrait, PrimitiveArrayTrait};
24use crate::vtable::{EncodingVTable, StatisticsVTable, VTableRef};
25use crate::{
26 Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
27 ArrayVariantsImpl, Canonical, Encoding, EncodingId, RkyvMetadata, TryFromArrayRef,
28};
29
30#[derive(Clone, Debug)]
31pub struct ListArray {
32 dtype: DType,
33 elements: ArrayRef,
34 offsets: ArrayRef,
35 validity: Validity,
36 stats_set: ArrayStats,
37}
38
39pub struct ListEncoding;
40impl Encoding for ListEncoding {
41 type Array = ListArray;
42 type Metadata = RkyvMetadata<ListMetadata>;
43}
44
45impl EncodingVTable for ListEncoding {
46 fn id(&self) -> EncodingId {
47 EncodingId::new_ref("vortex.list")
48 }
49}
50
51pub trait OffsetPType: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
52
53impl<T> OffsetPType for T where T: NativePType + PrimInt + AsPrimitive<usize> + Into<Scalar> {}
54
55impl ListArray {
64 pub fn try_new(
65 elements: ArrayRef,
66 offsets: ArrayRef,
67 validity: Validity,
68 ) -> VortexResult<Self> {
69 let nullability = validity.nullability();
70
71 if !offsets.dtype().is_int() || offsets.dtype().is_nullable() {
72 vortex_bail!(
73 "Expected offsets to be an non-nullable integer type, got {:?}",
74 offsets.dtype()
75 );
76 }
77
78 if offsets.is_empty() {
79 vortex_bail!("Offsets must have at least one element, [0] for an empty list");
80 }
81
82 Ok(Self {
83 dtype: DType::List(Arc::new(elements.dtype().clone()), nullability),
84 elements,
85 offsets,
86 validity,
87 stats_set: Default::default(),
88 })
89 }
90
91 pub fn validity(&self) -> &Validity {
92 &self.validity
93 }
94
95 pub fn offset_at(&self, index: usize) -> usize {
98 PrimitiveArray::try_from_array(self.offsets().clone())
99 .ok()
100 .map(|p| {
101 match_each_native_ptype!(p.ptype(), |$P| {
102 p.as_slice::<$P>()[index].as_()
103 })
104 })
105 .unwrap_or_else(|| {
106 scalar_at(self.offsets(), index)
107 .unwrap_or_else(|err| {
108 vortex_panic!(err, "Failed to get offset at index: {}", index)
109 })
110 .as_ref()
111 .try_into()
112 .vortex_expect("Failed to convert offset to usize")
113 })
114 }
115
116 pub fn elements_at(&self, index: usize) -> VortexResult<ArrayRef> {
118 let start = self.offset_at(index);
119 let end = self.offset_at(index + 1);
120 slice(self.elements(), start, end)
121 }
122
123 pub fn offsets(&self) -> &ArrayRef {
125 &self.offsets
126 }
127
128 pub fn elements(&self) -> &ArrayRef {
130 &self.elements
131 }
132}
133
134impl ArrayImpl for ListArray {
135 type Encoding = ListEncoding;
136
137 fn _len(&self) -> usize {
138 self.offsets.len().saturating_sub(1)
139 }
140
141 fn _dtype(&self) -> &DType {
142 &self.dtype
143 }
144
145 fn _vtable(&self) -> VTableRef {
146 VTableRef::new_ref(&ListEncoding)
147 }
148}
149
150impl ArrayStatisticsImpl for ListArray {
151 fn _stats_ref(&self) -> StatsSetRef<'_> {
152 self.stats_set.to_ref(self)
153 }
154}
155
156impl ArrayVariantsImpl for ListArray {
157 fn _as_list_typed(&self) -> Option<&dyn ListArrayTrait> {
158 Some(self)
159 }
160}
161
162impl ListArrayTrait for ListArray {}
163
164impl ArrayCanonicalImpl for ListArray {
165 fn _to_canonical(&self) -> VortexResult<Canonical> {
166 Ok(Canonical::List(self.clone()))
167 }
168}
169
170impl ArrayValidityImpl for ListArray {
171 fn _is_valid(&self, index: usize) -> VortexResult<bool> {
172 self.validity.is_valid(index)
173 }
174
175 fn _all_valid(&self) -> VortexResult<bool> {
176 self.validity.all_valid()
177 }
178
179 fn _all_invalid(&self) -> VortexResult<bool> {
180 self.validity.all_invalid()
181 }
182
183 fn _validity_mask(&self) -> VortexResult<Mask> {
184 self.validity.to_logical(self.len())
185 }
186}
187
188impl StatisticsVTable<&ListArray> for ListEncoding {}
189
190#[cfg(feature = "test-harness")]
191impl ListArray {
192 pub fn from_iter_slow<O: OffsetPType, I: IntoIterator>(
196 iter: I,
197 dtype: Arc<DType>,
198 ) -> VortexResult<ArrayRef>
199 where
200 I::Item: IntoIterator,
201 <I::Item as IntoIterator>::Item: Into<Scalar>,
202 {
203 let iter = iter.into_iter();
204 let mut builder =
205 ListBuilder::<O>::with_capacity(dtype.clone(), NonNullable, iter.size_hint().0);
206
207 for v in iter {
208 let elem = Scalar::list(
209 dtype.clone(),
210 v.into_iter().map(|x| x.into()).collect_vec(),
211 dtype.nullability(),
212 );
213 builder.append_value(elem.as_list())?
214 }
215 Ok(builder.finish())
216 }
217
218 pub fn from_iter_opt_slow<O: OffsetPType, I: IntoIterator<Item = Option<T>>, T>(
219 iter: I,
220 dtype: Arc<DType>,
221 ) -> VortexResult<ArrayRef>
222 where
223 T: IntoIterator,
224 T::Item: Into<Scalar>,
225 {
226 let iter = iter.into_iter();
227 let mut builder =
228 ListBuilder::<O>::with_capacity(dtype.clone(), Nullable, iter.size_hint().0);
229
230 for v in iter {
231 if let Some(v) = v {
232 let elem = Scalar::list(
233 dtype.clone(),
234 v.into_iter().map(|x| x.into()).collect_vec(),
235 dtype.nullability(),
236 );
237 builder.append_value(elem.as_list())?
238 } else {
239 builder.append_null()
240 }
241 }
242 Ok(builder.finish())
243 }
244}
245
246#[cfg(test)]
247mod test {
248 use std::sync::Arc;
249
250 use arrow_buffer::BooleanBuffer;
251 use vortex_dtype::Nullability;
252 use vortex_dtype::Nullability::NonNullable;
253 use vortex_dtype::PType::I32;
254 use vortex_mask::Mask;
255 use vortex_scalar::Scalar;
256
257 use crate::array::Array;
258 use crate::arrays::PrimitiveArray;
259 use crate::arrays::list::ListArray;
260 use crate::compute::{filter, scalar_at};
261 use crate::validity::Validity;
262
263 #[test]
264 fn test_empty_list_array() {
265 let elements = PrimitiveArray::empty::<u32>(NonNullable);
266 let offsets = PrimitiveArray::from_iter([0]);
267 let validity = Validity::AllValid;
268
269 let list =
270 ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
271
272 assert_eq!(0, list.len());
273 }
274
275 #[test]
276 fn test_simple_list_array() {
277 let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
278 let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
279 let validity = Validity::AllValid;
280
281 let list =
282 ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
283
284 assert_eq!(
285 Scalar::list(
286 Arc::new(I32.into()),
287 vec![1.into(), 2.into()],
288 Nullability::Nullable
289 ),
290 scalar_at(&list, 0).unwrap()
291 );
292 assert_eq!(
293 Scalar::list(
294 Arc::new(I32.into()),
295 vec![3.into(), 4.into()],
296 Nullability::Nullable
297 ),
298 scalar_at(&list, 1).unwrap()
299 );
300 assert_eq!(
301 Scalar::list(Arc::new(I32.into()), vec![5.into()], Nullability::Nullable),
302 scalar_at(&list, 2).unwrap()
303 );
304 }
305
306 #[test]
307 fn test_simple_list_array_from_iter() {
308 let elements = PrimitiveArray::from_iter([1i32, 2, 3]);
309 let offsets = PrimitiveArray::from_iter([0, 2, 3]);
310 let validity = Validity::NonNullable;
311
312 let list =
313 ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
314
315 let list_from_iter =
316 ListArray::from_iter_slow::<u32, _>(vec![vec![1i32, 2], vec![3]], Arc::new(I32.into()))
317 .unwrap();
318
319 assert_eq!(list.len(), list_from_iter.len());
320 assert_eq!(
321 scalar_at(&list, 0).unwrap(),
322 scalar_at(&list_from_iter, 0).unwrap()
323 );
324 assert_eq!(
325 scalar_at(&list, 1).unwrap(),
326 scalar_at(&list_from_iter, 1).unwrap()
327 );
328 }
329
330 #[test]
331 fn test_simple_list_filter() {
332 let elements = PrimitiveArray::from_option_iter([None, Some(2), Some(3), Some(4), Some(5)]);
333 let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]);
334 let validity = Validity::AllValid;
335
336 let list = ListArray::try_new(elements.into_array(), offsets.into_array(), validity)
337 .unwrap()
338 .into_array();
339
340 let filtered = filter(
341 &list,
342 &Mask::from(BooleanBuffer::from(vec![false, true, true])),
343 );
344
345 assert!(filtered.is_ok())
346 }
347}