1mod cast;
5mod filter;
6mod mask;
7
8use itertools::Itertools;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_error::VortexResult;
11use vortex_scalar::Scalar;
12
13use crate::arrays::StructVTable;
14use crate::arrays::struct_::StructArray;
15use crate::compute::{
16 IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
17 MinMaxResult, TakeKernel, TakeKernelAdapter, fill_null, is_constant_opts, take,
18};
19use crate::validity::Validity;
20use crate::vtable::ValidityHelper;
21use crate::{Array, ArrayRef, IntoArray, register_kernel};
22
23impl TakeKernel for StructVTable {
24 fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
25 if array.is_empty() {
28 return StructArray::try_new_with_dtype(
29 array.fields().to_vec(),
30 array.struct_fields().clone(),
31 indices.len(),
32 Validity::AllInvalid,
33 )
34 .map(StructArray::into_array);
35 }
36 let inner_indices = &fill_null(
38 indices,
39 &Scalar::default_value(indices.dtype().with_nullability(NonNullable)),
40 )?;
41 StructArray::try_new_with_dtype(
42 array
43 .fields()
44 .iter()
45 .map(|field| take(field, inner_indices))
46 .try_collect()?,
47 array.struct_fields().clone(),
48 indices.len(),
49 array.validity().take(indices)?,
50 )
51 .map(|a| a.into_array())
52 }
53}
54
55register_kernel!(TakeKernelAdapter(StructVTable).lift());
56
57impl MinMaxKernel for StructVTable {
58 fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
59 Ok(None)
61 }
62}
63
64register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
65
66impl IsConstantKernel for StructVTable {
67 fn is_constant(
68 &self,
69 array: &StructArray,
70 opts: &IsConstantOpts,
71 ) -> VortexResult<Option<bool>> {
72 let children = array.children();
73 if children.is_empty() {
74 return Ok(Some(true));
75 }
76
77 for child in children.iter() {
78 match is_constant_opts(child, opts)? {
79 None => return Ok(None),
81 Some(false) => return Ok(Some(false)),
82 Some(true) => {}
83 }
84 }
85
86 Ok(Some(true))
87 }
88}
89
90register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
91
92#[cfg(test)]
93mod tests {
94 use Nullability::{NonNullable, Nullable};
95 use vortex_buffer::buffer;
96 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
97 use vortex_error::VortexUnwrap;
98 use vortex_mask::Mask;
99 use vortex_scalar::Scalar;
100
101 use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
102 use crate::compute::conformance::mask::test_mask;
103 use crate::compute::{cast, filter, is_constant, take};
104 use crate::validity::Validity;
105 use crate::{Array, IntoArray as _};
106
107 #[test]
108 fn filter_empty_struct() {
109 let struct_arr =
110 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
111 let mask = vec![
112 false, true, false, true, false, true, false, true, false, true,
113 ];
114 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
115 assert_eq!(filtered.len(), 5);
116 }
117
118 #[test]
119 fn take_empty_struct() {
120 let struct_arr =
121 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
122 let indices = PrimitiveArray::from_option_iter([Some(1), None]);
123 let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
124 assert_eq!(taken.len(), 2);
125
126 assert_eq!(
127 taken.scalar_at(0).unwrap(),
128 Scalar::struct_(
129 DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable),
130 vec![]
131 )
132 );
133 assert_eq!(
134 taken.scalar_at(1).unwrap(),
135 Scalar::null(DType::Struct(
136 StructFields::new(FieldNames::default(), vec![]),
137 Nullable
138 ))
139 );
140 }
141
142 #[test]
143 fn take_field_struct() {
144 let struct_arr =
145 StructArray::from_fields(&[("a", PrimitiveArray::from_iter(0..10).to_array())])
146 .unwrap();
147 let indices = PrimitiveArray::from_option_iter([Some(1), None]);
148 let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
149 assert_eq!(taken.len(), 2);
150
151 assert_eq!(
152 taken.scalar_at(0).unwrap(),
153 Scalar::struct_(
154 struct_arr.dtype().union_nullability(Nullable),
155 vec![Scalar::primitive(1, NonNullable)],
156 )
157 );
158 assert_eq!(
159 taken.scalar_at(1).unwrap(),
160 Scalar::null(struct_arr.dtype().union_nullability(Nullable),)
161 );
162 }
163
164 #[test]
165 fn filter_empty_struct_with_empty_filter() {
166 let struct_arr =
167 StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
168 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
169 assert_eq!(filtered.len(), 0);
170 }
171
172 #[test]
173 fn test_mask_empty_struct() {
174 test_mask(
175 StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
176 .unwrap()
177 .as_ref(),
178 );
179 }
180
181 #[test]
182 fn test_mask_complex_struct() {
183 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
184 let ys = VarBinArray::from_iter(
185 [Some("a"), Some("b"), None, Some("d"), None],
186 DType::Utf8(Nullable),
187 )
188 .into_array();
189 let zs =
190 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
191
192 test_mask(
193 StructArray::try_new(
194 ["xs", "ys", "zs"].into(),
195 vec![
196 StructArray::try_new(
197 ["left", "right"].into(),
198 vec![xs.clone(), xs],
199 5,
200 Validity::NonNullable,
201 )
202 .unwrap()
203 .into_array(),
204 ys,
205 zs,
206 ],
207 5,
208 Validity::NonNullable,
209 )
210 .unwrap()
211 .as_ref(),
212 );
213 }
214
215 #[test]
216 fn test_cast_empty_struct() {
217 let array = StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
218 .unwrap()
219 .into_array();
220 let non_nullable_dtype = DType::Struct(
221 StructFields::new(FieldNames::default(), vec![]),
222 NonNullable,
223 );
224 let casted = cast(&array, &non_nullable_dtype).unwrap();
225 assert_eq!(casted.dtype(), &non_nullable_dtype);
226
227 let nullable_dtype =
228 DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable);
229 let casted = cast(&array, &nullable_dtype).unwrap();
230 assert_eq!(casted.dtype(), &nullable_dtype);
231 }
232
233 #[test]
234 fn test_cast_cannot_change_name_order() {
235 let array = StructArray::try_new(
236 ["xs", "ys", "zs"].into(),
237 vec![
238 buffer![1u8].into_array(),
239 buffer![1u8].into_array(),
240 buffer![1u8].into_array(),
241 ],
242 1,
243 Validity::NonNullable,
244 )
245 .unwrap();
246
247 let tu8 = DType::Primitive(PType::U8, NonNullable);
248
249 let result = cast(
250 array.as_ref(),
251 &DType::Struct(
252 StructFields::new(
253 FieldNames::from(["ys", "xs", "zs"]),
254 vec![tu8.clone(), tu8.clone(), tu8],
255 ),
256 NonNullable,
257 ),
258 );
259 assert!(
260 result.as_ref().is_err_and(|err| {
261 err.to_string()
262 .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
263 }),
264 "{result:?}"
265 );
266 }
267
268 #[test]
269 fn test_cast_complex_struct() {
270 let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
271 let ys = VarBinArray::from_vec(vec!["a", "b", "c", "d", "e"], DType::Utf8(Nullable));
272 let zs = BoolArray::new(
273 BooleanBuffer::from_iter([true, true, false, false, true]),
274 Validity::AllValid,
275 );
276 let fully_nullable_array = StructArray::try_new(
277 ["xs", "ys", "zs"].into(),
278 vec![
279 StructArray::try_new(
280 ["left", "right"].into(),
281 vec![xs.to_array(), xs.to_array()],
282 5,
283 Validity::AllValid,
284 )
285 .unwrap()
286 .into_array(),
287 ys.into_array(),
288 zs.into_array(),
289 ],
290 5,
291 Validity::AllValid,
292 )
293 .unwrap()
294 .into_array();
295
296 let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
297 let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
298 assert_eq!(casted.dtype(), &top_level_non_nullable);
299
300 let non_null_xs_right = DType::Struct(
301 StructFields::new(
302 ["xs", "ys", "zs"].into(),
303 vec![
304 DType::Struct(
305 StructFields::new(
306 ["left", "right"].into(),
307 vec![
308 DType::Primitive(PType::I64, NonNullable),
309 DType::Primitive(PType::I64, Nullable),
310 ],
311 ),
312 Nullable,
313 ),
314 DType::Utf8(Nullable),
315 DType::Bool(Nullable),
316 ],
317 ),
318 Nullable,
319 );
320 let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
321 assert_eq!(casted.dtype(), &non_null_xs_right);
322
323 let non_null_xs = DType::Struct(
324 StructFields::new(
325 ["xs", "ys", "zs"].into(),
326 vec![
327 DType::Struct(
328 StructFields::new(
329 ["left", "right"].into(),
330 vec![
331 DType::Primitive(PType::I64, Nullable),
332 DType::Primitive(PType::I64, Nullable),
333 ],
334 ),
335 NonNullable,
336 ),
337 DType::Utf8(Nullable),
338 DType::Bool(Nullable),
339 ],
340 ),
341 Nullable,
342 );
343 let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
344 assert_eq!(casted.dtype(), &non_null_xs);
345 }
346
347 #[test]
348 fn test_empty_struct_is_constant() {
349 let array = StructArray::new_with_len(2);
350 let is_constant = is_constant(array.as_ref()).vortex_unwrap();
351 assert_eq!(is_constant, Some(true));
352 }
353}