vortex_array/arrays/list/compute/
mod.rs

1mod to_arrow;
2
3use std::sync::Arc;
4
5use itertools::Itertools;
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8use vortex_scalar::Scalar;
9
10use crate::arrays::{ListArray, ListEncoding};
11use crate::compute::{
12    IsConstantFn, IsConstantOpts, IsSortedFn, MaskFn, MinMaxFn, MinMaxResult, ScalarAtFn, SliceFn,
13    ToArrowFn, UncompressedSizeFn, scalar_at, slice, uncompressed_size,
14};
15use crate::vtable::ComputeVTable;
16use crate::{Array, ArrayComputeImpl, ArrayRef};
17
18impl ArrayComputeImpl for ListArray {}
19
20impl ComputeVTable for ListEncoding {
21    fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
22        Some(self)
23    }
24
25    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
26        Some(self)
27    }
28
29    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
30        Some(self)
31    }
32
33    fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn<&dyn Array>> {
34        Some(self)
35    }
36
37    fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
38        Some(self)
39    }
40
41    fn min_max_fn(&self) -> Option<&dyn MinMaxFn<&dyn Array>> {
42        Some(self)
43    }
44
45    fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
46        Some(self)
47    }
48
49    fn is_sorted_fn(&self) -> Option<&dyn IsSortedFn<&dyn Array>> {
50        Some(self)
51    }
52}
53
54impl ScalarAtFn<&ListArray> for ListEncoding {
55    fn scalar_at(&self, array: &ListArray, index: usize) -> VortexResult<Scalar> {
56        let elem = array.elements_at(index)?;
57        let scalars: Vec<Scalar> = (0..elem.len()).map(|i| scalar_at(&elem, i)).try_collect()?;
58
59        Ok(Scalar::list(
60            Arc::new(elem.dtype().clone()),
61            scalars,
62            array.dtype().nullability(),
63        ))
64    }
65}
66
67impl SliceFn<&ListArray> for ListEncoding {
68    fn slice(&self, array: &ListArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
69        Ok(ListArray::try_new(
70            array.elements().clone(),
71            slice(array.offsets(), start, stop + 1)?,
72            array.validity().slice(start, stop)?,
73        )?
74        .into_array())
75    }
76}
77
78impl MaskFn<&ListArray> for ListEncoding {
79    fn mask(&self, array: &ListArray, mask: Mask) -> VortexResult<ArrayRef> {
80        ListArray::try_new(
81            array.elements().clone(),
82            array.offsets().clone(),
83            array.validity().mask(&mask)?,
84        )
85        .map(|a| a.into_array())
86    }
87}
88
89impl MinMaxFn<&ListArray> for ListEncoding {
90    fn min_max(&self, _array: &ListArray) -> VortexResult<Option<MinMaxResult>> {
91        // TODO(joe): Implement list min max
92        Ok(None)
93    }
94}
95
96impl IsConstantFn<&ListArray> for ListEncoding {
97    fn is_constant(
98        &self,
99        _array: &ListArray,
100        _opts: &IsConstantOpts,
101    ) -> VortexResult<Option<bool>> {
102        // TODO(adam): Do we want to fallback to arrow here?
103        Ok(None)
104    }
105}
106
107impl UncompressedSizeFn<&ListArray> for ListEncoding {
108    fn uncompressed_size(&self, array: &ListArray) -> VortexResult<usize> {
109        let size = uncompressed_size(array.elements())? + uncompressed_size(array.offsets())?;
110        Ok(size + array.validity().uncompressed_size())
111    }
112}
113
114impl IsSortedFn<&ListArray> for ListEncoding {
115    fn is_sorted(&self, _array: &ListArray) -> VortexResult<bool> {
116        Ok(false)
117    }
118
119    fn is_strict_sorted(&self, _array: &ListArray) -> VortexResult<bool> {
120        Ok(false)
121    }
122}
123
124#[cfg(test)]
125mod test {
126    use crate::array::Array;
127    use crate::arrays::{ListArray, PrimitiveArray};
128    use crate::compute::conformance::mask::test_mask;
129    use crate::validity::Validity;
130
131    #[test]
132    fn test_mask_list() {
133        let elements = PrimitiveArray::from_iter(0..35);
134        let offsets = PrimitiveArray::from_iter([0, 5, 11, 18, 26, 35]);
135        let validity = Validity::AllValid;
136        let array =
137            ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap();
138
139        test_mask(&array);
140    }
141}