1mod to_arrow;
2
3use itertools::Itertools;
4use vortex_dtype::DType;
5use vortex_error::{VortexExpect, VortexResult, vortex_bail};
6use vortex_mask::Mask;
7use vortex_scalar::Scalar;
8
9use crate::arrays::StructEncoding;
10use crate::arrays::struct_::StructArray;
11use crate::compute::{
12 CastFn, FilterFn, IsConstantFn, IsConstantOpts, MaskFn, MinMaxFn, MinMaxResult, ScalarAtFn,
13 SliceFn, TakeFn, ToArrowFn, UncompressedSizeFn, filter, is_constant_opts, scalar_at, slice,
14 take, try_cast, uncompressed_size,
15};
16use crate::variants::StructArrayTrait;
17use crate::vtable::ComputeVTable;
18use crate::{Array, ArrayRef, ArrayVisitor};
19
20impl ComputeVTable for StructEncoding {
21 fn cast_fn(&self) -> Option<&dyn CastFn<&dyn Array>> {
22 Some(self)
23 }
24
25 fn filter_fn(&self) -> Option<&dyn FilterFn<&dyn Array>> {
26 Some(self)
27 }
28
29 fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
30 Some(self)
31 }
32
33 fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
34 Some(self)
35 }
36
37 fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
38 Some(self)
39 }
40
41 fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
42 Some(self)
43 }
44
45 fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
46 Some(self)
47 }
48
49 fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn<&dyn Array>> {
50 Some(self)
51 }
52
53 fn min_max_fn(&self) -> Option<&dyn MinMaxFn<&dyn Array>> {
54 Some(self)
55 }
56
57 fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
58 Some(self)
59 }
60}
61
62impl CastFn<&StructArray> for StructEncoding {
63 fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<ArrayRef> {
64 let Some(target_sdtype) = dtype.as_struct() else {
65 vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
66 };
67
68 let source_sdtype = array
69 .dtype()
70 .as_struct()
71 .vortex_expect("struct array must have struct dtype");
72
73 if target_sdtype.names() != source_sdtype.names() {
74 vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
75 }
76
77 let validity = array
78 .validity()
79 .clone()
80 .cast_nullability(dtype.nullability())?;
81
82 StructArray::try_new(
83 target_sdtype.names().clone(),
84 array
85 .fields()
86 .iter()
87 .zip_eq(target_sdtype.fields())
88 .map(|(field, dtype)| try_cast(field, &dtype))
89 .try_collect()?,
90 array.len(),
91 validity,
92 )
93 .map(|a| a.into_array())
94 }
95}
96
97impl ScalarAtFn<&StructArray> for StructEncoding {
98 fn scalar_at(&self, array: &StructArray, index: usize) -> VortexResult<Scalar> {
99 Ok(Scalar::struct_(
100 array.dtype().clone(),
101 array
102 .fields()
103 .iter()
104 .map(|field| scalar_at(field, index))
105 .try_collect()?,
106 ))
107 }
108}
109
110impl TakeFn<&StructArray> for StructEncoding {
111 fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
112 StructArray::try_new(
113 array.names().clone(),
114 array
115 .fields()
116 .iter()
117 .map(|field| take(field, indices))
118 .try_collect()?,
119 indices.len(),
120 array.validity().take(indices)?,
121 )
122 .map(|a| a.into_array())
123 }
124}
125
126impl SliceFn<&StructArray> for StructEncoding {
127 fn slice(&self, array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
128 let fields = array
129 .fields()
130 .iter()
131 .map(|field| slice(field, start, stop))
132 .try_collect()?;
133 StructArray::try_new(
134 array.names().clone(),
135 fields,
136 stop - start,
137 array.validity().slice(start, stop)?,
138 )
139 .map(|a| a.into_array())
140 }
141}
142
143impl FilterFn<&StructArray> for StructEncoding {
144 fn filter(&self, array: &StructArray, mask: &Mask) -> VortexResult<ArrayRef> {
145 let validity = array.validity().filter(mask)?;
146
147 let fields: Vec<ArrayRef> = array
148 .fields()
149 .iter()
150 .map(|field| filter(field, mask))
151 .try_collect()?;
152 let length = fields
153 .first()
154 .map(|a| a.len())
155 .unwrap_or_else(|| mask.true_count());
156
157 StructArray::try_new(array.names().clone(), fields, length, validity)
158 .map(|a| a.into_array())
159 }
160}
161
162impl MaskFn<&StructArray> for StructEncoding {
163 fn mask(&self, array: &StructArray, filter_mask: Mask) -> VortexResult<ArrayRef> {
164 let validity = array.validity().mask(&filter_mask)?;
165
166 StructArray::try_new(
167 array.names().clone(),
168 array.fields().to_vec(),
169 array.len(),
170 validity,
171 )
172 .map(|a| a.into_array())
173 }
174}
175
176impl MinMaxFn<&StructArray> for StructEncoding {
177 fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
178 Ok(None)
180 }
181}
182
183impl IsConstantFn<&StructArray> for StructEncoding {
184 fn is_constant(
185 &self,
186 array: &StructArray,
187 opts: &IsConstantOpts,
188 ) -> VortexResult<Option<bool>> {
189 let children = array.children();
190 if children.is_empty() {
191 return Ok(None);
192 }
193
194 for child in children.iter() {
195 if !is_constant_opts(child, opts)? {
196 return Ok(Some(false));
197 }
198 }
199
200 Ok(Some(true))
201 }
202}
203
204impl UncompressedSizeFn<&StructArray> for StructEncoding {
205 fn uncompressed_size(&self, array: &StructArray) -> VortexResult<usize> {
206 let mut sum = array.validity().uncompressed_size();
207 for child in array.children().into_iter() {
208 sum += uncompressed_size(child.as_ref())?;
209 }
210
211 Ok(sum)
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use std::sync::Arc;
218
219 use vortex_buffer::buffer;
220 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructDType};
221 use vortex_mask::Mask;
222
223 use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
224 use crate::compute::test_harness::test_mask;
225 use crate::compute::{filter, try_cast};
226 use crate::validity::Validity;
227 use crate::{Array, IntoArray as _};
228
229 #[test]
230 fn filter_empty_struct() {
231 let struct_arr =
232 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
233 let mask = vec![
234 false, true, false, true, false, true, false, true, false, true,
235 ];
236 let filtered = filter(&struct_arr, &Mask::from_iter(mask)).unwrap();
237 assert_eq!(filtered.len(), 5);
238 }
239
240 #[test]
241 fn filter_empty_struct_with_empty_filter() {
242 let struct_arr =
243 StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
244 let filtered = filter(&struct_arr, &Mask::from_iter::<[bool; 0]>([])).unwrap();
245 assert_eq!(filtered.len(), 0);
246 }
247
248 #[test]
249 fn test_mask_empty_struct() {
250 test_mask(&StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable).unwrap());
251 }
252
253 #[test]
254 fn test_mask_complex_struct() {
255 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
256 let ys = VarBinArray::from_iter(
257 [Some("a"), Some("b"), None, Some("d"), None],
258 DType::Utf8(Nullability::Nullable),
259 )
260 .into_array();
261 let zs =
262 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
263
264 test_mask(
265 &StructArray::try_new(
266 ["xs".into(), "ys".into(), "zs".into()].into(),
267 vec![
268 StructArray::try_new(
269 ["left".into(), "right".into()].into(),
270 vec![xs.clone(), xs],
271 5,
272 Validity::NonNullable,
273 )
274 .unwrap()
275 .into_array(),
276 ys,
277 zs,
278 ],
279 5,
280 Validity::NonNullable,
281 )
282 .unwrap(),
283 );
284 }
285
286 #[test]
287 fn test_cast_empty_struct() {
288 let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
289 .unwrap()
290 .into_array();
291 let non_nullable_dtype = DType::Struct(
292 Arc::from(StructDType::new([].into(), vec![])),
293 Nullability::NonNullable,
294 );
295 let casted = try_cast(&array, &non_nullable_dtype).unwrap();
296 assert_eq!(casted.dtype(), &non_nullable_dtype);
297
298 let nullable_dtype = DType::Struct(
299 Arc::from(StructDType::new([].into(), vec![])),
300 Nullability::Nullable,
301 );
302 let casted = try_cast(&array, &nullable_dtype).unwrap();
303 assert_eq!(casted.dtype(), &nullable_dtype);
304 }
305
306 #[test]
307 fn test_cast_cannot_change_name_order() {
308 let array = StructArray::try_new(
309 ["xs".into(), "ys".into(), "zs".into()].into(),
310 vec![
311 buffer![1u8].into_array(),
312 buffer![1u8].into_array(),
313 buffer![1u8].into_array(),
314 ],
315 1,
316 Validity::NonNullable,
317 )
318 .unwrap();
319
320 let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable);
321
322 let result = try_cast(
323 &array,
324 &DType::Struct(
325 Arc::from(StructDType::new(
326 FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
327 vec![tu8.clone(), tu8.clone(), tu8],
328 )),
329 Nullability::NonNullable,
330 ),
331 );
332 assert!(
333 result.as_ref().is_err_and(|err| {
334 err.to_string()
335 .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
336 }),
337 "{:?}",
338 result
339 );
340 }
341
342 #[test]
343 fn test_cast_complex_struct() {
344 let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
345 let ys = VarBinArray::from_vec(
346 vec!["a", "b", "c", "d", "e"],
347 DType::Utf8(Nullability::Nullable),
348 );
349 let zs = BoolArray::new(
350 BooleanBuffer::from_iter([true, true, false, false, true]),
351 Validity::AllValid,
352 );
353 let fully_nullable_array = StructArray::try_new(
354 ["xs".into(), "ys".into(), "zs".into()].into(),
355 vec![
356 StructArray::try_new(
357 ["left".into(), "right".into()].into(),
358 vec![xs.to_array(), xs.to_array()],
359 5,
360 Validity::AllValid,
361 )
362 .unwrap()
363 .into_array(),
364 ys.into_array(),
365 zs.into_array(),
366 ],
367 5,
368 Validity::AllValid,
369 )
370 .unwrap()
371 .into_array();
372
373 let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
374 let casted = try_cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
375 assert_eq!(casted.dtype(), &top_level_non_nullable);
376
377 let non_null_xs_right = DType::Struct(
378 Arc::from(StructDType::new(
379 ["xs".into(), "ys".into(), "zs".into()].into(),
380 vec![
381 DType::Struct(
382 Arc::from(StructDType::new(
383 ["left".into(), "right".into()].into(),
384 vec![
385 DType::Primitive(PType::I64, Nullability::NonNullable),
386 DType::Primitive(PType::I64, Nullability::Nullable),
387 ],
388 )),
389 Nullability::Nullable,
390 ),
391 DType::Utf8(Nullability::Nullable),
392 DType::Bool(Nullability::Nullable),
393 ],
394 )),
395 Nullability::Nullable,
396 );
397 let casted = try_cast(&fully_nullable_array, &non_null_xs_right).unwrap();
398 assert_eq!(casted.dtype(), &non_null_xs_right);
399
400 let non_null_xs = DType::Struct(
401 Arc::from(StructDType::new(
402 ["xs".into(), "ys".into(), "zs".into()].into(),
403 vec![
404 DType::Struct(
405 Arc::from(StructDType::new(
406 ["left".into(), "right".into()].into(),
407 vec![
408 DType::Primitive(PType::I64, Nullability::Nullable),
409 DType::Primitive(PType::I64, Nullability::Nullable),
410 ],
411 )),
412 Nullability::NonNullable,
413 ),
414 DType::Utf8(Nullability::Nullable),
415 DType::Bool(Nullability::Nullable),
416 ],
417 )),
418 Nullability::Nullable,
419 );
420 let casted = try_cast(&fully_nullable_array, &non_null_xs).unwrap();
421 assert_eq!(casted.dtype(), &non_null_xs);
422 }
423}