1mod cast;
2mod mask;
3mod to_arrow;
4
5use itertools::Itertools;
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8use vortex_scalar::Scalar;
9
10use crate::arrays::StructEncoding;
11use crate::arrays::struct_::StructArray;
12use crate::compute::{
13 FilterKernel, FilterKernelAdapter, IsConstantFn, IsConstantOpts, MinMaxFn, MinMaxResult,
14 ScalarAtFn, SliceFn, TakeFn, ToArrowFn, UncompressedSizeFn, filter, is_constant_opts,
15 scalar_at, slice, take, uncompressed_size,
16};
17use crate::vtable::ComputeVTable;
18use crate::{Array, ArrayRef, ArrayVisitor, register_kernel};
19
20impl ComputeVTable for StructEncoding {
21 fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
22 Some(self)
23 }
24
25 fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
26 Some(self)
27 }
28
29 fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
30 Some(self)
31 }
32
33 fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
34 Some(self)
35 }
36
37 fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn<&dyn Array>> {
38 Some(self)
39 }
40
41 fn min_max_fn(&self) -> Option<&dyn MinMaxFn<&dyn Array>> {
42 Some(self)
43 }
44
45 fn uncompressed_size_fn(&self) -> Option<&dyn UncompressedSizeFn<&dyn Array>> {
46 Some(self)
47 }
48}
49
50impl ScalarAtFn<&StructArray> for StructEncoding {
51 fn scalar_at(&self, array: &StructArray, index: usize) -> VortexResult<Scalar> {
52 Ok(Scalar::struct_(
53 array.dtype().clone(),
54 array
55 .fields()
56 .iter()
57 .map(|field| scalar_at(field, index))
58 .try_collect()?,
59 ))
60 }
61}
62
63impl TakeFn<&StructArray> for StructEncoding {
64 fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
65 StructArray::try_new_with_dtype(
66 array
67 .fields()
68 .iter()
69 .map(|field| take(field, indices))
70 .try_collect()?,
71 array.struct_dtype().clone(),
72 indices.len(),
73 array.validity().take(indices)?,
74 )
75 .map(|a| a.into_array())
76 }
77}
78
79impl SliceFn<&StructArray> for StructEncoding {
80 fn slice(&self, array: &StructArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
81 let fields = array
82 .fields()
83 .iter()
84 .map(|field| slice(field, start, stop))
85 .try_collect()?;
86 StructArray::try_new_with_dtype(
87 fields,
88 array.struct_dtype().clone(),
89 stop - start,
90 array.validity().slice(start, stop)?,
91 )
92 .map(|a| a.into_array())
93 }
94}
95
96impl FilterKernel for StructEncoding {
97 fn filter(&self, array: &StructArray, mask: &Mask) -> VortexResult<ArrayRef> {
98 let validity = array.validity().filter(mask)?;
99
100 let fields: Vec<ArrayRef> = array
101 .fields()
102 .iter()
103 .map(|field| filter(field, mask))
104 .try_collect()?;
105 let length = fields
106 .first()
107 .map(|a| a.len())
108 .unwrap_or_else(|| mask.true_count());
109
110 StructArray::try_new_with_dtype(fields, array.struct_dtype().clone(), length, validity)
111 .map(|a| a.into_array())
112 }
113}
114
115register_kernel!(FilterKernelAdapter(StructEncoding).lift());
116
117impl MinMaxFn<&StructArray> for StructEncoding {
118 fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
119 Ok(None)
121 }
122}
123
124impl IsConstantFn<&StructArray> for StructEncoding {
125 fn is_constant(
126 &self,
127 array: &StructArray,
128 opts: &IsConstantOpts,
129 ) -> VortexResult<Option<bool>> {
130 let children = array.children();
131 if children.is_empty() {
132 return Ok(None);
133 }
134
135 for child in children.iter() {
136 if !is_constant_opts(child, opts)? {
137 return Ok(Some(false));
138 }
139 }
140
141 Ok(Some(true))
142 }
143}
144
145impl UncompressedSizeFn<&StructArray> for StructEncoding {
146 fn uncompressed_size(&self, array: &StructArray) -> VortexResult<usize> {
147 let mut sum = array.validity().uncompressed_size();
148 for child in array.children().into_iter() {
149 sum += uncompressed_size(child.as_ref())?;
150 }
151
152 Ok(sum)
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use std::sync::Arc;
159
160 use vortex_buffer::buffer;
161 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructDType};
162 use vortex_mask::Mask;
163
164 use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
165 use crate::compute::conformance::mask::test_mask;
166 use crate::compute::{cast, filter};
167 use crate::validity::Validity;
168 use crate::{Array, IntoArray as _};
169
170 #[test]
171 fn filter_empty_struct() {
172 let struct_arr =
173 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
174 let mask = vec![
175 false, true, false, true, false, true, false, true, false, true,
176 ];
177 let filtered = filter(&struct_arr, &Mask::from_iter(mask)).unwrap();
178 assert_eq!(filtered.len(), 5);
179 }
180
181 #[test]
182 fn filter_empty_struct_with_empty_filter() {
183 let struct_arr =
184 StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
185 let filtered = filter(&struct_arr, &Mask::from_iter::<[bool; 0]>([])).unwrap();
186 assert_eq!(filtered.len(), 0);
187 }
188
189 #[test]
190 fn test_mask_empty_struct() {
191 test_mask(&StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable).unwrap());
192 }
193
194 #[test]
195 fn test_mask_complex_struct() {
196 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
197 let ys = VarBinArray::from_iter(
198 [Some("a"), Some("b"), None, Some("d"), None],
199 DType::Utf8(Nullability::Nullable),
200 )
201 .into_array();
202 let zs =
203 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
204
205 test_mask(
206 &StructArray::try_new(
207 ["xs".into(), "ys".into(), "zs".into()].into(),
208 vec![
209 StructArray::try_new(
210 ["left".into(), "right".into()].into(),
211 vec![xs.clone(), xs],
212 5,
213 Validity::NonNullable,
214 )
215 .unwrap()
216 .into_array(),
217 ys,
218 zs,
219 ],
220 5,
221 Validity::NonNullable,
222 )
223 .unwrap(),
224 );
225 }
226
227 #[test]
228 fn test_cast_empty_struct() {
229 let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
230 .unwrap()
231 .into_array();
232 let non_nullable_dtype = DType::Struct(
233 Arc::from(StructDType::new([].into(), vec![])),
234 Nullability::NonNullable,
235 );
236 let casted = cast(&array, &non_nullable_dtype).unwrap();
237 assert_eq!(casted.dtype(), &non_nullable_dtype);
238
239 let nullable_dtype = DType::Struct(
240 Arc::from(StructDType::new([].into(), vec![])),
241 Nullability::Nullable,
242 );
243 let casted = cast(&array, &nullable_dtype).unwrap();
244 assert_eq!(casted.dtype(), &nullable_dtype);
245 }
246
247 #[test]
248 fn test_cast_cannot_change_name_order() {
249 let array = StructArray::try_new(
250 ["xs".into(), "ys".into(), "zs".into()].into(),
251 vec![
252 buffer![1u8].into_array(),
253 buffer![1u8].into_array(),
254 buffer![1u8].into_array(),
255 ],
256 1,
257 Validity::NonNullable,
258 )
259 .unwrap();
260
261 let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable);
262
263 let result = cast(
264 &array,
265 &DType::Struct(
266 Arc::from(StructDType::new(
267 FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
268 vec![tu8.clone(), tu8.clone(), tu8],
269 )),
270 Nullability::NonNullable,
271 ),
272 );
273 assert!(
274 result.as_ref().is_err_and(|err| {
275 err.to_string()
276 .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
277 }),
278 "{:?}",
279 result
280 );
281 }
282
283 #[test]
284 fn test_cast_complex_struct() {
285 let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
286 let ys = VarBinArray::from_vec(
287 vec!["a", "b", "c", "d", "e"],
288 DType::Utf8(Nullability::Nullable),
289 );
290 let zs = BoolArray::new(
291 BooleanBuffer::from_iter([true, true, false, false, true]),
292 Validity::AllValid,
293 );
294 let fully_nullable_array = StructArray::try_new(
295 ["xs".into(), "ys".into(), "zs".into()].into(),
296 vec![
297 StructArray::try_new(
298 ["left".into(), "right".into()].into(),
299 vec![xs.to_array(), xs.to_array()],
300 5,
301 Validity::AllValid,
302 )
303 .unwrap()
304 .into_array(),
305 ys.into_array(),
306 zs.into_array(),
307 ],
308 5,
309 Validity::AllValid,
310 )
311 .unwrap()
312 .into_array();
313
314 let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
315 let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
316 assert_eq!(casted.dtype(), &top_level_non_nullable);
317
318 let non_null_xs_right = DType::Struct(
319 Arc::from(StructDType::new(
320 ["xs".into(), "ys".into(), "zs".into()].into(),
321 vec![
322 DType::Struct(
323 Arc::from(StructDType::new(
324 ["left".into(), "right".into()].into(),
325 vec![
326 DType::Primitive(PType::I64, Nullability::NonNullable),
327 DType::Primitive(PType::I64, Nullability::Nullable),
328 ],
329 )),
330 Nullability::Nullable,
331 ),
332 DType::Utf8(Nullability::Nullable),
333 DType::Bool(Nullability::Nullable),
334 ],
335 )),
336 Nullability::Nullable,
337 );
338 let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
339 assert_eq!(casted.dtype(), &non_null_xs_right);
340
341 let non_null_xs = DType::Struct(
342 Arc::from(StructDType::new(
343 ["xs".into(), "ys".into(), "zs".into()].into(),
344 vec![
345 DType::Struct(
346 Arc::from(StructDType::new(
347 ["left".into(), "right".into()].into(),
348 vec![
349 DType::Primitive(PType::I64, Nullability::Nullable),
350 DType::Primitive(PType::I64, Nullability::Nullable),
351 ],
352 )),
353 Nullability::NonNullable,
354 ),
355 DType::Utf8(Nullability::Nullable),
356 DType::Bool(Nullability::Nullable),
357 ],
358 )),
359 Nullability::Nullable,
360 );
361 let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
362 assert_eq!(casted.dtype(), &non_null_xs);
363 }
364}