1mod cast;
2mod filter;
3mod mask;
4
5use itertools::Itertools;
6use vortex_error::VortexResult;
7
8use crate::arrays::StructVTable;
9use crate::arrays::struct_::StructArray;
10use crate::compute::{
11 IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
12 MinMaxResult, TakeKernel, TakeKernelAdapter, is_constant_opts, take,
13};
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, register_kernel};
16
17impl TakeKernel for StructVTable {
18 fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19 StructArray::try_new_with_dtype(
20 array
21 .fields()
22 .iter()
23 .map(|field| take(field, indices))
24 .try_collect()?,
25 array.struct_fields().clone(),
26 indices.len(),
27 array.validity().take(indices)?,
28 )
29 .map(|a| a.into_array())
30 }
31}
32
33register_kernel!(TakeKernelAdapter(StructVTable).lift());
34
35impl MinMaxKernel for StructVTable {
36 fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
37 Ok(None)
39 }
40}
41
42register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
43
44impl IsConstantKernel for StructVTable {
45 fn is_constant(
46 &self,
47 array: &StructArray,
48 opts: &IsConstantOpts,
49 ) -> VortexResult<Option<bool>> {
50 let children = array.children();
51 if children.is_empty() {
52 return Ok(None);
53 }
54
55 for child in children.iter() {
56 match is_constant_opts(child, opts)? {
57 None => return Ok(None),
59 Some(false) => return Ok(Some(false)),
60 Some(true) => {}
61 }
62 }
63
64 Ok(Some(true))
65 }
66}
67
68register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
69
70#[cfg(test)]
71mod tests {
72 use std::sync::Arc;
73
74 use vortex_buffer::buffer;
75 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
76 use vortex_mask::Mask;
77
78 use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
79 use crate::compute::conformance::mask::test_mask;
80 use crate::compute::{cast, filter};
81 use crate::validity::Validity;
82 use crate::{Array, IntoArray as _};
83
84 #[test]
85 fn filter_empty_struct() {
86 let struct_arr =
87 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
88 let mask = vec![
89 false, true, false, true, false, true, false, true, false, true,
90 ];
91 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
92 assert_eq!(filtered.len(), 5);
93 }
94
95 #[test]
96 fn filter_empty_struct_with_empty_filter() {
97 let struct_arr =
98 StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
99 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
100 assert_eq!(filtered.len(), 0);
101 }
102
103 #[test]
104 fn test_mask_empty_struct() {
105 test_mask(
106 StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
107 .unwrap()
108 .as_ref(),
109 );
110 }
111
112 #[test]
113 fn test_mask_complex_struct() {
114 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
115 let ys = VarBinArray::from_iter(
116 [Some("a"), Some("b"), None, Some("d"), None],
117 DType::Utf8(Nullability::Nullable),
118 )
119 .into_array();
120 let zs =
121 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
122
123 test_mask(
124 StructArray::try_new(
125 ["xs".into(), "ys".into(), "zs".into()].into(),
126 vec![
127 StructArray::try_new(
128 ["left".into(), "right".into()].into(),
129 vec![xs.clone(), xs],
130 5,
131 Validity::NonNullable,
132 )
133 .unwrap()
134 .into_array(),
135 ys,
136 zs,
137 ],
138 5,
139 Validity::NonNullable,
140 )
141 .unwrap()
142 .as_ref(),
143 );
144 }
145
146 #[test]
147 fn test_cast_empty_struct() {
148 let array = StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
149 .unwrap()
150 .into_array();
151 let non_nullable_dtype = DType::Struct(
152 Arc::from(StructFields::new([].into(), vec![])),
153 Nullability::NonNullable,
154 );
155 let casted = cast(&array, &non_nullable_dtype).unwrap();
156 assert_eq!(casted.dtype(), &non_nullable_dtype);
157
158 let nullable_dtype = DType::Struct(
159 Arc::from(StructFields::new([].into(), vec![])),
160 Nullability::Nullable,
161 );
162 let casted = cast(&array, &nullable_dtype).unwrap();
163 assert_eq!(casted.dtype(), &nullable_dtype);
164 }
165
166 #[test]
167 fn test_cast_cannot_change_name_order() {
168 let array = StructArray::try_new(
169 ["xs".into(), "ys".into(), "zs".into()].into(),
170 vec![
171 buffer![1u8].into_array(),
172 buffer![1u8].into_array(),
173 buffer![1u8].into_array(),
174 ],
175 1,
176 Validity::NonNullable,
177 )
178 .unwrap();
179
180 let tu8 = DType::Primitive(PType::U8, Nullability::NonNullable);
181
182 let result = cast(
183 array.as_ref(),
184 &DType::Struct(
185 Arc::from(StructFields::new(
186 FieldNames::from(["ys".into(), "xs".into(), "zs".into()]),
187 vec![tu8.clone(), tu8.clone(), tu8],
188 )),
189 Nullability::NonNullable,
190 ),
191 );
192 assert!(
193 result.as_ref().is_err_and(|err| {
194 err.to_string()
195 .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
196 }),
197 "{result:?}"
198 );
199 }
200
201 #[test]
202 fn test_cast_complex_struct() {
203 let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
204 let ys = VarBinArray::from_vec(
205 vec!["a", "b", "c", "d", "e"],
206 DType::Utf8(Nullability::Nullable),
207 );
208 let zs = BoolArray::new(
209 BooleanBuffer::from_iter([true, true, false, false, true]),
210 Validity::AllValid,
211 );
212 let fully_nullable_array = StructArray::try_new(
213 ["xs".into(), "ys".into(), "zs".into()].into(),
214 vec![
215 StructArray::try_new(
216 ["left".into(), "right".into()].into(),
217 vec![xs.to_array(), xs.to_array()],
218 5,
219 Validity::AllValid,
220 )
221 .unwrap()
222 .into_array(),
223 ys.into_array(),
224 zs.into_array(),
225 ],
226 5,
227 Validity::AllValid,
228 )
229 .unwrap()
230 .into_array();
231
232 let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
233 let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
234 assert_eq!(casted.dtype(), &top_level_non_nullable);
235
236 let non_null_xs_right = DType::Struct(
237 Arc::from(StructFields::new(
238 ["xs".into(), "ys".into(), "zs".into()].into(),
239 vec![
240 DType::Struct(
241 Arc::from(StructFields::new(
242 ["left".into(), "right".into()].into(),
243 vec![
244 DType::Primitive(PType::I64, Nullability::NonNullable),
245 DType::Primitive(PType::I64, Nullability::Nullable),
246 ],
247 )),
248 Nullability::Nullable,
249 ),
250 DType::Utf8(Nullability::Nullable),
251 DType::Bool(Nullability::Nullable),
252 ],
253 )),
254 Nullability::Nullable,
255 );
256 let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
257 assert_eq!(casted.dtype(), &non_null_xs_right);
258
259 let non_null_xs = DType::Struct(
260 Arc::from(StructFields::new(
261 ["xs".into(), "ys".into(), "zs".into()].into(),
262 vec![
263 DType::Struct(
264 Arc::from(StructFields::new(
265 ["left".into(), "right".into()].into(),
266 vec![
267 DType::Primitive(PType::I64, Nullability::Nullable),
268 DType::Primitive(PType::I64, Nullability::Nullable),
269 ],
270 )),
271 Nullability::NonNullable,
272 ),
273 DType::Utf8(Nullability::Nullable),
274 DType::Bool(Nullability::Nullable),
275 ],
276 )),
277 Nullability::Nullable,
278 );
279 let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
280 assert_eq!(casted.dtype(), &non_null_xs);
281 }
282}