1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_session::VortexSession;
14use vortex_utils::aliases::hash_set::HashSet;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray as _;
19use crate::arrays::StructArray;
20use crate::arrays::struct_::StructArrayExt;
21use crate::dtype::DType;
22use crate::dtype::FieldNames;
23use crate::dtype::Nullability;
24use crate::dtype::StructFields;
25use crate::expr::Expression;
26use crate::expr::lit;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::ExecutionArgs;
30use crate::scalar_fn::ReduceCtx;
31use crate::scalar_fn::ReduceNode;
32use crate::scalar_fn::ReduceNodeRef;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::ScalarFnVTableExt;
36use crate::scalar_fn::fns::get_item::GetItem;
37use crate::scalar_fn::fns::pack::Pack;
38use crate::scalar_fn::fns::pack::PackOptions;
39use crate::validity::Validity;
40
41#[derive(Clone)]
48pub struct Merge;
49
50impl ScalarFnVTable for Merge {
51 type Options = DuplicateHandling;
52
53 fn id(&self) -> ScalarFnId {
54 ScalarFnId::from("vortex.merge")
55 }
56
57 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
58 Ok(Some(match instance {
59 DuplicateHandling::RightMost => vec![0x00],
60 DuplicateHandling::Error => vec![0x01],
61 }))
62 }
63
64 fn deserialize(
65 &self,
66 _metadata: &[u8],
67 _session: &VortexSession,
68 ) -> VortexResult<Self::Options> {
69 let instance = match _metadata {
70 [0x00] => DuplicateHandling::RightMost,
71 [0x01] => DuplicateHandling::Error,
72 _ => {
73 vortex_bail!("invalid metadata for Merge expression");
74 }
75 };
76 Ok(instance)
77 }
78
79 fn arity(&self, _options: &Self::Options) -> Arity {
80 Arity::Variadic { min: 0, max: None }
81 }
82
83 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
84 ChildName::from(Arc::from(format!("{}", child_idx)))
85 }
86
87 fn fmt_sql(
88 &self,
89 _options: &Self::Options,
90 expr: &Expression,
91 f: &mut Formatter<'_>,
92 ) -> std::fmt::Result {
93 write!(f, "merge(")?;
94 for (i, child) in expr.children().iter().enumerate() {
95 child.fmt_sql(f)?;
96 if i + 1 < expr.children().len() {
97 write!(f, ", ")?;
98 }
99 }
100 write!(f, ")")
101 }
102
103 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
104 let mut field_names = Vec::new();
105 let mut arrays = Vec::new();
106 let mut merge_nullability = Nullability::NonNullable;
107 let mut duplicate_names = HashSet::<_>::new();
108
109 for dtype in arg_dtypes {
110 let Some(fields) = dtype.as_struct_fields_opt() else {
111 vortex_bail!("merge expects struct input");
112 };
113 if dtype.is_nullable() {
114 vortex_bail!("merge expects non-nullable input");
115 }
116
117 merge_nullability |= dtype.nullability();
118
119 for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
120 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
121 duplicate_names.insert(field_name.clone());
122 arrays[idx] = field_dtype;
123 } else {
124 field_names.push(field_name.clone());
125 arrays.push(field_dtype);
126 }
127 }
128 }
129
130 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
131 vortex_bail!(
132 "merge: duplicate fields in children: {}",
133 duplicate_names.into_iter().format(", ")
134 )
135 }
136
137 Ok(DType::Struct(
138 StructFields::new(FieldNames::from(field_names), arrays),
139 merge_nullability,
140 ))
141 }
142
143 fn execute(
144 &self,
145 options: &Self::Options,
146 args: &dyn ExecutionArgs,
147 ctx: &mut ExecutionCtx,
148 ) -> VortexResult<ArrayRef> {
149 let mut field_names = Vec::new();
151 let mut arrays = Vec::new();
152 let mut duplicate_names = HashSet::<_>::new();
153
154 for i in 0..args.num_inputs() {
155 let array = args.get(i)?.execute::<StructArray>(ctx)?;
156 if array.dtype().is_nullable() {
157 vortex_bail!("merge expects non-nullable input");
158 }
159
160 for (field_name, field_array) in array
161 .names()
162 .iter()
163 .zip_eq(array.iter_unmasked_fields().cloned())
164 {
165 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
167 duplicate_names.insert(field_name.clone());
168 arrays[idx] = field_array;
169 } else {
170 field_names.push(field_name.clone());
171 arrays.push(field_array);
172 }
173 }
174 }
175
176 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
177 vortex_bail!(
178 "merge: duplicate fields in children: {}",
179 duplicate_names.into_iter().format(", ")
180 )
181 }
182
183 let validity = Validity::NonNullable;
185 let len = args.row_count();
186 Ok(
187 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
188 .into_array(),
189 )
190 }
191
192 fn reduce(
193 &self,
194 options: &Self::Options,
195 node: &dyn ReduceNode,
196 ctx: &dyn ReduceCtx,
197 ) -> VortexResult<Option<ReduceNodeRef>> {
198 let mut names = Vec::with_capacity(node.child_count() * 2);
199 let mut children = Vec::with_capacity(node.child_count() * 2);
200 let mut duplicate_names = HashSet::<_>::new();
201
202 for child in (0..node.child_count()).map(|i| node.child(i)) {
203 let child_dtype = child.node_dtype()?;
204 if !child_dtype.is_struct() {
205 vortex_bail!(
206 "Merge child must return a non-nullable struct dtype, got {}",
207 child_dtype
208 )
209 }
210
211 let child_dtype = child_dtype
212 .as_struct_fields_opt()
213 .vortex_expect("expected struct");
214
215 for name in child_dtype.names().iter() {
216 if let Some(idx) = names.iter().position(|n| n == name) {
217 duplicate_names.insert(name.clone());
218 children[idx] = Arc::clone(&child);
219 } else {
220 names.push(name.clone());
221 children.push(Arc::clone(&child));
222 }
223 }
224
225 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
226 vortex_bail!(
227 "merge: duplicate fields in children: {}",
228 duplicate_names.into_iter().format(", ")
229 )
230 }
231 }
232
233 let pack_children: Vec<_> = names
234 .iter()
235 .zip(children)
236 .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
237 .try_collect()?;
238
239 let pack_expr = ctx.new_node(
240 Pack.bind(PackOptions {
241 names: FieldNames::from(names),
242 nullability: node.node_dtype()?.nullability(),
243 }),
244 &pack_children,
245 )?;
246
247 Ok(Some(pack_expr))
248 }
249
250 fn validity(
251 &self,
252 _options: &Self::Options,
253 _expression: &Expression,
254 ) -> VortexResult<Option<Expression>> {
255 Ok(Some(lit(true)))
256 }
257
258 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
259 true
260 }
261
262 fn is_fallible(&self, instance: &Self::Options) -> bool {
263 matches!(instance, DuplicateHandling::Error)
264 }
265}
266
267#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
269pub enum DuplicateHandling {
270 RightMost,
272 #[default]
274 Error,
275}
276
277impl Display for DuplicateHandling {
278 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279 match self {
280 DuplicateHandling::RightMost => write!(f, "RightMost"),
281 DuplicateHandling::Error => write!(f, "Error"),
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use vortex_buffer::buffer;
289 use vortex_error::VortexResult;
290 use vortex_error::vortex_bail;
291
292 use crate::ArrayRef;
293 use crate::IntoArray;
294 #[expect(deprecated)]
295 use crate::ToCanonical as _;
296 use crate::arrays::PrimitiveArray;
297 use crate::arrays::struct_::StructArrayExt;
298 use crate::assert_arrays_eq;
299 use crate::dtype::DType;
300 use crate::dtype::Nullability::NonNullable;
301 use crate::dtype::PType::I32;
302 use crate::dtype::PType::I64;
303 use crate::dtype::PType::U32;
304 use crate::dtype::PType::U64;
305 use crate::expr::Expression;
306 use crate::expr::get_item;
307 use crate::expr::merge;
308 use crate::expr::merge_opts;
309 use crate::expr::root;
310 use crate::scalar_fn::fns::merge::DuplicateHandling;
311 use crate::scalar_fn::fns::merge::StructArray;
312 use crate::scalar_fn::fns::pack::Pack;
313
314 fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
315 let mut field_path = field_path.iter();
316
317 let Some(field) = field_path.next() else {
318 vortex_bail!("empty field path");
319 };
320
321 #[expect(deprecated)]
322 let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
323 for field in field_path {
324 #[expect(deprecated)]
325 let next = array.to_struct().unmasked_field_by_name(field)?.clone();
326 array = next;
327 }
328 #[expect(deprecated)]
329 let result = array.to_primitive();
330 Ok(result)
331 }
332
333 #[test]
334 pub fn test_merge_right_most() {
335 let expr = merge_opts(
336 vec![
337 get_item("0", root()),
338 get_item("1", root()),
339 get_item("2", root()),
340 ],
341 DuplicateHandling::RightMost,
342 );
343
344 let test_array = StructArray::from_fields(&[
345 (
346 "0",
347 StructArray::from_fields(&[
348 ("a", buffer![0, 0, 0].into_array()),
349 ("b", buffer![1, 1, 1].into_array()),
350 ])
351 .unwrap()
352 .into_array(),
353 ),
354 (
355 "1",
356 StructArray::from_fields(&[
357 ("b", buffer![2, 2, 2].into_array()),
358 ("c", buffer![3, 3, 3].into_array()),
359 ])
360 .unwrap()
361 .into_array(),
362 ),
363 (
364 "2",
365 StructArray::from_fields(&[
366 ("d", buffer![4, 4, 4].into_array()),
367 ("e", buffer![5, 5, 5].into_array()),
368 ])
369 .unwrap()
370 .into_array(),
371 ),
372 ])
373 .unwrap()
374 .into_array();
375 let actual_array = test_array.apply(&expr).unwrap();
376
377 assert_eq!(
378 actual_array.as_struct_typed().names(),
379 ["a", "b", "c", "d", "e"]
380 );
381
382 assert_arrays_eq!(
383 primitive_field(&actual_array, &["a"]).unwrap(),
384 PrimitiveArray::from_iter([0i32, 0, 0])
385 );
386 assert_arrays_eq!(
387 primitive_field(&actual_array, &["b"]).unwrap(),
388 PrimitiveArray::from_iter([2i32, 2, 2])
389 );
390 assert_arrays_eq!(
391 primitive_field(&actual_array, &["c"]).unwrap(),
392 PrimitiveArray::from_iter([3i32, 3, 3])
393 );
394 assert_arrays_eq!(
395 primitive_field(&actual_array, &["d"]).unwrap(),
396 PrimitiveArray::from_iter([4i32, 4, 4])
397 );
398 assert_arrays_eq!(
399 primitive_field(&actual_array, &["e"]).unwrap(),
400 PrimitiveArray::from_iter([5i32, 5, 5])
401 );
402 }
403
404 #[test]
405 #[should_panic(expected = "merge: duplicate fields in children")]
406 pub fn test_merge_error_on_dupe_return_dtype() {
407 let expr = merge_opts(
408 vec![get_item("0", root()), get_item("1", root())],
409 DuplicateHandling::Error,
410 );
411 let test_array = StructArray::try_from_iter([
412 (
413 "0",
414 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
415 ),
416 (
417 "1",
418 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
419 ),
420 ])
421 .unwrap()
422 .into_array();
423
424 expr.return_dtype(test_array.dtype()).unwrap();
425 }
426
427 #[test]
428 #[should_panic(expected = "merge: duplicate fields in children")]
429 pub fn test_merge_error_on_dupe_evaluate() {
430 let expr = merge_opts(
431 vec![get_item("0", root()), get_item("1", root())],
432 DuplicateHandling::Error,
433 );
434 let test_array = StructArray::try_from_iter([
435 (
436 "0",
437 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
438 ),
439 (
440 "1",
441 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
442 ),
443 ])
444 .unwrap()
445 .into_array();
446
447 test_array.apply(&expr).unwrap();
448 }
449
450 #[test]
451 pub fn test_empty_merge() {
452 let expr = merge(Vec::<Expression>::new());
453
454 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
455 .unwrap()
456 .into_array();
457 let actual_array = test_array.clone().apply(&expr).unwrap();
458 assert_eq!(actual_array.len(), test_array.len());
459 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
460 }
461
462 #[test]
463 pub fn test_nested_merge() {
464 let expr = merge_opts(
467 vec![get_item("0", root()), get_item("1", root())],
468 DuplicateHandling::RightMost,
469 );
470
471 let test_array = StructArray::from_fields(&[
472 (
473 "0",
474 StructArray::from_fields(&[(
475 "a",
476 StructArray::from_fields(&[
477 ("x", buffer![0, 0, 0].into_array()),
478 ("y", buffer![1, 1, 1].into_array()),
479 ])
480 .unwrap()
481 .into_array(),
482 )])
483 .unwrap()
484 .into_array(),
485 ),
486 (
487 "1",
488 StructArray::from_fields(&[(
489 "a",
490 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
491 .unwrap()
492 .into_array(),
493 )])
494 .unwrap()
495 .into_array(),
496 ),
497 ])
498 .unwrap()
499 .into_array();
500 #[expect(deprecated)]
501 let actual_array = test_array.apply(&expr).unwrap().to_struct();
502
503 #[expect(deprecated)]
504 let inner_struct = actual_array
505 .unmasked_field_by_name("a")
506 .unwrap()
507 .to_struct();
508 assert_eq!(
509 inner_struct
510 .names()
511 .iter()
512 .map(|name| name.as_ref())
513 .collect::<Vec<_>>(),
514 vec!["x"]
515 );
516 }
517
518 #[test]
519 pub fn test_merge_order() {
520 let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
521
522 let test_array = StructArray::from_fields(&[
523 (
524 "0",
525 StructArray::from_fields(&[
526 ("a", buffer![0, 0, 0].into_array()),
527 ("c", buffer![1, 1, 1].into_array()),
528 ])
529 .unwrap()
530 .into_array(),
531 ),
532 (
533 "1",
534 StructArray::from_fields(&[
535 ("b", buffer![2, 2, 2].into_array()),
536 ("d", buffer![3, 3, 3].into_array()),
537 ])
538 .unwrap()
539 .into_array(),
540 ),
541 ])
542 .unwrap()
543 .into_array();
544 #[expect(deprecated)]
545 let actual_array = test_array.apply(&expr).unwrap().to_struct();
546
547 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
548 }
549
550 #[test]
551 pub fn test_display() {
552 let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
553 assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
554
555 let expr2 = merge(vec![get_item("a", root())]);
556 assert_eq!(expr2.to_string(), "merge($.a)");
557 }
558
559 #[test]
560 fn test_remove_merge() {
561 let dtype = DType::struct_(
562 [
563 ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
564 ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
565 ],
566 NonNullable,
567 );
568
569 let e = merge_opts(
570 [get_item("0", root()), get_item("1", root())],
571 DuplicateHandling::RightMost,
572 );
573
574 let result = e.optimize(&dtype).unwrap();
575
576 assert!(result.is::<Pack>());
577 assert_eq!(
578 result.return_dtype(&dtype).unwrap(),
579 DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
580 );
581 }
582}