1mod cast;
2mod filter;
3mod mask;
4
5use itertools::Itertools;
6use vortex_dtype::Nullability::NonNullable;
7use vortex_error::VortexResult;
8use vortex_scalar::Scalar;
9
10use crate::arrays::StructVTable;
11use crate::arrays::struct_::StructArray;
12use crate::compute::{
13 IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
14 MinMaxResult, TakeKernel, TakeKernelAdapter, fill_null, is_constant_opts, take,
15};
16use crate::validity::Validity;
17use crate::vtable::ValidityHelper;
18use crate::{Array, ArrayRef, IntoArray, register_kernel};
19
20impl TakeKernel for StructVTable {
21 fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22 if array.is_empty() {
25 return StructArray::try_new_with_dtype(
26 array.fields().to_vec(),
27 array.struct_fields().clone(),
28 indices.len(),
29 Validity::AllInvalid,
30 )
31 .map(StructArray::into_array);
32 }
33 let inner_indices = &fill_null(
35 indices,
36 &Scalar::default_value(indices.dtype().with_nullability(NonNullable)),
37 )?;
38 StructArray::try_new_with_dtype(
39 array
40 .fields()
41 .iter()
42 .map(|field| take(field, inner_indices))
43 .try_collect()?,
44 array.struct_fields().clone(),
45 indices.len(),
46 array.validity().take(indices)?,
47 )
48 .map(|a| a.into_array())
49 }
50}
51
52register_kernel!(TakeKernelAdapter(StructVTable).lift());
53
54impl MinMaxKernel for StructVTable {
55 fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
56 Ok(None)
58 }
59}
60
61register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
62
63impl IsConstantKernel for StructVTable {
64 fn is_constant(
65 &self,
66 array: &StructArray,
67 opts: &IsConstantOpts,
68 ) -> VortexResult<Option<bool>> {
69 let children = array.children();
70 if children.is_empty() {
71 return Ok(None);
72 }
73
74 for child in children.iter() {
75 match is_constant_opts(child, opts)? {
76 None => return Ok(None),
78 Some(false) => return Ok(Some(false)),
79 Some(true) => {}
80 }
81 }
82
83 Ok(Some(true))
84 }
85}
86
87register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
88
89#[cfg(test)]
90mod tests {
91 use std::sync::Arc;
92
93 use Nullability::{NonNullable, Nullable};
94 use vortex_buffer::buffer;
95 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
96 use vortex_mask::Mask;
97 use vortex_scalar::Scalar;
98
99 use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
100 use crate::compute::conformance::mask::test_mask;
101 use crate::compute::{cast, filter, take};
102 use crate::validity::Validity;
103 use crate::{Array, IntoArray as _};
104
105 #[test]
106 fn filter_empty_struct() {
107 let struct_arr =
108 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
109 let mask = vec![
110 false, true, false, true, false, true, false, true, false, true,
111 ];
112 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
113 assert_eq!(filtered.len(), 5);
114 }
115
116 #[test]
117 fn take_empty_struct() {
118 let struct_arr =
119 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
120 let indices = PrimitiveArray::from_option_iter([Some(1), None]);
121 let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
122 assert_eq!(taken.len(), 2);
123
124 assert_eq!(
125 taken.scalar_at(0).unwrap(),
126 Scalar::struct_(
127 DType::Struct(Arc::new(StructFields::new([].into(), vec![])), Nullable),
128 vec![]
129 )
130 );
131 assert_eq!(
132 taken.scalar_at(1).unwrap(),
133 Scalar::null(DType::Struct(
134 Arc::new(StructFields::new([].into(), vec![])),
135 Nullable
136 ))
137 );
138 }
139
140 #[test]
141 fn take_field_struct() {
142 let struct_arr =
143 StructArray::from_fields(&[("a", PrimitiveArray::from_iter(0..10).to_array())])
144 .unwrap();
145 let indices = PrimitiveArray::from_option_iter([Some(1), None]);
146 let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
147 assert_eq!(taken.len(), 2);
148
149 assert_eq!(
150 taken.scalar_at(0).unwrap(),
151 Scalar::struct_(
152 struct_arr.dtype().union_nullability(Nullable),
153 vec![Scalar::primitive(1, NonNullable)],
154 )
155 );
156 assert_eq!(
157 taken.scalar_at(1).unwrap(),
158 Scalar::null(struct_arr.dtype().union_nullability(Nullable),)
159 );
160 }
161
162 #[test]
163 fn filter_empty_struct_with_empty_filter() {
164 let struct_arr =
165 StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
166 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
167 assert_eq!(filtered.len(), 0);
168 }
169
170 #[test]
171 fn test_mask_empty_struct() {
172 test_mask(
173 StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
174 .unwrap()
175 .as_ref(),
176 );
177 }
178
179 #[test]
180 fn test_mask_complex_struct() {
181 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
182 let ys = VarBinArray::from_iter(
183 [Some("a"), Some("b"), None, Some("d"), None],
184 DType::Utf8(Nullable),
185 )
186 .into_array();
187 let zs =
188 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
189
190 test_mask(
191 StructArray::try_new(
192 ["xs".into(), "ys".into(), "zs".into()].into(),
193 vec![
194 StructArray::try_new(
195 ["left".into(), "right".into()].into(),
196 vec![xs.clone(), xs],
197 5,
198 Validity::NonNullable,
199 )
200 .unwrap()
201 .into_array(),
202 ys,
203 zs,
204 ],
205 5,
206 Validity::NonNullable,
207 )
208 .unwrap()
209 .as_ref(),
210 );
211 }
212
213 #[test]
214 fn test_cast_empty_struct() {
215 let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
216 .unwrap()
217 .into_array();
218 let non_nullable_dtype =
219 DType::Struct(Arc::from(StructFields::new([].into(), vec![])), NonNullable);
220 let casted = cast(&array, &non_nullable_dtype).unwrap();
221 assert_eq!(casted.dtype(), &non_nullable_dtype);
222
223 let nullable_dtype =
224 DType::Struct(Arc::from(StructFields::new([].into(), vec![])), Nullable);
225 let casted = cast(&array, &nullable_dtype).unwrap();
226 assert_eq!(casted.dtype(), &nullable_dtype);
227 }
228
229 #[test]
230 fn test_cast_cannot_change_name_order() {
231 let array = StructArray::try_new(
232 ["xs".into(), "ys".into(), "zs".into()].into(),
233 vec![
234 buffer![1u8].into_array(),
235 buffer![1u8].into_array(),
236 buffer![1u8].into_array(),
237 ],
238 1,
239 Validity::NonNullable,
240 )
241 .unwrap();
242
243 let tu8 = DType::Primitive(PType::U8, NonNullable);
244
245 let result = cast(
246 array.as_ref(),
247 &DType::Struct(
248 Arc::from(StructFields::new(
249 FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
250 vec![tu8.clone(), tu8.clone(), tu8],
251 )),
252 NonNullable,
253 ),
254 );
255 assert!(
256 result.as_ref().is_err_and(|err| {
257 err.to_string()
258 .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
259 }),
260 "{result:?}"
261 );
262 }
263
264 #[test]
265 fn test_cast_complex_struct() {
266 let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
267 let ys = VarBinArray::from_vec(vec!["a", "b", "c", "d", "e"], DType::Utf8(Nullable));
268 let zs = BoolArray::new(
269 BooleanBuffer::from_iter([true, true, false, false, true]),
270 Validity::AllValid,
271 );
272 let fully_nullable_array = StructArray::try_new(
273 ["xs".into(), "ys".into(), "zs".into()].into(),
274 vec![
275 StructArray::try_new(
276 ["left".into(), "right".into()].into(),
277 vec![xs.to_array(), xs.to_array()],
278 5,
279 Validity::AllValid,
280 )
281 .unwrap()
282 .into_array(),
283 ys.into_array(),
284 zs.into_array(),
285 ],
286 5,
287 Validity::AllValid,
288 )
289 .unwrap()
290 .into_array();
291
292 let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
293 let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
294 assert_eq!(casted.dtype(), &top_level_non_nullable);
295
296 let non_null_xs_right = DType::Struct(
297 Arc::from(StructFields::new(
298 ["xs".into(), "ys".into(), "zs".into()].into(),
299 vec![
300 DType::Struct(
301 Arc::from(StructFields::new(
302 ["left".into(), "right".into()].into(),
303 vec![
304 DType::Primitive(PType::I64, NonNullable),
305 DType::Primitive(PType::I64, Nullable),
306 ],
307 )),
308 Nullable,
309 ),
310 DType::Utf8(Nullable),
311 DType::Bool(Nullable),
312 ],
313 )),
314 Nullable,
315 );
316 let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
317 assert_eq!(casted.dtype(), &non_null_xs_right);
318
319 let non_null_xs = DType::Struct(
320 Arc::from(StructFields::new(
321 ["xs".into(), "ys".into(), "zs".into()].into(),
322 vec![
323 DType::Struct(
324 Arc::from(StructFields::new(
325 ["left".into(), "right".into()].into(),
326 vec![
327 DType::Primitive(PType::I64, Nullable),
328 DType::Primitive(PType::I64, Nullable),
329 ],
330 )),
331 NonNullable,
332 ),
333 DType::Utf8(Nullable),
334 DType::Bool(Nullable),
335 ],
336 )),
337 Nullable,
338 );
339 let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
340 assert_eq!(casted.dtype(), &non_null_xs);
341 }
342}