1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, ArrayVariants};
10use vortex_dtype::{DType, FieldNames, Nullability, StructDType};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::{ExprRef, VortexExpr};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct Merge {
23 values: Vec<ExprRef>,
24}
25
26impl Merge {
27 pub fn new_expr(values: Vec<ExprRef>) -> Arc<Self> {
28 Arc::new(Merge { values })
29 }
30}
31
32pub fn merge(elements: impl IntoIterator<Item = impl Into<ExprRef>>) -> ExprRef {
33 let values = elements.into_iter().map(|value| value.into()).collect_vec();
34 Merge::new_expr(values)
35}
36
37impl Display for Merge {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.write_str("{")?;
40 self.values
41 .iter()
42 .format_with(", ", |expr, f| f(expr))
43 .fmt(f)?;
44 f.write_str("}")
45 }
46}
47
48#[cfg(feature = "proto")]
49pub(crate) mod proto {
50 use vortex_error::{VortexResult, vortex_bail};
51 use vortex_proto::expr::kind::Kind;
52
53 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Merge};
54
55 pub struct MergeSerde;
56
57 impl Id for MergeSerde {
58 fn id(&self) -> &'static str {
59 "merge"
60 }
61 }
62
63 impl ExprDeserialize for MergeSerde {
64 fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
65 vortex_bail!(NotImplemented: "", self.id())
66 }
67 }
68
69 impl ExprSerializable for Merge {
70 fn id(&self) -> &'static str {
71 MergeSerde.id()
72 }
73
74 fn serialize_kind(&self) -> VortexResult<Kind> {
75 vortex_bail!(NotImplemented: "", self.id())
76 }
77 }
78}
79
80impl VortexExpr for Merge {
81 fn as_any(&self) -> &dyn Any {
82 self
83 }
84
85 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
86 let len = batch.len();
87 let value_arrays = self
88 .values
89 .iter()
90 .map(|value_expr| value_expr.evaluate(batch))
91 .process_results(|it| it.collect::<Vec<_>>())?;
92
93 let mut field_names = Vec::new();
95 let mut arrays = Vec::new();
96
97 for value_array in value_arrays.iter() {
98 if value_array.dtype().is_nullable() {
100 todo!("merge nullable structs");
101 }
102 if !value_array.dtype().is_struct() {
103 vortex_bail!("merge expects non-nullable struct input");
104 }
105
106 let struct_array = value_array
107 .as_struct_typed()
108 .vortex_expect("merge expects struct input");
109
110 for (i, field_name) in struct_array.names().iter().enumerate() {
111 let array = struct_array
112 .maybe_null_field_by_idx(i)
113 .vortex_expect("struct field not found");
114
115 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
117 arrays[idx] = array;
118 } else {
119 field_names.push(field_name.clone());
120 arrays.push(array);
121 }
122 }
123 }
124
125 Ok(StructArray::try_new(
126 FieldNames::from(field_names),
127 arrays,
128 len,
129 Validity::NonNullable,
130 )?
131 .into_array())
132 }
133
134 fn children(&self) -> Vec<&ExprRef> {
135 self.values.iter().collect()
136 }
137
138 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
139 Self::new_expr(children)
140 }
141
142 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
143 let mut field_names = Vec::new();
144 let mut arrays = Vec::new();
145
146 for value in self.values.iter() {
147 let dtype = value.return_dtype(scope_dtype)?;
148 if !dtype.is_struct() {
149 vortex_bail!("merge expects non-nullable struct input");
150 }
151
152 let struct_dtype = dtype
153 .as_struct()
154 .vortex_expect("merge expects struct input");
155
156 for i in 0..struct_dtype.nfields() {
157 let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
158 let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
159 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
160 arrays[idx] = field_dtype;
161 } else {
162 field_names.push(field_name.clone());
163 arrays.push(field_dtype);
164 }
165 }
166 }
167
168 Ok(DType::Struct(
169 Arc::new(StructDType::new(FieldNames::from(field_names), arrays)),
170 Nullability::NonNullable,
171 ))
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use vortex_array::arrays::{PrimitiveArray, StructArray};
178 use vortex_array::{Array, IntoArray, ToCanonical};
179 use vortex_buffer::buffer;
180 use vortex_error::{VortexResult, vortex_bail, vortex_err};
181
182 use crate::{GetItem, Identity, Merge, VortexExpr};
183
184 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
185 let mut field_path = field_path.iter();
186
187 let Some(field) = field_path.next() else {
188 vortex_bail!("empty field path");
189 };
190
191 let mut array = array
192 .as_struct_typed()
193 .ok_or_else(|| vortex_err!("expected a struct"))?
194 .maybe_null_field_by_name(field)?;
195
196 for field in field_path {
197 array = array
198 .as_struct_typed()
199 .ok_or_else(|| vortex_err!("expected a struct"))?
200 .maybe_null_field_by_name(field)?;
201 }
202 Ok(array.to_primitive().unwrap())
203 }
204
205 #[test]
206 pub fn test_merge() {
207 let expr = Merge::new_expr(vec![
208 GetItem::new_expr("0", Identity::new_expr()),
209 GetItem::new_expr("1", Identity::new_expr()),
210 GetItem::new_expr("2", Identity::new_expr()),
211 ]);
212
213 let test_array = StructArray::from_fields(&[
214 (
215 "0",
216 StructArray::from_fields(&[
217 ("a", buffer![0, 0, 0].into_array()),
218 ("b", buffer![1, 1, 1].into_array()),
219 ])
220 .unwrap()
221 .into_array(),
222 ),
223 (
224 "1",
225 StructArray::from_fields(&[
226 ("b", buffer![2, 2, 2].into_array()),
227 ("c", buffer![3, 3, 3].into_array()),
228 ])
229 .unwrap()
230 .into_array(),
231 ),
232 (
233 "2",
234 StructArray::from_fields(&[
235 ("d", buffer![4, 4, 4].into_array()),
236 ("e", buffer![5, 5, 5].into_array()),
237 ])
238 .unwrap()
239 .into_array(),
240 ),
241 ])
242 .unwrap()
243 .into_array();
244 let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
245
246 assert_eq!(
247 actual_array.as_struct_typed().unwrap().names(),
248 &["a".into(), "b".into(), "c".into(), "d".into(), "e".into()].into()
249 );
250
251 assert_eq!(
252 primitive_field(&actual_array, &["a"])
253 .unwrap()
254 .as_slice::<i32>(),
255 [0, 0, 0]
256 );
257 assert_eq!(
258 primitive_field(&actual_array, &["b"])
259 .unwrap()
260 .as_slice::<i32>(),
261 [2, 2, 2]
262 );
263 assert_eq!(
264 primitive_field(&actual_array, &["c"])
265 .unwrap()
266 .as_slice::<i32>(),
267 [3, 3, 3]
268 );
269 assert_eq!(
270 primitive_field(&actual_array, &["d"])
271 .unwrap()
272 .as_slice::<i32>(),
273 [4, 4, 4]
274 );
275 assert_eq!(
276 primitive_field(&actual_array, &["e"])
277 .unwrap()
278 .as_slice::<i32>(),
279 [5, 5, 5]
280 );
281 }
282
283 #[test]
284 pub fn test_empty_merge() {
285 let expr = Merge::new_expr(Vec::new());
286
287 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
288 .unwrap()
289 .into_array();
290 let actual_array = expr.evaluate(&test_array).unwrap();
291 assert_eq!(actual_array.len(), test_array.len());
292 assert_eq!(actual_array.as_struct_typed().unwrap().nfields(), 0);
293 }
294
295 #[test]
296 pub fn test_nested_merge() {
297 let expr = Merge::new_expr(vec![
300 GetItem::new_expr("0", Identity::new_expr()),
301 GetItem::new_expr("1", Identity::new_expr()),
302 ]);
303
304 let test_array = StructArray::from_fields(&[
305 (
306 "0",
307 StructArray::from_fields(&[(
308 "a",
309 StructArray::from_fields(&[
310 ("x", buffer![0, 0, 0].into_array()),
311 ("y", buffer![1, 1, 1].into_array()),
312 ])
313 .unwrap()
314 .into_array(),
315 )])
316 .unwrap()
317 .into_array(),
318 ),
319 (
320 "1",
321 StructArray::from_fields(&[(
322 "a",
323 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
324 .unwrap()
325 .into_array(),
326 )])
327 .unwrap()
328 .into_array(),
329 ),
330 ])
331 .unwrap()
332 .into_array();
333 let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
334
335 assert_eq!(
336 actual_array
337 .as_struct_typed()
338 .unwrap()
339 .maybe_null_field_by_name("a")
340 .unwrap()
341 .as_struct_typed()
342 .unwrap()
343 .names()
344 .iter()
345 .map(|name| name.as_ref())
346 .collect::<Vec<_>>(),
347 vec!["x"]
348 );
349 }
350
351 #[test]
352 pub fn test_merge_order() {
353 let expr = Merge::new_expr(vec![
354 GetItem::new_expr("0", Identity::new_expr()),
355 GetItem::new_expr("1", Identity::new_expr()),
356 ]);
357
358 let test_array = StructArray::from_fields(&[
359 (
360 "0",
361 StructArray::from_fields(&[
362 ("a", buffer![0, 0, 0].into_array()),
363 ("c", buffer![1, 1, 1].into_array()),
364 ])
365 .unwrap()
366 .into_array(),
367 ),
368 (
369 "1",
370 StructArray::from_fields(&[
371 ("b", buffer![2, 2, 2].into_array()),
372 ("d", buffer![3, 3, 3].into_array()),
373 ])
374 .unwrap()
375 .into_array(),
376 ),
377 ])
378 .unwrap()
379 .into_array();
380 let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
381
382 assert_eq!(
383 actual_array.as_struct_typed().unwrap().names(),
384 &["a".into(), "c".into(), "b".into(), "d".into()].into()
385 );
386 }
387}