1mod cast;
2mod filter;
3mod mask;
4
5use itertools::Itertools;
6use vortex_dtype::Nullability::NonNullable;
7use vortex_error::VortexResult;
8use vortex_scalar::Scalar;
9
10use crate::arrays::StructVTable;
11use crate::arrays::struct_::StructArray;
12use crate::compute::{
13 IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
14 MinMaxResult, TakeKernel, TakeKernelAdapter, fill_null, is_constant_opts, take,
15};
16use crate::validity::Validity;
17use crate::vtable::ValidityHelper;
18use crate::{Array, ArrayRef, IntoArray, register_kernel};
19
20impl TakeKernel for StructVTable {
21 fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22 if array.is_empty() {
25 return StructArray::try_new_with_dtype(
26 array.fields().to_vec(),
27 array.struct_fields().clone(),
28 indices.len(),
29 Validity::AllInvalid,
30 )
31 .map(StructArray::into_array);
32 }
33 let inner_indices = &fill_null(
35 indices,
36 &Scalar::default_value(indices.dtype().with_nullability(NonNullable)),
37 )?;
38 StructArray::try_new_with_dtype(
39 array
40 .fields()
41 .iter()
42 .map(|field| take(field, inner_indices))
43 .try_collect()?,
44 array.struct_fields().clone(),
45 indices.len(),
46 array.validity().take(indices)?,
47 )
48 .map(|a| a.into_array())
49 }
50}
51
52register_kernel!(TakeKernelAdapter(StructVTable).lift());
53
54impl MinMaxKernel for StructVTable {
55 fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
56 Ok(None)
58 }
59}
60
61register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
62
63impl IsConstantKernel for StructVTable {
64 fn is_constant(
65 &self,
66 array: &StructArray,
67 opts: &IsConstantOpts,
68 ) -> VortexResult<Option<bool>> {
69 let children = array.children();
70 if children.is_empty() {
71 return Ok(None);
72 }
73
74 for child in children.iter() {
75 match is_constant_opts(child, opts)? {
76 None => return Ok(None),
78 Some(false) => return Ok(Some(false)),
79 Some(true) => {}
80 }
81 }
82
83 Ok(Some(true))
84 }
85}
86
87register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
88
89#[cfg(test)]
90mod tests {
91
92 use Nullability::{NonNullable, Nullable};
93 use vortex_buffer::buffer;
94 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
95 use vortex_mask::Mask;
96 use vortex_scalar::Scalar;
97
98 use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
99 use crate::compute::conformance::mask::test_mask;
100 use crate::compute::{cast, filter, take};
101 use crate::validity::Validity;
102 use crate::{Array, IntoArray as _};
103
104 #[test]
105 fn filter_empty_struct() {
106 let struct_arr =
107 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
108 let mask = vec![
109 false, true, false, true, false, true, false, true, false, true,
110 ];
111 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
112 assert_eq!(filtered.len(), 5);
113 }
114
115 #[test]
116 fn take_empty_struct() {
117 let struct_arr =
118 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
119 let indices = PrimitiveArray::from_option_iter([Some(1), None]);
120 let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
121 assert_eq!(taken.len(), 2);
122
123 assert_eq!(
124 taken.scalar_at(0).unwrap(),
125 Scalar::struct_(
126 DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable),
127 vec![]
128 )
129 );
130 assert_eq!(
131 taken.scalar_at(1).unwrap(),
132 Scalar::null(DType::Struct(
133 StructFields::new(FieldNames::default(), vec![]),
134 Nullable
135 ))
136 );
137 }
138
139 #[test]
140 fn take_field_struct() {
141 let struct_arr =
142 StructArray::from_fields(&[("a", PrimitiveArray::from_iter(0..10).to_array())])
143 .unwrap();
144 let indices = PrimitiveArray::from_option_iter([Some(1), None]);
145 let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
146 assert_eq!(taken.len(), 2);
147
148 assert_eq!(
149 taken.scalar_at(0).unwrap(),
150 Scalar::struct_(
151 struct_arr.dtype().union_nullability(Nullable),
152 vec![Scalar::primitive(1, NonNullable)],
153 )
154 );
155 assert_eq!(
156 taken.scalar_at(1).unwrap(),
157 Scalar::null(struct_arr.dtype().union_nullability(Nullable),)
158 );
159 }
160
161 #[test]
162 fn filter_empty_struct_with_empty_filter() {
163 let struct_arr =
164 StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
165 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
166 assert_eq!(filtered.len(), 0);
167 }
168
169 #[test]
170 fn test_mask_empty_struct() {
171 test_mask(
172 StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
173 .unwrap()
174 .as_ref(),
175 );
176 }
177
178 #[test]
179 fn test_mask_complex_struct() {
180 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
181 let ys = VarBinArray::from_iter(
182 [Some("a"), Some("b"), None, Some("d"), None],
183 DType::Utf8(Nullable),
184 )
185 .into_array();
186 let zs =
187 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
188
189 test_mask(
190 StructArray::try_new(
191 ["xs", "ys", "zs"].into(),
192 vec![
193 StructArray::try_new(
194 ["left", "right"].into(),
195 vec![xs.clone(), xs],
196 5,
197 Validity::NonNullable,
198 )
199 .unwrap()
200 .into_array(),
201 ys,
202 zs,
203 ],
204 5,
205 Validity::NonNullable,
206 )
207 .unwrap()
208 .as_ref(),
209 );
210 }
211
212 #[test]
213 fn test_cast_empty_struct() {
214 let array = StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
215 .unwrap()
216 .into_array();
217 let non_nullable_dtype = DType::Struct(
218 StructFields::new(FieldNames::default(), vec![]),
219 NonNullable,
220 );
221 let casted = cast(&array, &non_nullable_dtype).unwrap();
222 assert_eq!(casted.dtype(), &non_nullable_dtype);
223
224 let nullable_dtype =
225 DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable);
226 let casted = cast(&array, &nullable_dtype).unwrap();
227 assert_eq!(casted.dtype(), &nullable_dtype);
228 }
229
230 #[test]
231 fn test_cast_cannot_change_name_order() {
232 let array = StructArray::try_new(
233 ["xs", "ys", "zs"].into(),
234 vec![
235 buffer![1u8].into_array(),
236 buffer![1u8].into_array(),
237 buffer![1u8].into_array(),
238 ],
239 1,
240 Validity::NonNullable,
241 )
242 .unwrap();
243
244 let tu8 = DType::Primitive(PType::U8, NonNullable);
245
246 let result = cast(
247 array.as_ref(),
248 &DType::Struct(
249 StructFields::new(
250 FieldNames::from(["ys", "xs", "zs"]),
251 vec![tu8.clone(), tu8.clone(), tu8],
252 ),
253 NonNullable,
254 ),
255 );
256 assert!(
257 result.as_ref().is_err_and(|err| {
258 err.to_string()
259 .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
260 }),
261 "{result:?}"
262 );
263 }
264
265 #[test]
266 fn test_cast_complex_struct() {
267 let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
268 let ys = VarBinArray::from_vec(vec!["a", "b", "c", "d", "e"], DType::Utf8(Nullable));
269 let zs = BoolArray::new(
270 BooleanBuffer::from_iter([true, true, false, false, true]),
271 Validity::AllValid,
272 );
273 let fully_nullable_array = StructArray::try_new(
274 ["xs", "ys", "zs"].into(),
275 vec![
276 StructArray::try_new(
277 ["left", "right"].into(),
278 vec![xs.to_array(), xs.to_array()],
279 5,
280 Validity::AllValid,
281 )
282 .unwrap()
283 .into_array(),
284 ys.into_array(),
285 zs.into_array(),
286 ],
287 5,
288 Validity::AllValid,
289 )
290 .unwrap()
291 .into_array();
292
293 let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
294 let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
295 assert_eq!(casted.dtype(), &top_level_non_nullable);
296
297 let non_null_xs_right = DType::Struct(
298 StructFields::new(
299 ["xs", "ys", "zs"].into(),
300 vec![
301 DType::Struct(
302 StructFields::new(
303 ["left", "right"].into(),
304 vec![
305 DType::Primitive(PType::I64, NonNullable),
306 DType::Primitive(PType::I64, Nullable),
307 ],
308 ),
309 Nullable,
310 ),
311 DType::Utf8(Nullable),
312 DType::Bool(Nullable),
313 ],
314 ),
315 Nullable,
316 );
317 let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
318 assert_eq!(casted.dtype(), &non_null_xs_right);
319
320 let non_null_xs = DType::Struct(
321 StructFields::new(
322 ["xs", "ys", "zs"].into(),
323 vec![
324 DType::Struct(
325 StructFields::new(
326 ["left", "right"].into(),
327 vec![
328 DType::Primitive(PType::I64, Nullable),
329 DType::Primitive(PType::I64, Nullable),
330 ],
331 ),
332 NonNullable,
333 ),
334 DType::Utf8(Nullable),
335 DType::Bool(Nullable),
336 ],
337 ),
338 Nullable,
339 );
340 let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
341 assert_eq!(casted.dtype(), &non_null_xs);
342 }
343}