1mod to_arrow;
2
3use itertools::Itertools;
4use vortex_dtype::DType;
5use vortex_error::{VortexExpect, VortexResult, vortex_bail};
6use vortex_mask::Mask;
7use vortex_scalar::Scalar;
8
9use crate::arrays::StructEncoding;
10use crate::arrays::struct_::StructArray;
11use crate::compute::{
12 CastFn, FilterKernel, FilterKernelAdapter, IsConstantFn, IsConstantOpts, KernelRef, MaskFn,
13 MinMaxFn, MinMaxResult, ScalarAtFn, SliceFn, TakeFn, ToArrowFn, UncompressedSizeFn, filter,
14 is_constant_opts, scalar_at, slice, take, try_cast, uncompressed_size,
15};
16use crate::vtable::ComputeVTable;
17use crate::{Array, ArrayComputeImpl, ArrayRef, ArrayVisitor};
18
19impl ArrayComputeImpl for StructArray {
20 const FILTER: Option<KernelRef> = FilterKernelAdapter(StructEncoding).some();
21}
22
23impl ComputeVTable for StructEncoding {
24 fn cast_fn(&self) -> Option<&dyn CastFn<&dyn Array>> {
25 Some(self)
26 }
27
28 fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
29 Some(self)
30 }
31
32 fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
33 Some(self)
34 }
35
36 fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
37 Some(self)
38 }
39
40 fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
41 Some(self)
42 }
43
44 fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
45 Some(self)
46 }
47
48 fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn<&dyn Array>> {
49 Some(self)
50 }
51
52 fn min_max_fn(&self) -> Option<&dyn MinMaxFn<&dyn Array>> {
53 Some(self)
54 }
55
56 fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
57 Some(self)
58 }
59}
60
61impl CastFn<&StructArray> for StructEncoding {
62 fn cast(&self, array: &StructArray, dtype: &DType) -> VortexResult<ArrayRef> {
63 let Some(target_sdtype) = dtype.as_struct() else {
64 vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
65 };
66
67 let source_sdtype = array
68 .dtype()
69 .as_struct()
70 .vortex_expect("struct array must have struct dtype");
71
72 if target_sdtype.names() != source_sdtype.names() {
73 vortex_bail!("cannot cast {} to {}", array.dtype(), dtype);
74 }
75
76 let validity = array
77 .validity()
78 .clone()
79 .cast_nullability(dtype.nullability())?;
80
81 StructArray::try_new(
82 target_sdtype.names().clone(),
83 array
84 .fields()
85 .iter()
86 .zip_eq(target_sdtype.fields())
87 .map(|(field, dtype)| try_cast(field, &dtype))
88 .try_collect()?,
89 array.len(),
90 validity,
91 )
92 .map(|a| a.into_array())
93 }
94}
95
96impl ScalarAtFn<&StructArray> for StructEncoding {
97 fn scalar_at(&self, array: &StructArray, index: usize) -> VortexResult<Scalar> {
98 Ok(Scalar::struct_(
99 array.dtype().clone(),
100 array
101 .fields()
102 .iter()
103 .map(|field| scalar_at(field, index))
104 .try_collect()?,
105 ))
106 }
107}
108
109impl TakeFn<&StructArray> for StructEncoding {
110 fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
111 StructArray::try_new_with_dtype(
112 array
113 .fields()
114 .iter()
115 .map(|field| take(field, indices))
116 .try_collect()?,
117 array.struct_dtype().clone(),
118 indices.len(),
119 array.validity().take(indices)?,
120 )
121 .map(|a| a.into_array())
122 }
123}
124
125impl SliceFn<&StructArray> for StructEncoding {
126 fn slice(&self, array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
127 let fields = array
128 .fields()
129 .iter()
130 .map(|field| slice(field, start, stop))
131 .try_collect()?;
132 StructArray::try_new_with_dtype(
133 fields,
134 array.struct_dtype().clone(),
135 stop - start,
136 array.validity().slice(start, stop)?,
137 )
138 .map(|a| a.into_array())
139 }
140}
141
142impl FilterKernel for StructEncoding {
143 fn filter(&self, array: &StructArray, mask: &Mask) -> VortexResult<ArrayRef> {
144 let validity = array.validity().filter(mask)?;
145
146 let fields: Vec<ArrayRef> = array
147 .fields()
148 .iter()
149 .map(|field| filter(field, mask))
150 .try_collect()?;
151 let length = fields
152 .first()
153 .map(|a| a.len())
154 .unwrap_or_else(|| mask.true_count());
155
156 StructArray::try_new_with_dtype(fields, array.struct_dtype().clone(), length, validity)
157 .map(|a| a.into_array())
158 }
159}
160
161impl MaskFn<&StructArray> for StructEncoding {
162 fn mask(&self, array: &StructArray, filter_mask: Mask) -> VortexResult<ArrayRef> {
163 let validity = array.validity().mask(&filter_mask)?;
164
165 StructArray::try_new_with_dtype(
166 array.fields().to_vec(),
167 array.struct_dtype().clone(),
168 array.len(),
169 validity,
170 )
171 .map(|a| a.into_array())
172 }
173}
174
175impl MinMaxFn<&StructArray> for StructEncoding {
176 fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
177 Ok(None)
179 }
180}
181
182impl IsConstantFn<&StructArray> for StructEncoding {
183 fn is_constant(
184 &self,
185 array: &StructArray,
186 opts: &IsConstantOpts,
187 ) -> VortexResult<Option<bool>> {
188 let children = array.children();
189 if children.is_empty() {
190 return Ok(None);
191 }
192
193 for child in children.iter() {
194 if !is_constant_opts(child, opts)? {
195 return Ok(Some(false));
196 }
197 }
198
199 Ok(Some(true))
200 }
201}
202
203impl UncompressedSizeFn<&StructArray> for StructEncoding {
204 fn uncompressed_size(&self, array: &StructArray) -> VortexResult<usize> {
205 let mut sum = array.validity().uncompressed_size();
206 for child in array.children().into_iter() {
207 sum += uncompressed_size(child.as_ref())?;
208 }
209
210 Ok(sum)
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use std::sync::Arc;
217
218 use vortex_buffer::buffer;
219 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructDType};
220 use vortex_mask::Mask;
221
222 use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
223 use crate::compute::conformance::mask::test_mask;
224 use crate::compute::{filter, try_cast};
225 use crate::validity::Validity;
226 use crate::{Array, IntoArray as _};
227
228 #[test]
229 fn filter_empty_struct() {
230 let struct_arr =
231 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
232 let mask = vec![
233 false, true, false, true, false, true, false, true, false, true,
234 ];
235 let filtered = filter(&struct_arr, &Mask::from_iter(mask)).unwrap();
236 assert_eq!(filtered.len(), 5);
237 }
238
239 #[test]
240 fn filter_empty_struct_with_empty_filter() {
241 let struct_arr =
242 StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
243 let filtered = filter(&struct_arr, &Mask::from_iter::<[bool; 0]>([])).unwrap();
244 assert_eq!(filtered.len(), 0);
245 }
246
247 #[test]
248 fn test_mask_empty_struct() {
249 test_mask(&StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable).unwrap());
250 }
251
252 #[test]
253 fn test_mask_complex_struct() {
254 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
255 let ys = VarBinArray::from_iter(
256 [Some("a"), Some("b"), None, Some("d"), None],
257 DType::Utf8(Nullability::Nullable),
258 )
259 .into_array();
260 let zs =
261 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
262
263 test_mask(
264 &StructArray::try_new(
265 ["xs".into(), "ys".into(), "zs".into()].into(),
266 vec![
267 StructArray::try_new(
268 ["left".into(), "right".into()].into(),
269 vec![xs.clone(), xs],
270 5,
271 Validity::NonNullable,
272 )
273 .unwrap()
274 .into_array(),
275 ys,
276 zs,
277 ],
278 5,
279 Validity::NonNullable,
280 )
281 .unwrap(),
282 );
283 }
284
285 #[test]
286 fn test_cast_empty_struct() {
287 let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
288 .unwrap()
289 .into_array();
290 let non_nullable_dtype = DType::Struct(
291 Arc::from(StructDType::new([].into(), vec![])),
292 Nullability::NonNullable,
293 );
294 let casted = try_cast(&array, &non_nullable_dtype).unwrap();
295 assert_eq!(casted.dtype(), &non_nullable_dtype);
296
297 let nullable_dtype = DType::Struct(
298 Arc::from(StructDType::new([].into(), vec![])),
299 Nullability::Nullable,
300 );
301 let casted = try_cast(&array, &nullable_dtype).unwrap();
302 assert_eq!(casted.dtype(), &nullable_dtype);
303 }
304
305 #[test]
306 fn test_cast_cannot_change_name_order() {
307 let array = StructArray::try_new(
308 ["xs".into(), "ys".into(), "zs".into()].into(),
309 vec![
310 buffer![1u8].into_array(),
311 buffer![1u8].into_array(),
312 buffer![1u8].into_array(),
313 ],
314 1,
315 Validity::NonNullable,
316 )
317 .unwrap();
318
319 let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable);
320
321 let result = try_cast(
322 &array,
323 &DType::Struct(
324 Arc::from(StructDType::new(
325 FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
326 vec![tu8.clone(), tu8.clone(), tu8],
327 )),
328 Nullability::NonNullable,
329 ),
330 );
331 assert!(
332 result.as_ref().is_err_and(|err| {
333 err.to_string()
334 .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
335 }),
336 "{:?}",
337 result
338 );
339 }
340
341 #[test]
342 fn test_cast_complex_struct() {
343 let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
344 let ys = VarBinArray::from_vec(
345 vec!["a", "b", "c", "d", "e"],
346 DType::Utf8(Nullability::Nullable),
347 );
348 let zs = BoolArray::new(
349 BooleanBuffer::from_iter([true, true, false, false, true]),
350 Validity::AllValid,
351 );
352 let fully_nullable_array = StructArray::try_new(
353 ["xs".into(), "ys".into(), "zs".into()].into(),
354 vec![
355 StructArray::try_new(
356 ["left".into(), "right".into()].into(),
357 vec![xs.to_array(), xs.to_array()],
358 5,
359 Validity::AllValid,
360 )
361 .unwrap()
362 .into_array(),
363 ys.into_array(),
364 zs.into_array(),
365 ],
366 5,
367 Validity::AllValid,
368 )
369 .unwrap()
370 .into_array();
371
372 let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
373 let casted = try_cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
374 assert_eq!(casted.dtype(), &top_level_non_nullable);
375
376 let non_null_xs_right = DType::Struct(
377 Arc::from(StructDType::new(
378 ["xs".into(), "ys".into(), "zs".into()].into(),
379 vec![
380 DType::Struct(
381 Arc::from(StructDType::new(
382 ["left".into(), "right".into()].into(),
383 vec![
384 DType::Primitive(PType::I64, Nullability::NonNullable),
385 DType::Primitive(PType::I64, Nullability::Nullable),
386 ],
387 )),
388 Nullability::Nullable,
389 ),
390 DType::Utf8(Nullability::Nullable),
391 DType::Bool(Nullability::Nullable),
392 ],
393 )),
394 Nullability::Nullable,
395 );
396 let casted = try_cast(&fully_nullable_array, &non_null_xs_right).unwrap();
397 assert_eq!(casted.dtype(), &non_null_xs_right);
398
399 let non_null_xs = DType::Struct(
400 Arc::from(StructDType::new(
401 ["xs".into(), "ys".into(), "zs".into()].into(),
402 vec![
403 DType::Struct(
404 Arc::from(StructDType::new(
405 ["left".into(), "right".into()].into(),
406 vec![
407 DType::Primitive(PType::I64, Nullability::Nullable),
408 DType::Primitive(PType::I64, Nullability::Nullable),
409 ],
410 )),
411 Nullability::NonNullable,
412 ),
413 DType::Utf8(Nullability::Nullable),
414 DType::Bool(Nullability::Nullable),
415 ],
416 )),
417 Nullability::Nullable,
418 );
419 let casted = try_cast(&fully_nullable_array, &non_null_xs).unwrap();
420 assert_eq!(casted.dtype(), &non_null_xs);
421 }
422}