1mod cast;
2mod filter;
3mod mask;
4
5use itertools::Itertools;
6use vortex_error::VortexResult;
7
8use crate::arrays::StructVTable;
9use crate::arrays::struct_::StructArray;
10use crate::compute::{
11 IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
12 MinMaxResult, TakeKernel, TakeKernelAdapter, is_constant_opts, take,
13};
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, register_kernel};
16
17impl TakeKernel for StructVTable {
18 fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19 StructArray::try_new_with_dtype(
20 array
21 .fields()
22 .iter()
23 .map(|field| take(field, indices))
24 .try_collect()?,
25 array.struct_dtype().clone(),
26 indices.len(),
27 array.validity().take(indices)?,
28 )
29 .map(|a| a.into_array())
30 }
31}
32
33register_kernel!(TakeKernelAdapter(StructVTable).lift());
34
35impl MinMaxKernel for StructVTable {
36 fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
37 Ok(None)
39 }
40}
41
42register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
43
44impl IsConstantKernel for StructVTable {
45 fn is_constant(
46 &self,
47 array: &StructArray,
48 opts: &IsConstantOpts,
49 ) -> VortexResult<Option<bool>> {
50 let children = array.children();
51 if children.is_empty() {
52 return Ok(None);
53 }
54
55 for child in children.iter() {
56 match is_constant_opts(child, opts)? {
57 None => return Ok(None),
59 Some(false) => return Ok(Some(false)),
60 Some(true) => {}
61 }
62 }
63
64 Ok(Some(true))
65 }
66}
67
68register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
69
70#[cfg(test)]
71mod tests {
72 use std::sync::Arc;
73
74 use vortex_buffer::buffer;
75 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructDType};
76 use vortex_mask::Mask;
77
78 use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
79 use crate::compute::conformance::mask::test_mask;
80 use crate::compute::{cast, filter};
81 use crate::validity::Validity;
82 use crate::{Array, IntoArray as _};
83
84 #[test]
85 fn filter_empty_struct() {
86 let struct_arr =
87 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
88 let mask = vec![
89 false, true, false, true, false, true, false, true, false, true,
90 ];
91 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
92 assert_eq!(filtered.len(), 5);
93 }
94
95 #[test]
96 fn filter_empty_struct_with_empty_filter() {
97 let struct_arr =
98 StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
99 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
100 assert_eq!(filtered.len(), 0);
101 }
102
103 #[test]
104 fn test_mask_empty_struct() {
105 test_mask(
106 StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
107 .unwrap()
108 .as_ref(),
109 );
110 }
111
112 #[test]
113 fn test_mask_complex_struct() {
114 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
115 let ys = VarBinArray::from_iter(
116 [Some("a"), Some("b"), None, Some("d"), None],
117 DType::Utf8(Nullability::Nullable),
118 )
119 .into_array();
120 let zs =
121 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
122
123 test_mask(
124 StructArray::try_new(
125 ["xs".into(), "ys".into(), "zs".into()].into(),
126 vec![
127 StructArray::try_new(
128 ["left".into(), "right".into()].into(),
129 vec![xs.clone(), xs],
130 5,
131 Validity::NonNullable,
132 )
133 .unwrap()
134 .into_array(),
135 ys,
136 zs,
137 ],
138 5,
139 Validity::NonNullable,
140 )
141 .unwrap()
142 .as_ref(),
143 );
144 }
145
146 #[test]
147 fn test_cast_empty_struct() {
148 let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
149 .unwrap()
150 .into_array();
151 let non_nullable_dtype = DType::Struct(
152 Arc::from(StructDType::new([].into(), vec![])),
153 Nullability::NonNullable,
154 );
155 let casted = cast(&array, &non_nullable_dtype).unwrap();
156 assert_eq!(casted.dtype(), &non_nullable_dtype);
157
158 let nullable_dtype = DType::Struct(
159 Arc::from(StructDType::new([].into(), vec![])),
160 Nullability::Nullable,
161 );
162 let casted = cast(&array, &nullable_dtype).unwrap();
163 assert_eq!(casted.dtype(), &nullable_dtype);
164 }
165
166 #[test]
167 fn test_cast_cannot_change_name_order() {
168 let array = StructArray::try_new(
169 ["xs".into(), "ys".into(), "zs".into()].into(),
170 vec![
171 buffer![1u8].into_array(),
172 buffer![1u8].into_array(),
173 buffer![1u8].into_array(),
174 ],
175 1,
176 Validity::NonNullable,
177 )
178 .unwrap();
179
180 let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable);
181
182 let result = cast(
183 array.as_ref(),
184 &DType::Struct(
185 Arc::from(StructDType::new(
186 FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
187 vec![tu8.clone(), tu8.clone(), tu8],
188 )),
189 Nullability::NonNullable,
190 ),
191 );
192 assert!(
193 result.as_ref().is_err_and(|err| {
194 err.to_string()
195 .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
196 }),
197 "{:?}",
198 result
199 );
200 }
201
202 #[test]
203 fn test_cast_complex_struct() {
204 let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
205 let ys = VarBinArray::from_vec(
206 vec!["a", "b", "c", "d", "e"],
207 DType::Utf8(Nullability::Nullable),
208 );
209 let zs = BoolArray::new(
210 BooleanBuffer::from_iter([true, true, false, false, true]),
211 Validity::AllValid,
212 );
213 let fully_nullable_array = StructArray::try_new(
214 ["xs".into(), "ys".into(), "zs".into()].into(),
215 vec![
216 StructArray::try_new(
217 ["left".into(), "right".into()].into(),
218 vec![xs.to_array(), xs.to_array()],
219 5,
220 Validity::AllValid,
221 )
222 .unwrap()
223 .into_array(),
224 ys.into_array(),
225 zs.into_array(),
226 ],
227 5,
228 Validity::AllValid,
229 )
230 .unwrap()
231 .into_array();
232
233 let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
234 let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
235 assert_eq!(casted.dtype(), &top_level_non_nullable);
236
237 let non_null_xs_right = DType::Struct(
238 Arc::from(StructDType::new(
239 ["xs".into(), "ys".into(), "zs".into()].into(),
240 vec![
241 DType::Struct(
242 Arc::from(StructDType::new(
243 ["left".into(), "right".into()].into(),
244 vec![
245 DType::Primitive(PType::I64, Nullability::NonNullable),
246 DType::Primitive(PType::I64, Nullability::Nullable),
247 ],
248 )),
249 Nullability::Nullable,
250 ),
251 DType::Utf8(Nullability::Nullable),
252 DType::Bool(Nullability::Nullable),
253 ],
254 )),
255 Nullability::Nullable,
256 );
257 let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
258 assert_eq!(casted.dtype(), &non_null_xs_right);
259
260 let non_null_xs = DType::Struct(
261 Arc::from(StructDType::new(
262 ["xs".into(), "ys".into(), "zs".into()].into(),
263 vec![
264 DType::Struct(
265 Arc::from(StructDType::new(
266 ["left".into(), "right".into()].into(),
267 vec![
268 DType::Primitive(PType::I64, Nullability::Nullable),
269 DType::Primitive(PType::I64, Nullability::Nullable),
270 ],
271 )),
272 Nullability::NonNullable,
273 ),
274 DType::Utf8(Nullability::Nullable),
275 DType::Bool(Nullability::Nullable),
276 ],
277 )),
278 Nullability::Nullable,
279 );
280 let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
281 assert_eq!(casted.dtype(), &non_null_xs);
282 }
283}