1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_dtype::DType;
11use vortex_dtype::FieldNames;
12use vortex_dtype::Nullability;
13use vortex_dtype::StructFields;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_utils::aliases::hash_set::HashSet;
18use vortex_vector::Datum;
19
20use crate::Array;
21use crate::ArrayRef;
22use crate::IntoArray as _;
23use crate::ToCanonical;
24use crate::arrays::StructArray;
25use crate::expr::Arity;
26use crate::expr::ChildName;
27use crate::expr::ExecutionArgs;
28use crate::expr::ExprId;
29use crate::expr::Expression;
30use crate::expr::GetItem;
31use crate::expr::Pack;
32use crate::expr::PackOptions;
33use crate::expr::ReduceCtx;
34use crate::expr::ReduceNode;
35use crate::expr::ReduceNodeRef;
36use crate::expr::VTable;
37use crate::expr::VTableExt;
38use crate::validity::Validity;
39
40pub struct Merge;
47
48impl VTable for Merge {
49 type Options = DuplicateHandling;
50
51 fn id(&self) -> ExprId {
52 ExprId::new_ref("vortex.merge")
53 }
54
55 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
56 Ok(Some(match instance {
57 DuplicateHandling::RightMost => vec![0x00],
58 DuplicateHandling::Error => vec![0x01],
59 }))
60 }
61
62 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
63 let instance = match metadata {
64 [0x00] => DuplicateHandling::RightMost,
65 [0x01] => DuplicateHandling::Error,
66 _ => {
67 vortex_bail!("invalid metadata for Merge expression");
68 }
69 };
70 Ok(instance)
71 }
72
73 fn arity(&self, _options: &Self::Options) -> Arity {
74 Arity::Variadic { min: 0, max: None }
75 }
76
77 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
78 ChildName::from(Arc::from(format!("{}", child_idx)))
79 }
80
81 fn fmt_sql(
82 &self,
83 _options: &Self::Options,
84 expr: &Expression,
85 f: &mut Formatter<'_>,
86 ) -> std::fmt::Result {
87 write!(f, "merge(")?;
88 for (i, child) in expr.children().iter().enumerate() {
89 child.fmt_sql(f)?;
90 if i + 1 < expr.children().len() {
91 write!(f, ", ")?;
92 }
93 }
94 write!(f, ")")
95 }
96
97 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
98 let mut field_names = Vec::new();
99 let mut arrays = Vec::new();
100 let mut merge_nullability = Nullability::NonNullable;
101 let mut duplicate_names = HashSet::<_>::new();
102
103 for dtype in arg_dtypes {
104 let Some(fields) = dtype.as_struct_fields_opt() else {
105 vortex_bail!("merge expects struct input");
106 };
107 if dtype.is_nullable() {
108 vortex_bail!("merge expects non-nullable input");
109 }
110
111 merge_nullability |= dtype.nullability();
112
113 for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
114 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
115 duplicate_names.insert(field_name.clone());
116 arrays[idx] = field_dtype;
117 } else {
118 field_names.push(field_name.clone());
119 arrays.push(field_dtype);
120 }
121 }
122 }
123
124 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
125 vortex_bail!(
126 "merge: duplicate fields in children: {}",
127 duplicate_names.into_iter().format(", ")
128 )
129 }
130
131 Ok(DType::Struct(
132 StructFields::new(FieldNames::from(field_names), arrays),
133 merge_nullability,
134 ))
135 }
136
137 fn evaluate(
138 &self,
139 options: &Self::Options,
140 expr: &Expression,
141 scope: &ArrayRef,
142 ) -> VortexResult<ArrayRef> {
143 let mut field_names = Vec::new();
145 let mut arrays = Vec::new();
146 let mut duplicate_names = HashSet::<_>::new();
147
148 for child in expr.children().iter() {
149 let array = child.evaluate(scope)?;
151 if array.dtype().is_nullable() {
152 vortex_bail!("merge expects non-nullable input");
153 }
154 if !array.dtype().is_struct() {
155 vortex_bail!("merge expects struct input");
156 }
157 let array = array.to_struct();
158
159 for (field_name, array) in array.names().iter().zip_eq(array.fields().iter().cloned()) {
160 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
162 duplicate_names.insert(field_name.clone());
163 arrays[idx] = array;
164 } else {
165 field_names.push(field_name.clone());
166 arrays.push(array);
167 }
168 }
169 }
170
171 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
172 vortex_bail!(
173 "merge: duplicate fields in children: {}",
174 duplicate_names.into_iter().format(", ")
175 )
176 }
177
178 let validity = Validity::NonNullable;
180 let len = scope.len();
181 Ok(
182 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
183 .into_array(),
184 )
185 }
186
187 fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult<Datum> {
188 todo!()
189 }
190
191 fn reduce(
192 &self,
193 options: &Self::Options,
194 node: &dyn ReduceNode,
195 ctx: &dyn ReduceCtx,
196 ) -> VortexResult<Option<ReduceNodeRef>> {
197 let merge_dtype = node.node_dtype()?;
198 let mut names = Vec::with_capacity(node.child_count() * 2);
199 let mut children = Vec::with_capacity(node.child_count() * 2);
200 let mut duplicate_names = HashSet::<_>::new();
201
202 for child in node.children() {
203 let child_dtype = child.node_dtype()?;
204 if !child_dtype.is_struct() {
205 vortex_bail!(
206 "Merge child must return a non-nullable struct dtype, got {}",
207 child_dtype
208 )
209 }
210
211 let child_dtype = child_dtype
212 .as_struct_fields_opt()
213 .vortex_expect("expected struct");
214
215 for name in child_dtype.names().iter() {
216 if let Some(idx) = names.iter().position(|n| n == name) {
217 duplicate_names.insert(name.clone());
218 children[idx] = child.clone();
219 } else {
220 names.push(name.clone());
221 children.push(child.clone());
222 }
223 }
224
225 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
226 vortex_bail!(
227 "merge: duplicate fields in children: {}",
228 duplicate_names.into_iter().format(", ")
229 )
230 }
231 }
232
233 let pack_children: Vec<_> = names
234 .iter()
235 .zip(children)
236 .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
237 .try_collect()?;
238
239 let pack_expr = ctx.new_node(
240 Pack.bind(PackOptions {
241 names: FieldNames::from(names),
242 nullability: merge_dtype.nullability(),
243 }),
244 &pack_children,
245 )?;
246
247 Ok(Some(pack_expr))
248 }
249
250 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
251 true
252 }
253
254 fn is_fallible(&self, instance: &Self::Options) -> bool {
255 matches!(instance, DuplicateHandling::Error)
256 }
257}
258
259#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
261pub enum DuplicateHandling {
262 RightMost,
264 #[default]
266 Error,
267}
268
269impl Display for DuplicateHandling {
270 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
271 match self {
272 DuplicateHandling::RightMost => write!(f, "RightMost"),
273 DuplicateHandling::Error => write!(f, "Error"),
274 }
275 }
276}
277
278pub fn merge(elements: impl IntoIterator<Item = impl Into<Expression>>) -> Expression {
289 let values = elements.into_iter().map(|value| value.into()).collect_vec();
290 Merge.new_expr(DuplicateHandling::default(), values)
291}
292
293pub fn merge_opts(
294 elements: impl IntoIterator<Item = impl Into<Expression>>,
295 duplicate_handling: DuplicateHandling,
296) -> Expression {
297 let values = elements.into_iter().map(|value| value.into()).collect_vec();
298 Merge.new_expr(duplicate_handling, values)
299}
300
301#[cfg(test)]
302mod tests {
303 use vortex_buffer::buffer;
304 use vortex_dtype::DType;
305 use vortex_dtype::Nullability::NonNullable;
306 use vortex_dtype::PType::I32;
307 use vortex_dtype::PType::I64;
308 use vortex_dtype::PType::U32;
309 use vortex_dtype::PType::U64;
310 use vortex_error::VortexResult;
311 use vortex_error::vortex_bail;
312
313 use super::merge;
314 use crate::Array;
315 use crate::IntoArray;
316 use crate::ToCanonical;
317 use crate::arrays::PrimitiveArray;
318 use crate::arrays::StructArray;
319 use crate::expr::Expression;
320 use crate::expr::Pack;
321 use crate::expr::exprs::get_item::get_item;
322 use crate::expr::exprs::merge::DuplicateHandling;
323 use crate::expr::exprs::merge::merge_opts;
324 use crate::expr::exprs::root::root;
325
326 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
327 let mut field_path = field_path.iter();
328
329 let Some(field) = field_path.next() else {
330 vortex_bail!("empty field path");
331 };
332
333 let mut array = array.to_struct().field_by_name(field)?.clone();
334 for field in field_path {
335 array = array.to_struct().field_by_name(field)?.clone();
336 }
337 Ok(array.to_primitive())
338 }
339
340 #[test]
341 pub fn test_merge_right_most() {
342 let expr = merge_opts(
343 vec![
344 get_item("0", root()),
345 get_item("1", root()),
346 get_item("2", root()),
347 ],
348 DuplicateHandling::RightMost,
349 );
350
351 let test_array = StructArray::from_fields(&[
352 (
353 "0",
354 StructArray::from_fields(&[
355 ("a", buffer![0, 0, 0].into_array()),
356 ("b", buffer![1, 1, 1].into_array()),
357 ])
358 .unwrap()
359 .into_array(),
360 ),
361 (
362 "1",
363 StructArray::from_fields(&[
364 ("b", buffer![2, 2, 2].into_array()),
365 ("c", buffer![3, 3, 3].into_array()),
366 ])
367 .unwrap()
368 .into_array(),
369 ),
370 (
371 "2",
372 StructArray::from_fields(&[
373 ("d", buffer![4, 4, 4].into_array()),
374 ("e", buffer![5, 5, 5].into_array()),
375 ])
376 .unwrap()
377 .into_array(),
378 ),
379 ])
380 .unwrap()
381 .into_array();
382 let actual_array = expr.evaluate(&test_array).unwrap();
383
384 assert_eq!(
385 actual_array.as_struct_typed().names(),
386 ["a", "b", "c", "d", "e"]
387 );
388
389 assert_eq!(
390 primitive_field(&actual_array, &["a"])
391 .unwrap()
392 .as_slice::<i32>(),
393 [0, 0, 0]
394 );
395 assert_eq!(
396 primitive_field(&actual_array, &["b"])
397 .unwrap()
398 .as_slice::<i32>(),
399 [2, 2, 2]
400 );
401 assert_eq!(
402 primitive_field(&actual_array, &["c"])
403 .unwrap()
404 .as_slice::<i32>(),
405 [3, 3, 3]
406 );
407 assert_eq!(
408 primitive_field(&actual_array, &["d"])
409 .unwrap()
410 .as_slice::<i32>(),
411 [4, 4, 4]
412 );
413 assert_eq!(
414 primitive_field(&actual_array, &["e"])
415 .unwrap()
416 .as_slice::<i32>(),
417 [5, 5, 5]
418 );
419 }
420
421 #[test]
422 #[should_panic(expected = "merge: duplicate fields in children")]
423 pub fn test_merge_error_on_dupe_return_dtype() {
424 let expr = merge_opts(
425 vec![get_item("0", root()), get_item("1", root())],
426 DuplicateHandling::Error,
427 );
428 let test_array = StructArray::try_from_iter([
429 (
430 "0",
431 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
432 ),
433 (
434 "1",
435 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
436 ),
437 ])
438 .unwrap()
439 .into_array();
440
441 expr.return_dtype(test_array.dtype()).unwrap();
442 }
443
444 #[test]
445 #[should_panic(expected = "merge: duplicate fields in children")]
446 pub fn test_merge_error_on_dupe_evaluate() {
447 let expr = merge_opts(
448 vec![get_item("0", root()), get_item("1", root())],
449 DuplicateHandling::Error,
450 );
451 let test_array = StructArray::try_from_iter([
452 (
453 "0",
454 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
455 ),
456 (
457 "1",
458 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
459 ),
460 ])
461 .unwrap()
462 .into_array();
463
464 expr.evaluate(&test_array).unwrap();
465 }
466
467 #[test]
468 pub fn test_empty_merge() {
469 let expr = merge(Vec::<Expression>::new());
470
471 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
472 .unwrap()
473 .into_array();
474 let actual_array = expr.evaluate(&test_array.clone()).unwrap();
475 assert_eq!(actual_array.len(), test_array.len());
476 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
477 }
478
479 #[test]
480 pub fn test_nested_merge() {
481 let expr = merge_opts(
484 vec![get_item("0", root()), get_item("1", root())],
485 DuplicateHandling::RightMost,
486 );
487
488 let test_array = StructArray::from_fields(&[
489 (
490 "0",
491 StructArray::from_fields(&[(
492 "a",
493 StructArray::from_fields(&[
494 ("x", buffer![0, 0, 0].into_array()),
495 ("y", buffer![1, 1, 1].into_array()),
496 ])
497 .unwrap()
498 .into_array(),
499 )])
500 .unwrap()
501 .into_array(),
502 ),
503 (
504 "1",
505 StructArray::from_fields(&[(
506 "a",
507 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
508 .unwrap()
509 .into_array(),
510 )])
511 .unwrap()
512 .into_array(),
513 ),
514 ])
515 .unwrap()
516 .into_array();
517 let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
518
519 assert_eq!(
520 actual_array
521 .field_by_name("a")
522 .unwrap()
523 .to_struct()
524 .names()
525 .iter()
526 .map(|name| name.as_ref())
527 .collect::<Vec<_>>(),
528 vec!["x"]
529 );
530 }
531
532 #[test]
533 pub fn test_merge_order() {
534 let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
535
536 let test_array = StructArray::from_fields(&[
537 (
538 "0",
539 StructArray::from_fields(&[
540 ("a", buffer![0, 0, 0].into_array()),
541 ("c", buffer![1, 1, 1].into_array()),
542 ])
543 .unwrap()
544 .into_array(),
545 ),
546 (
547 "1",
548 StructArray::from_fields(&[
549 ("b", buffer![2, 2, 2].into_array()),
550 ("d", buffer![3, 3, 3].into_array()),
551 ])
552 .unwrap()
553 .into_array(),
554 ),
555 ])
556 .unwrap()
557 .into_array();
558 let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
559
560 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
561 }
562
563 #[test]
564 pub fn test_display() {
565 let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
566 assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
567
568 let expr2 = merge(vec![get_item("a", root())]);
569 assert_eq!(expr2.to_string(), "merge($.a)");
570 }
571
572 #[test]
573 fn test_remove_merge() {
574 let dtype = DType::struct_(
575 [
576 ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
577 ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
578 ],
579 NonNullable,
580 );
581
582 let e = merge_opts(
583 [get_item("0", root()), get_item("1", root())],
584 DuplicateHandling::RightMost,
585 );
586
587 let result = e.optimize(&dtype).unwrap();
588
589 assert!(result.is::<Pack>());
590 assert_eq!(
591 result.return_dtype(&dtype).unwrap(),
592 DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
593 );
594 }
595}