vortex_array/arrays/chunked/vtable/
canonical.rs1use vortex_buffer::BufferMut;
5use vortex_dtype::{DType, Nullability, PType, StructFields};
6use vortex_error::VortexExpect;
7
8use crate::arrays::{ChunkedArray, ChunkedVTable, ListViewArray, PrimitiveArray, StructArray};
9use crate::builders::{ArrayBuilder, builder_with_capacity};
10use crate::compute::cast;
11use crate::validity::Validity;
12use crate::vtable::CanonicalVTable;
13use crate::{Array, ArrayRef, Canonical, IntoArray, ToCanonical};
14
15impl CanonicalVTable<ChunkedVTable> for ChunkedVTable {
16 fn canonicalize(array: &ChunkedArray) -> Canonical {
17 if array.nchunks() == 0 {
18 return Canonical::empty(array.dtype());
19 }
20 if array.nchunks() == 1 {
21 return array.chunks()[0].to_canonical();
22 }
23
24 match array.dtype() {
25 DType::Struct(struct_dtype, _) => {
26 let struct_array = pack_struct_chunks(
27 array.chunks(),
28 Validity::copy_from_array(array.as_ref()),
29 struct_dtype,
30 );
31 Canonical::Struct(struct_array)
32 }
33 DType::List(elem_dtype, _) => Canonical::List(swizzle_list_chunks(
34 array.chunks(),
35 Validity::copy_from_array(array.as_ref()),
36 elem_dtype,
37 )),
38 _ => {
39 let mut builder = builder_with_capacity(array.dtype(), array.len());
40 array.append_to_builder(builder.as_mut());
41 builder.finish_into_canonical()
42 }
43 }
44 }
45
46 fn append_to_builder(array: &ChunkedArray, builder: &mut dyn ArrayBuilder) {
47 for chunk in array.chunks() {
48 chunk.append_to_builder(builder);
49 }
50 }
51}
52
53fn pack_struct_chunks(
58 chunks: &[ArrayRef],
59 validity: Validity,
60 struct_dtype: &StructFields,
61) -> StructArray {
62 let len = chunks.iter().map(|chunk| chunk.len()).sum();
63 let mut field_arrays = Vec::new();
64
65 for (field_idx, field_dtype) in struct_dtype.fields().enumerate() {
66 let field_chunks = chunks
67 .iter()
68 .map(|c| {
69 c.to_struct()
70 .fields()
71 .get(field_idx)
72 .vortex_expect("Invalid field index")
73 .to_array()
74 })
75 .collect::<Vec<_>>();
76
77 let field_array = unsafe { ChunkedArray::new_unchecked(field_chunks, field_dtype.clone()) };
80 field_arrays.push(field_array.into_array());
81 }
82
83 unsafe { StructArray::new_unchecked(field_arrays, struct_dtype.clone(), len, validity) }
86}
87
88fn swizzle_list_chunks(
94 chunks: &[ArrayRef],
95 validity: Validity,
96 elem_dtype: &DType,
97) -> ListViewArray {
98 let len: usize = chunks.iter().map(|c| c.len()).sum();
99
100 assert_eq!(
101 chunks[0]
102 .dtype()
103 .as_list_element_opt()
104 .vortex_expect("DType was somehow not a list")
105 .as_ref(),
106 elem_dtype
107 );
108
109 let mut list_elements_chunks = Vec::with_capacity(chunks.len());
113 let mut num_elements = 0;
114
115 let mut offsets = BufferMut::<u64>::with_capacity(len);
120 let mut sizes = BufferMut::<u64>::with_capacity(len);
121
122 for chunk in chunks {
123 let chunk_array = chunk.to_listview();
124
125 list_elements_chunks.push(chunk_array.elements().clone());
127
128 let offsets_arr = cast(
130 chunk_array.offsets(),
131 &DType::Primitive(PType::U64, Nullability::NonNullable),
132 )
133 .vortex_expect("Must be able to fit array offsets in u64")
134 .to_primitive();
135
136 let sizes_arr = cast(
137 chunk_array.sizes(),
138 &DType::Primitive(PType::U64, Nullability::NonNullable),
139 )
140 .vortex_expect("Must be able to fit array offsets in u64")
141 .to_primitive();
142
143 let offsets_slice = offsets_arr.as_slice::<u64>();
144 let sizes_slice = sizes_arr.as_slice::<u64>();
145
146 offsets.extend(offsets_slice.iter().map(|o| o + num_elements));
148 sizes.extend(sizes_slice);
149
150 num_elements += chunk_array.elements().len() as u64;
151 }
152
153 let chunked_elements =
155 unsafe { ChunkedArray::new_unchecked(list_elements_chunks, elem_dtype.clone()) }
156 .into_array();
157
158 let offsets = PrimitiveArray::new(offsets.freeze(), Validity::NonNullable).into_array();
159 let sizes = PrimitiveArray::new(sizes.freeze(), Validity::NonNullable).into_array();
160
161 unsafe { ListViewArray::new_unchecked(chunked_elements, offsets, sizes, validity) }
167}
168
169#[cfg(test)]
170mod tests {
171 use std::sync::Arc;
172
173 use vortex_buffer::buffer;
174 use vortex_dtype::DType::{List, Primitive};
175 use vortex_dtype::Nullability::NonNullable;
176 use vortex_dtype::PType::I32;
177
178 use crate::accessor::ArrayAccessor;
179 use crate::arrays::{ChunkedArray, ListArray, StructArray, VarBinViewArray};
180 use crate::validity::Validity;
181 use crate::{IntoArray, ToCanonical};
182
183 #[test]
184 pub fn pack_nested_structs() {
185 let struct_array = StructArray::try_new(
186 ["a"].into(),
187 vec![VarBinViewArray::from_iter_str(["foo", "bar", "baz", "quak"]).into_array()],
188 4,
189 Validity::NonNullable,
190 )
191 .unwrap();
192 let dtype = struct_array.dtype().clone();
193 let chunked = ChunkedArray::try_new(
194 vec![
195 ChunkedArray::try_new(vec![struct_array.to_array()], dtype.clone())
196 .unwrap()
197 .into_array(),
198 ],
199 dtype,
200 )
201 .unwrap()
202 .into_array();
203 let canonical_struct = chunked.to_struct();
204 let canonical_varbin = canonical_struct.fields()[0].to_varbinview();
205 let original_varbin = struct_array.fields()[0].to_varbinview();
206 let orig_values = original_varbin
207 .with_iterator(|it| it.map(|a| a.map(|v| v.to_vec())).collect::<Vec<_>>())
208 .unwrap();
209 let canon_values = canonical_varbin
210 .with_iterator(|it| it.map(|a| a.map(|v| v.to_vec())).collect::<Vec<_>>())
211 .unwrap();
212 assert_eq!(orig_values, canon_values);
213 }
214
215 #[test]
216 pub fn pack_nested_lists() {
217 let l1 = ListArray::try_new(
218 buffer![1, 2, 3, 4].into_array(),
219 buffer![0, 3].into_array(),
220 Validity::NonNullable,
221 )
222 .unwrap();
223
224 let l2 = ListArray::try_new(
225 buffer![5, 6].into_array(),
226 buffer![0, 2].into_array(),
227 Validity::NonNullable,
228 )
229 .unwrap();
230
231 let chunked_list = ChunkedArray::try_new(
232 vec![l1.clone().into_array(), l2.clone().into_array()],
233 List(Arc::new(Primitive(I32, NonNullable)), NonNullable),
234 );
235
236 let canon_values = chunked_list.unwrap().to_listview();
237
238 assert_eq!(l1.scalar_at(0), canon_values.scalar_at(0));
239 assert_eq!(l2.scalar_at(0), canon_values.scalar_at(1));
240 }
241}