vortex_array/arrays/list/compute/
mod.rs

1mod to_arrow;
2
3use std::sync::Arc;
4
5use itertools::Itertools;
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8use vortex_scalar::Scalar;
9
10use crate::arrays::{ListArray, ListEncoding};
11use crate::compute::{
12    IsConstantFn, IsConstantOpts, MaskFn, MinMaxFn, MinMaxResult, ScalarAtFn, SliceFn, ToArrowFn,
13    UncompressedSizeFn, scalar_at, slice, uncompressed_size,
14};
15use crate::vtable::ComputeVTable;
16use crate::{Array, ArrayRef};
17
18impl ComputeVTable for ListEncoding {
19    fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
20        Some(self)
21    }
22
23    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
24        Some(self)
25    }
26
27    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
28        Some(self)
29    }
30
31    fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn<&dyn Array>> {
32        Some(self)
33    }
34
35    fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
36        Some(self)
37    }
38
39    fn min_max_fn(&self) -> Option<&dyn MinMaxFn<&dyn Array>> {
40        Some(self)
41    }
42
43    fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
44        Some(self)
45    }
46}
47
48impl ScalarAtFn<&ListArray> for ListEncoding {
49    fn scalar_at(&self, array: &ListArray, index: usize) -> VortexResult<Scalar> {
50        let elem = array.elements_at(index)?;
51        let scalars: Vec<Scalar> = (0..elem.len()).map(|i| scalar_at(&elem, i)).try_collect()?;
52
53        Ok(Scalar::list(
54            Arc::new(elem.dtype().clone()),
55            scalars,
56            array.dtype().nullability(),
57        ))
58    }
59}
60
61impl SliceFn<&ListArray> for ListEncoding {
62    fn slice(&self, array: &ListArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
63        Ok(ListArray::try_new(
64            array.elements().clone(),
65            slice(array.offsets(), start, stop + 1)?,
66            array.validity().slice(start, stop)?,
67        )?
68        .into_array())
69    }
70}
71
72impl MaskFn<&ListArray> for ListEncoding {
73    fn mask(&self, array: &ListArray, mask: Mask) -> VortexResult<ArrayRef> {
74        ListArray::try_new(
75            array.elements().clone(),
76            array.offsets().clone(),
77            array.validity().mask(&mask)?,
78        )
79        .map(|a| a.into_array())
80    }
81}
82
83impl MinMaxFn<&ListArray> for ListEncoding {
84    fn min_max(&self, _array: &ListArray) -> VortexResult<Option<MinMaxResult>> {
85        // TODO(joe): Implement list min max
86        Ok(None)
87    }
88}
89
90impl IsConstantFn<&ListArray> for ListEncoding {
91    fn is_constant(
92        &self,
93        _array: &ListArray,
94        _opts: &IsConstantOpts,
95    ) -> VortexResult<Option<bool>> {
96        // TODO(adam): Do we want to fallback to arrow here?
97        Ok(None)
98    }
99}
100
101impl UncompressedSizeFn<&ListArray> for ListEncoding {
102    fn uncompressed_size(&self, array: &ListArray) -> VortexResult<usize> {
103        let size = uncompressed_size(array.elements())? + uncompressed_size(array.offsets())?;
104        Ok(size + array.validity().uncompressed_size())
105    }
106}
107
108#[cfg(test)]
109mod test {
110    use crate::array::Array;
111    use crate::arrays::{ListArray, PrimitiveArray};
112    use crate::compute::test_harness::test_mask;
113    use crate::validity::Validity;
114
115    #[test]
116    fn test_mask_list() {
117        let elements = PrimitiveArray::from_iter(0..35);
118        let offsets = PrimitiveArray::from_iter([0, 5, 11, 18, 26, 35]);
119        let validity = Validity::AllValid;
120        let array =
121            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
122
123        test_mask(&array);
124    }
125}