1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_dtype::DType;
11use vortex_dtype::FieldNames;
12use vortex_dtype::Nullability;
13use vortex_dtype::StructFields;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_session::VortexSession;
18use vortex_utils::aliases::hash_set::HashSet;
19
20use crate::ArrayRef;
21use crate::IntoArray as _;
22use crate::arrays::StructArray;
23use crate::expr::Arity;
24use crate::expr::ChildName;
25use crate::expr::ExecutionArgs;
26use crate::expr::ExprId;
27use crate::expr::Expression;
28use crate::expr::GetItem;
29use crate::expr::Pack;
30use crate::expr::PackOptions;
31use crate::expr::ReduceCtx;
32use crate::expr::ReduceNode;
33use crate::expr::ReduceNodeRef;
34use crate::expr::VTable;
35use crate::expr::VTableExt;
36use crate::expr::lit;
37use crate::validity::Validity;
38
39pub struct Merge;
46
47impl VTable for Merge {
48 type Options = DuplicateHandling;
49
50 fn id(&self) -> ExprId {
51 ExprId::new_ref("vortex.merge")
52 }
53
54 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
55 Ok(Some(match instance {
56 DuplicateHandling::RightMost => vec![0x00],
57 DuplicateHandling::Error => vec![0x01],
58 }))
59 }
60
61 fn deserialize(
62 &self,
63 _metadata: &[u8],
64 _session: &VortexSession,
65 ) -> VortexResult<Self::Options> {
66 let instance = match _metadata {
67 [0x00] => DuplicateHandling::RightMost,
68 [0x01] => DuplicateHandling::Error,
69 _ => {
70 vortex_bail!("invalid metadata for Merge expression");
71 }
72 };
73 Ok(instance)
74 }
75
76 fn arity(&self, _options: &Self::Options) -> Arity {
77 Arity::Variadic { min: 0, max: None }
78 }
79
80 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81 ChildName::from(Arc::from(format!("{}", child_idx)))
82 }
83
84 fn fmt_sql(
85 &self,
86 _options: &Self::Options,
87 expr: &Expression,
88 f: &mut Formatter<'_>,
89 ) -> std::fmt::Result {
90 write!(f, "merge(")?;
91 for (i, child) in expr.children().iter().enumerate() {
92 child.fmt_sql(f)?;
93 if i + 1 < expr.children().len() {
94 write!(f, ", ")?;
95 }
96 }
97 write!(f, ")")
98 }
99
100 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
101 let mut field_names = Vec::new();
102 let mut arrays = Vec::new();
103 let mut merge_nullability = Nullability::NonNullable;
104 let mut duplicate_names = HashSet::<_>::new();
105
106 for dtype in arg_dtypes {
107 let Some(fields) = dtype.as_struct_fields_opt() else {
108 vortex_bail!("merge expects struct input");
109 };
110 if dtype.is_nullable() {
111 vortex_bail!("merge expects non-nullable input");
112 }
113
114 merge_nullability |= dtype.nullability();
115
116 for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
117 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
118 duplicate_names.insert(field_name.clone());
119 arrays[idx] = field_dtype;
120 } else {
121 field_names.push(field_name.clone());
122 arrays.push(field_dtype);
123 }
124 }
125 }
126
127 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
128 vortex_bail!(
129 "merge: duplicate fields in children: {}",
130 duplicate_names.into_iter().format(", ")
131 )
132 }
133
134 Ok(DType::Struct(
135 StructFields::new(FieldNames::from(field_names), arrays),
136 merge_nullability,
137 ))
138 }
139
140 fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
141 let mut field_names = Vec::new();
143 let mut arrays = Vec::new();
144 let mut duplicate_names = HashSet::<_>::new();
145
146 for input in args.inputs {
147 let array = input.execute::<StructArray>(args.ctx)?;
148 if array.dtype().is_nullable() {
149 vortex_bail!("merge expects non-nullable input");
150 }
151
152 for (field_name, field_array) in array
153 .names()
154 .iter()
155 .zip_eq(array.unmasked_fields().iter().cloned())
156 {
157 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
159 duplicate_names.insert(field_name.clone());
160 arrays[idx] = field_array;
161 } else {
162 field_names.push(field_name.clone());
163 arrays.push(field_array);
164 }
165 }
166 }
167
168 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
169 vortex_bail!(
170 "merge: duplicate fields in children: {}",
171 duplicate_names.into_iter().format(", ")
172 )
173 }
174
175 let validity = Validity::NonNullable;
177 let len = args.row_count;
178 Ok(
179 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
180 .into_array(),
181 )
182 }
183
184 fn reduce(
185 &self,
186 options: &Self::Options,
187 node: &dyn ReduceNode,
188 ctx: &dyn ReduceCtx,
189 ) -> VortexResult<Option<ReduceNodeRef>> {
190 let mut names = Vec::with_capacity(node.child_count() * 2);
191 let mut children = Vec::with_capacity(node.child_count() * 2);
192 let mut duplicate_names = HashSet::<_>::new();
193
194 for child in (0..node.child_count()).map(|i| node.child(i)) {
195 let child_dtype = child.node_dtype()?;
196 if !child_dtype.is_struct() {
197 vortex_bail!(
198 "Merge child must return a non-nullable struct dtype, got {}",
199 child_dtype
200 )
201 }
202
203 let child_dtype = child_dtype
204 .as_struct_fields_opt()
205 .vortex_expect("expected struct");
206
207 for name in child_dtype.names().iter() {
208 if let Some(idx) = names.iter().position(|n| n == name) {
209 duplicate_names.insert(name.clone());
210 children[idx] = child.clone();
211 } else {
212 names.push(name.clone());
213 children.push(child.clone());
214 }
215 }
216
217 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
218 vortex_bail!(
219 "merge: duplicate fields in children: {}",
220 duplicate_names.into_iter().format(", ")
221 )
222 }
223 }
224
225 let pack_children: Vec<_> = names
226 .iter()
227 .zip(children)
228 .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
229 .try_collect()?;
230
231 let pack_expr = ctx.new_node(
232 Pack.bind(PackOptions {
233 names: FieldNames::from(names),
234 nullability: node.node_dtype()?.nullability(),
235 }),
236 &pack_children,
237 )?;
238
239 Ok(Some(pack_expr))
240 }
241
242 fn validity(
243 &self,
244 _options: &Self::Options,
245 _expression: &Expression,
246 ) -> VortexResult<Option<Expression>> {
247 Ok(Some(lit(true)))
248 }
249
250 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
251 true
252 }
253
254 fn is_fallible(&self, instance: &Self::Options) -> bool {
255 matches!(instance, DuplicateHandling::Error)
256 }
257}
258
259#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
261pub enum DuplicateHandling {
262 RightMost,
264 #[default]
266 Error,
267}
268
269impl Display for DuplicateHandling {
270 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
271 match self {
272 DuplicateHandling::RightMost => write!(f, "RightMost"),
273 DuplicateHandling::Error => write!(f, "Error"),
274 }
275 }
276}
277
278pub fn merge(elements: impl IntoIterator<Item = impl Into<Expression>>) -> Expression {
289 let values = elements.into_iter().map(|value| value.into()).collect_vec();
290 Merge.new_expr(DuplicateHandling::default(), values)
291}
292
293pub fn merge_opts(
294 elements: impl IntoIterator<Item = impl Into<Expression>>,
295 duplicate_handling: DuplicateHandling,
296) -> Expression {
297 let values = elements.into_iter().map(|value| value.into()).collect_vec();
298 Merge.new_expr(duplicate_handling, values)
299}
300
301#[cfg(test)]
302mod tests {
303 use vortex_buffer::buffer;
304 use vortex_dtype::DType;
305 use vortex_dtype::Nullability::NonNullable;
306 use vortex_dtype::PType::I32;
307 use vortex_dtype::PType::I64;
308 use vortex_dtype::PType::U32;
309 use vortex_dtype::PType::U64;
310 use vortex_error::VortexResult;
311 use vortex_error::vortex_bail;
312
313 use super::merge;
314 use crate::Array;
315 use crate::IntoArray;
316 use crate::ToCanonical;
317 use crate::arrays::PrimitiveArray;
318 use crate::arrays::StructArray;
319 use crate::assert_arrays_eq;
320 use crate::expr::Expression;
321 use crate::expr::Pack;
322 use crate::expr::exprs::get_item::get_item;
323 use crate::expr::exprs::merge::DuplicateHandling;
324 use crate::expr::exprs::merge::merge_opts;
325 use crate::expr::exprs::root::root;
326
327 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
328 let mut field_path = field_path.iter();
329
330 let Some(field) = field_path.next() else {
331 vortex_bail!("empty field path");
332 };
333
334 let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
335 for field in field_path {
336 array = array.to_struct().unmasked_field_by_name(field)?.clone();
337 }
338 Ok(array.to_primitive())
339 }
340
341 #[test]
342 pub fn test_merge_right_most() {
343 let expr = merge_opts(
344 vec![
345 get_item("0", root()),
346 get_item("1", root()),
347 get_item("2", root()),
348 ],
349 DuplicateHandling::RightMost,
350 );
351
352 let test_array = StructArray::from_fields(&[
353 (
354 "0",
355 StructArray::from_fields(&[
356 ("a", buffer![0, 0, 0].into_array()),
357 ("b", buffer![1, 1, 1].into_array()),
358 ])
359 .unwrap()
360 .into_array(),
361 ),
362 (
363 "1",
364 StructArray::from_fields(&[
365 ("b", buffer![2, 2, 2].into_array()),
366 ("c", buffer![3, 3, 3].into_array()),
367 ])
368 .unwrap()
369 .into_array(),
370 ),
371 (
372 "2",
373 StructArray::from_fields(&[
374 ("d", buffer![4, 4, 4].into_array()),
375 ("e", buffer![5, 5, 5].into_array()),
376 ])
377 .unwrap()
378 .into_array(),
379 ),
380 ])
381 .unwrap()
382 .into_array();
383 let actual_array = test_array.apply(&expr).unwrap();
384
385 assert_eq!(
386 actual_array.as_struct_typed().names(),
387 ["a", "b", "c", "d", "e"]
388 );
389
390 assert_arrays_eq!(
391 primitive_field(&actual_array, &["a"]).unwrap(),
392 PrimitiveArray::from_iter([0i32, 0, 0])
393 );
394 assert_arrays_eq!(
395 primitive_field(&actual_array, &["b"]).unwrap(),
396 PrimitiveArray::from_iter([2i32, 2, 2])
397 );
398 assert_arrays_eq!(
399 primitive_field(&actual_array, &["c"]).unwrap(),
400 PrimitiveArray::from_iter([3i32, 3, 3])
401 );
402 assert_arrays_eq!(
403 primitive_field(&actual_array, &["d"]).unwrap(),
404 PrimitiveArray::from_iter([4i32, 4, 4])
405 );
406 assert_arrays_eq!(
407 primitive_field(&actual_array, &["e"]).unwrap(),
408 PrimitiveArray::from_iter([5i32, 5, 5])
409 );
410 }
411
412 #[test]
413 #[should_panic(expected = "merge: duplicate fields in children")]
414 pub fn test_merge_error_on_dupe_return_dtype() {
415 let expr = merge_opts(
416 vec![get_item("0", root()), get_item("1", root())],
417 DuplicateHandling::Error,
418 );
419 let test_array = StructArray::try_from_iter([
420 (
421 "0",
422 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
423 ),
424 (
425 "1",
426 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
427 ),
428 ])
429 .unwrap()
430 .into_array();
431
432 expr.return_dtype(test_array.dtype()).unwrap();
433 }
434
435 #[test]
436 #[should_panic(expected = "merge: duplicate fields in children")]
437 pub fn test_merge_error_on_dupe_evaluate() {
438 let expr = merge_opts(
439 vec![get_item("0", root()), get_item("1", root())],
440 DuplicateHandling::Error,
441 );
442 let test_array = StructArray::try_from_iter([
443 (
444 "0",
445 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
446 ),
447 (
448 "1",
449 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
450 ),
451 ])
452 .unwrap()
453 .into_array();
454
455 test_array.apply(&expr).unwrap();
456 }
457
458 #[test]
459 pub fn test_empty_merge() {
460 let expr = merge(Vec::<Expression>::new());
461
462 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
463 .unwrap()
464 .into_array();
465 let actual_array = test_array.clone().apply(&expr).unwrap();
466 assert_eq!(actual_array.len(), test_array.len());
467 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
468 }
469
470 #[test]
471 pub fn test_nested_merge() {
472 let expr = merge_opts(
475 vec![get_item("0", root()), get_item("1", root())],
476 DuplicateHandling::RightMost,
477 );
478
479 let test_array = StructArray::from_fields(&[
480 (
481 "0",
482 StructArray::from_fields(&[(
483 "a",
484 StructArray::from_fields(&[
485 ("x", buffer![0, 0, 0].into_array()),
486 ("y", buffer![1, 1, 1].into_array()),
487 ])
488 .unwrap()
489 .into_array(),
490 )])
491 .unwrap()
492 .into_array(),
493 ),
494 (
495 "1",
496 StructArray::from_fields(&[(
497 "a",
498 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
499 .unwrap()
500 .into_array(),
501 )])
502 .unwrap()
503 .into_array(),
504 ),
505 ])
506 .unwrap()
507 .into_array();
508 let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
509
510 assert_eq!(
511 actual_array
512 .unmasked_field_by_name("a")
513 .unwrap()
514 .to_struct()
515 .names()
516 .iter()
517 .map(|name| name.as_ref())
518 .collect::<Vec<_>>(),
519 vec!["x"]
520 );
521 }
522
523 #[test]
524 pub fn test_merge_order() {
525 let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
526
527 let test_array = StructArray::from_fields(&[
528 (
529 "0",
530 StructArray::from_fields(&[
531 ("a", buffer![0, 0, 0].into_array()),
532 ("c", buffer![1, 1, 1].into_array()),
533 ])
534 .unwrap()
535 .into_array(),
536 ),
537 (
538 "1",
539 StructArray::from_fields(&[
540 ("b", buffer![2, 2, 2].into_array()),
541 ("d", buffer![3, 3, 3].into_array()),
542 ])
543 .unwrap()
544 .into_array(),
545 ),
546 ])
547 .unwrap()
548 .into_array();
549 let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
550
551 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
552 }
553
554 #[test]
555 pub fn test_display() {
556 let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
557 assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
558
559 let expr2 = merge(vec![get_item("a", root())]);
560 assert_eq!(expr2.to_string(), "merge($.a)");
561 }
562
563 #[test]
564 fn test_remove_merge() {
565 let dtype = DType::struct_(
566 [
567 ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
568 ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
569 ],
570 NonNullable,
571 );
572
573 let e = merge_opts(
574 [get_item("0", root()), get_item("1", root())],
575 DuplicateHandling::RightMost,
576 );
577
578 let result = e.optimize(&dtype).unwrap();
579
580 assert!(result.is::<Pack>());
581 assert_eq!(
582 result.return_dtype(&dtype).unwrap(),
583 DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
584 );
585 }
586}