1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, ArrayVariants};
10use vortex_dtype::{DType, FieldNames, Nullability, StructDType};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::{ExprRef, VortexExpr};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct Merge {
23 values: Vec<ExprRef>,
24}
25
26impl Merge {
27 pub fn new_expr(values: Vec<ExprRef>) -> Arc<Self> {
28 Arc::new(Merge { values })
29 }
30}
31
32pub fn merge(elements: impl IntoIterator<Item = impl Into<ExprRef>>) -> ExprRef {
33 let values = elements.into_iter().map(|value| value.into()).collect_vec();
34 Merge::new_expr(values)
35}
36
37impl Display for Merge {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.write_str("{")?;
40 self.values
41 .iter()
42 .format_with(", ", |expr, f| f(expr))
43 .fmt(f)?;
44 f.write_str("}")
45 }
46}
47
48impl VortexExpr for Merge {
49 fn as_any(&self) -> &dyn Any {
50 self
51 }
52
53 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
54 let len = batch.len();
55 let value_arrays = self
56 .values
57 .iter()
58 .map(|value_expr| value_expr.evaluate(batch))
59 .process_results(|it| it.collect::<Vec<_>>())?;
60
61 let mut field_names = Vec::new();
63 let mut arrays = Vec::new();
64
65 for value_array in value_arrays.iter() {
66 if value_array.dtype().is_nullable() {
68 todo!("merge nullable structs");
69 }
70 if !value_array.dtype().is_struct() {
71 vortex_bail!("merge expects non-nullable struct input");
72 }
73
74 let struct_array = value_array
75 .as_struct_typed()
76 .vortex_expect("merge expects struct input");
77
78 for (i, field_name) in struct_array.names().iter().enumerate() {
79 let array = struct_array
80 .maybe_null_field_by_idx(i)
81 .vortex_expect("struct field not found");
82
83 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
85 arrays[idx] = array;
86 } else {
87 field_names.push(field_name.clone());
88 arrays.push(array);
89 }
90 }
91 }
92
93 Ok(StructArray::try_new(
94 FieldNames::from(field_names),
95 arrays,
96 len,
97 Validity::NonNullable,
98 )?
99 .into_array())
100 }
101
102 fn children(&self) -> Vec<&ExprRef> {
103 self.values.iter().collect()
104 }
105
106 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
107 Self::new_expr(children)
108 }
109
110 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
111 let mut field_names = Vec::new();
112 let mut arrays = Vec::new();
113
114 for value in self.values.iter() {
115 let dtype = value.return_dtype(scope_dtype)?;
116 if !dtype.is_struct() {
117 vortex_bail!("merge expects non-nullable struct input");
118 }
119
120 let struct_dtype = dtype
121 .as_struct()
122 .vortex_expect("merge expects struct input");
123
124 for i in 0..struct_dtype.nfields() {
125 let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
126 let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
127 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
128 arrays[idx] = field_dtype;
129 } else {
130 field_names.push(field_name.clone());
131 arrays.push(field_dtype);
132 }
133 }
134 }
135
136 Ok(DType::Struct(
137 Arc::new(StructDType::new(FieldNames::from(field_names), arrays)),
138 Nullability::NonNullable,
139 ))
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use vortex_array::arrays::{PrimitiveArray, StructArray};
146 use vortex_array::{Array, IntoArray, ToCanonical};
147 use vortex_buffer::buffer;
148 use vortex_error::{VortexResult, vortex_bail, vortex_err};
149
150 use crate::{GetItem, Identity, Merge, VortexExpr};
151
152 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
153 let mut field_path = field_path.iter();
154
155 let Some(field) = field_path.next() else {
156 vortex_bail!("empty field path");
157 };
158
159 let mut array = array
160 .as_struct_typed()
161 .ok_or_else(|| vortex_err!("expected a struct"))?
162 .maybe_null_field_by_name(field)?;
163
164 for field in field_path {
165 array = array
166 .as_struct_typed()
167 .ok_or_else(|| vortex_err!("expected a struct"))?
168 .maybe_null_field_by_name(field)?;
169 }
170 Ok(array.to_primitive().unwrap())
171 }
172
173 #[test]
174 pub fn test_merge() {
175 let expr = Merge::new_expr(vec![
176 GetItem::new_expr("0", Identity::new_expr()),
177 GetItem::new_expr("1", Identity::new_expr()),
178 GetItem::new_expr("2", Identity::new_expr()),
179 ]);
180
181 let test_array = StructArray::from_fields(&[
182 (
183 "0",
184 StructArray::from_fields(&[
185 ("a", buffer![0, 0, 0].into_array()),
186 ("b", buffer![1, 1, 1].into_array()),
187 ])
188 .unwrap()
189 .into_array(),
190 ),
191 (
192 "1",
193 StructArray::from_fields(&[
194 ("b", buffer![2, 2, 2].into_array()),
195 ("c", buffer![3, 3, 3].into_array()),
196 ])
197 .unwrap()
198 .into_array(),
199 ),
200 (
201 "2",
202 StructArray::from_fields(&[
203 ("d", buffer![4, 4, 4].into_array()),
204 ("e", buffer![5, 5, 5].into_array()),
205 ])
206 .unwrap()
207 .into_array(),
208 ),
209 ])
210 .unwrap()
211 .into_array();
212 let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
213
214 assert_eq!(
215 actual_array.as_struct_typed().unwrap().names(),
216 &["a".into(), "b".into(), "c".into(), "d".into(), "e".into()].into()
217 );
218
219 assert_eq!(
220 primitive_field(&actual_array, &["a"])
221 .unwrap()
222 .as_slice::<i32>(),
223 [0, 0, 0]
224 );
225 assert_eq!(
226 primitive_field(&actual_array, &["b"])
227 .unwrap()
228 .as_slice::<i32>(),
229 [2, 2, 2]
230 );
231 assert_eq!(
232 primitive_field(&actual_array, &["c"])
233 .unwrap()
234 .as_slice::<i32>(),
235 [3, 3, 3]
236 );
237 assert_eq!(
238 primitive_field(&actual_array, &["d"])
239 .unwrap()
240 .as_slice::<i32>(),
241 [4, 4, 4]
242 );
243 assert_eq!(
244 primitive_field(&actual_array, &["e"])
245 .unwrap()
246 .as_slice::<i32>(),
247 [5, 5, 5]
248 );
249 }
250
251 #[test]
252 pub fn test_empty_merge() {
253 let expr = Merge::new_expr(Vec::new());
254
255 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
256 .unwrap()
257 .into_array();
258 let actual_array = expr.evaluate(&test_array).unwrap();
259 assert_eq!(actual_array.len(), test_array.len());
260 assert_eq!(actual_array.as_struct_typed().unwrap().nfields(), 0);
261 }
262
263 #[test]
264 pub fn test_nested_merge() {
265 let expr = Merge::new_expr(vec![
268 GetItem::new_expr("0", Identity::new_expr()),
269 GetItem::new_expr("1", Identity::new_expr()),
270 ]);
271
272 let test_array = StructArray::from_fields(&[
273 (
274 "0",
275 StructArray::from_fields(&[(
276 "a",
277 StructArray::from_fields(&[
278 ("x", buffer![0, 0, 0].into_array()),
279 ("y", buffer![1, 1, 1].into_array()),
280 ])
281 .unwrap()
282 .into_array(),
283 )])
284 .unwrap()
285 .into_array(),
286 ),
287 (
288 "1",
289 StructArray::from_fields(&[(
290 "a",
291 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
292 .unwrap()
293 .into_array(),
294 )])
295 .unwrap()
296 .into_array(),
297 ),
298 ])
299 .unwrap()
300 .into_array();
301 let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
302
303 assert_eq!(
304 actual_array
305 .as_struct_typed()
306 .unwrap()
307 .maybe_null_field_by_name("a")
308 .unwrap()
309 .as_struct_typed()
310 .unwrap()
311 .names()
312 .iter()
313 .map(|name| name.as_ref())
314 .collect::<Vec<_>>(),
315 vec!["x"]
316 );
317 }
318
319 #[test]
320 pub fn test_merge_order() {
321 let expr = Merge::new_expr(vec![
322 GetItem::new_expr("0", Identity::new_expr()),
323 GetItem::new_expr("1", Identity::new_expr()),
324 ]);
325
326 let test_array = StructArray::from_fields(&[
327 (
328 "0",
329 StructArray::from_fields(&[
330 ("a", buffer![0, 0, 0].into_array()),
331 ("c", buffer![1, 1, 1].into_array()),
332 ])
333 .unwrap()
334 .into_array(),
335 ),
336 (
337 "1",
338 StructArray::from_fields(&[
339 ("b", buffer![2, 2, 2].into_array()),
340 ("d", buffer![3, 3, 3].into_array()),
341 ])
342 .unwrap()
343 .into_array(),
344 ),
345 ])
346 .unwrap()
347 .into_array();
348 let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
349
350 assert_eq!(
351 actual_array.as_struct_typed().unwrap().names(),
352 &["a".into(), "c".into(), "b".into(), "d".into()].into()
353 );
354 }
355}