1use std::fmt::Formatter;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use vortex_dtype::{DType, FieldNames, Nullability, StructFields};
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_utils::aliases::hash_set::HashSet;
12
13use crate::arrays::StructArray;
14use crate::expr::{ChildName, ExprId, Expression, ExpressionView, VTable, VTableExt};
15use crate::validity::Validity;
16use crate::{Array, ArrayRef, IntoArray as _, ToCanonical};
17
18pub struct Merge;
25
26impl VTable for Merge {
27 type Instance = DuplicateHandling;
28
29 fn id(&self) -> ExprId {
30 ExprId::new_ref("vortex.merge")
31 }
32
33 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
34 Ok(Some(match instance {
35 DuplicateHandling::RightMost => vec![0x00],
36 DuplicateHandling::Error => vec![0x01],
37 }))
38 }
39
40 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
41 let instance = match metadata {
42 [0x00] => DuplicateHandling::RightMost,
43 [0x01] => DuplicateHandling::Error,
44 _ => {
45 vortex_bail!("invalid metadata for Merge expression");
46 }
47 };
48 Ok(Some(instance))
49 }
50
51 fn validate(&self, _expr: &ExpressionView<Self>) -> VortexResult<()> {
52 Ok(())
53 }
54
55 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
56 ChildName::from(Arc::from(format!("{}", child_idx)))
57 }
58
59 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
60 write!(f, "merge(")?;
61 for (i, child) in expr.children().iter().enumerate() {
62 child.fmt_sql(f)?;
63 if i + 1 < expr.children().len() {
64 write!(f, ", ")?;
65 }
66 }
67 write!(f, ")")
68 }
69
70 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
71 let mut field_names = Vec::new();
72 let mut arrays = Vec::new();
73 let mut merge_nullability = Nullability::NonNullable;
74 let mut duplicate_names = HashSet::<_>::new();
75
76 for child in expr.children().iter() {
77 let dtype = child.return_dtype(scope)?;
78 let Some(fields) = dtype.as_struct_fields_opt() else {
79 vortex_bail!("merge expects struct input");
80 };
81 if dtype.is_nullable() {
82 vortex_bail!("merge expects non-nullable input");
83 }
84
85 merge_nullability |= dtype.nullability();
86
87 for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
88 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
89 duplicate_names.insert(field_name.clone());
90 arrays[idx] = field_dtype;
91 } else {
92 field_names.push(field_name.clone());
93 arrays.push(field_dtype);
94 }
95 }
96 }
97
98 if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() {
99 vortex_bail!(
100 "merge: duplicate fields in children: {}",
101 duplicate_names.into_iter().format(", ")
102 )
103 }
104
105 Ok(DType::Struct(
106 StructFields::new(FieldNames::from(field_names), arrays),
107 merge_nullability,
108 ))
109 }
110
111 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
112 let mut field_names = Vec::new();
114 let mut arrays = Vec::new();
115 let mut duplicate_names = HashSet::<_>::new();
116
117 for child in expr.children().iter() {
118 let array = child.evaluate(scope)?;
120 if array.dtype().is_nullable() {
121 vortex_bail!("merge expects non-nullable input");
122 }
123 if !array.dtype().is_struct() {
124 vortex_bail!("merge expects struct input");
125 }
126 let array = array.to_struct();
127
128 for (field_name, array) in array.names().iter().zip_eq(array.fields().iter().cloned()) {
129 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
131 duplicate_names.insert(field_name.clone());
132 arrays[idx] = array;
133 } else {
134 field_names.push(field_name.clone());
135 arrays.push(array);
136 }
137 }
138 }
139
140 if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() {
141 vortex_bail!(
142 "merge: duplicate fields in children: {}",
143 duplicate_names.into_iter().format(", ")
144 )
145 }
146
147 let validity = Validity::NonNullable;
149 let len = scope.len();
150 Ok(
151 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
152 .into_array(),
153 )
154 }
155}
156
157#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
159pub enum DuplicateHandling {
160 RightMost,
162 #[default]
164 Error,
165}
166
167pub fn merge(elements: impl IntoIterator<Item = impl Into<Expression>>) -> Expression {
178 let values = elements.into_iter().map(|value| value.into()).collect_vec();
179 Merge.new_expr(DuplicateHandling::default(), values)
180}
181
182pub fn merge_opts(
183 elements: impl IntoIterator<Item = impl Into<Expression>>,
184 duplicate_handling: DuplicateHandling,
185) -> Expression {
186 let values = elements.into_iter().map(|value| value.into()).collect_vec();
187 Merge.new_expr(duplicate_handling, values)
188}
189
190#[cfg(test)]
191mod tests {
192 use vortex_buffer::buffer;
193 use vortex_error::{VortexResult, vortex_bail};
194
195 use super::merge;
196 use crate::arrays::{PrimitiveArray, StructArray};
197 use crate::expr::Expression;
198 use crate::expr::exprs::get_item::get_item;
199 use crate::expr::exprs::merge::{DuplicateHandling, merge_opts};
200 use crate::expr::exprs::root::root;
201 use crate::{Array, IntoArray, ToCanonical};
202
203 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
204 let mut field_path = field_path.iter();
205
206 let Some(field) = field_path.next() else {
207 vortex_bail!("empty field path");
208 };
209
210 let mut array = array.to_struct().field_by_name(field)?.clone();
211 for field in field_path {
212 array = array.to_struct().field_by_name(field)?.clone();
213 }
214 Ok(array.to_primitive())
215 }
216
217 #[test]
218 pub fn test_merge_right_most() {
219 let expr = merge_opts(
220 vec![
221 get_item("0", root()),
222 get_item("1", root()),
223 get_item("2", root()),
224 ],
225 DuplicateHandling::RightMost,
226 );
227
228 let test_array = StructArray::from_fields(&[
229 (
230 "0",
231 StructArray::from_fields(&[
232 ("a", buffer![0, 0, 0].into_array()),
233 ("b", buffer![1, 1, 1].into_array()),
234 ])
235 .unwrap()
236 .into_array(),
237 ),
238 (
239 "1",
240 StructArray::from_fields(&[
241 ("b", buffer![2, 2, 2].into_array()),
242 ("c", buffer![3, 3, 3].into_array()),
243 ])
244 .unwrap()
245 .into_array(),
246 ),
247 (
248 "2",
249 StructArray::from_fields(&[
250 ("d", buffer![4, 4, 4].into_array()),
251 ("e", buffer![5, 5, 5].into_array()),
252 ])
253 .unwrap()
254 .into_array(),
255 ),
256 ])
257 .unwrap()
258 .into_array();
259 let actual_array = expr.evaluate(&test_array).unwrap();
260
261 assert_eq!(
262 actual_array.as_struct_typed().names(),
263 ["a", "b", "c", "d", "e"]
264 );
265
266 assert_eq!(
267 primitive_field(&actual_array, &["a"])
268 .unwrap()
269 .as_slice::<i32>(),
270 [0, 0, 0]
271 );
272 assert_eq!(
273 primitive_field(&actual_array, &["b"])
274 .unwrap()
275 .as_slice::<i32>(),
276 [2, 2, 2]
277 );
278 assert_eq!(
279 primitive_field(&actual_array, &["c"])
280 .unwrap()
281 .as_slice::<i32>(),
282 [3, 3, 3]
283 );
284 assert_eq!(
285 primitive_field(&actual_array, &["d"])
286 .unwrap()
287 .as_slice::<i32>(),
288 [4, 4, 4]
289 );
290 assert_eq!(
291 primitive_field(&actual_array, &["e"])
292 .unwrap()
293 .as_slice::<i32>(),
294 [5, 5, 5]
295 );
296 }
297
298 #[test]
299 #[should_panic(expected = "merge: duplicate fields in children")]
300 pub fn test_merge_error_on_dupe_return_dtype() {
301 let expr = merge_opts(
302 vec![get_item("0", root()), get_item("1", root())],
303 DuplicateHandling::Error,
304 );
305 let test_array = StructArray::try_from_iter([
306 (
307 "0",
308 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
309 ),
310 (
311 "1",
312 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
313 ),
314 ])
315 .unwrap()
316 .into_array();
317
318 expr.return_dtype(test_array.dtype()).unwrap();
319 }
320
321 #[test]
322 #[should_panic(expected = "merge: duplicate fields in children")]
323 pub fn test_merge_error_on_dupe_evaluate() {
324 let expr = merge_opts(
325 vec![get_item("0", root()), get_item("1", root())],
326 DuplicateHandling::Error,
327 );
328 let test_array = StructArray::try_from_iter([
329 (
330 "0",
331 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
332 ),
333 (
334 "1",
335 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
336 ),
337 ])
338 .unwrap()
339 .into_array();
340
341 expr.evaluate(&test_array).unwrap();
342 }
343
344 #[test]
345 pub fn test_empty_merge() {
346 let expr = merge(Vec::<Expression>::new());
347
348 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
349 .unwrap()
350 .into_array();
351 let actual_array = expr.evaluate(&test_array.clone()).unwrap();
352 assert_eq!(actual_array.len(), test_array.len());
353 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
354 }
355
356 #[test]
357 pub fn test_nested_merge() {
358 let expr = merge_opts(
361 vec![get_item("0", root()), get_item("1", root())],
362 DuplicateHandling::RightMost,
363 );
364
365 let test_array = StructArray::from_fields(&[
366 (
367 "0",
368 StructArray::from_fields(&[(
369 "a",
370 StructArray::from_fields(&[
371 ("x", buffer![0, 0, 0].into_array()),
372 ("y", buffer![1, 1, 1].into_array()),
373 ])
374 .unwrap()
375 .into_array(),
376 )])
377 .unwrap()
378 .into_array(),
379 ),
380 (
381 "1",
382 StructArray::from_fields(&[(
383 "a",
384 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
385 .unwrap()
386 .into_array(),
387 )])
388 .unwrap()
389 .into_array(),
390 ),
391 ])
392 .unwrap()
393 .into_array();
394 let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
395
396 assert_eq!(
397 actual_array
398 .field_by_name("a")
399 .unwrap()
400 .to_struct()
401 .names()
402 .iter()
403 .map(|name| name.as_ref())
404 .collect::<Vec<_>>(),
405 vec!["x"]
406 );
407 }
408
409 #[test]
410 pub fn test_merge_order() {
411 let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
412
413 let test_array = StructArray::from_fields(&[
414 (
415 "0",
416 StructArray::from_fields(&[
417 ("a", buffer![0, 0, 0].into_array()),
418 ("c", buffer![1, 1, 1].into_array()),
419 ])
420 .unwrap()
421 .into_array(),
422 ),
423 (
424 "1",
425 StructArray::from_fields(&[
426 ("b", buffer![2, 2, 2].into_array()),
427 ("d", buffer![3, 3, 3].into_array()),
428 ])
429 .unwrap()
430 .into_array(),
431 ),
432 ])
433 .unwrap()
434 .into_array();
435 let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
436
437 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
438 }
439
440 #[test]
441 pub fn test_display() {
442 let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
443 assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
444
445 let expr2 = merge(vec![get_item("a", root())]);
446 assert_eq!(expr2.to_string(), "merge($.a)");
447 }
448}