1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_session::VortexSession;
14use vortex_utils::aliases::hash_set::HashSet;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray as _;
19use crate::arrays::StructArray;
20use crate::arrays::struct_::StructArrayExt;
21use crate::dtype::DType;
22use crate::dtype::FieldNames;
23use crate::dtype::Nullability;
24use crate::dtype::StructFields;
25use crate::expr::Expression;
26use crate::expr::lit;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::ExecutionArgs;
30use crate::scalar_fn::ReduceCtx;
31use crate::scalar_fn::ReduceNode;
32use crate::scalar_fn::ReduceNodeRef;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::ScalarFnVTableExt;
36use crate::scalar_fn::fns::get_item::GetItem;
37use crate::scalar_fn::fns::pack::Pack;
38use crate::scalar_fn::fns::pack::PackOptions;
39use crate::validity::Validity;
40
41#[derive(Clone)]
48pub struct Merge;
49
50impl ScalarFnVTable for Merge {
51 type Options = DuplicateHandling;
52
53 fn id(&self) -> ScalarFnId {
54 ScalarFnId::from("vortex.merge")
55 }
56
57 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
58 Ok(Some(match instance {
59 DuplicateHandling::RightMost => vec![0x00],
60 DuplicateHandling::Error => vec![0x01],
61 }))
62 }
63
64 fn deserialize(
65 &self,
66 _metadata: &[u8],
67 _session: &VortexSession,
68 ) -> VortexResult<Self::Options> {
69 let instance = match _metadata {
70 [0x00] => DuplicateHandling::RightMost,
71 [0x01] => DuplicateHandling::Error,
72 _ => {
73 vortex_bail!("invalid metadata for Merge expression");
74 }
75 };
76 Ok(instance)
77 }
78
79 fn arity(&self, _options: &Self::Options) -> Arity {
80 Arity::Variadic { min: 0, max: None }
81 }
82
83 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
84 ChildName::from(Arc::from(format!("{}", child_idx)))
85 }
86
87 fn fmt_sql(
88 &self,
89 _options: &Self::Options,
90 expr: &Expression,
91 f: &mut Formatter<'_>,
92 ) -> std::fmt::Result {
93 write!(f, "merge(")?;
94 for (i, child) in expr.children().iter().enumerate() {
95 child.fmt_sql(f)?;
96 if i + 1 < expr.children().len() {
97 write!(f, ", ")?;
98 }
99 }
100 write!(f, ")")
101 }
102
103 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
104 let mut field_names = Vec::new();
105 let mut arrays = Vec::new();
106 let mut merge_nullability = Nullability::NonNullable;
107 let mut duplicate_names = HashSet::<_>::new();
108
109 for dtype in arg_dtypes {
110 let Some(fields) = dtype.as_struct_fields_opt() else {
111 vortex_bail!("merge expects struct input");
112 };
113 if dtype.is_nullable() {
114 vortex_bail!("merge expects non-nullable input");
115 }
116
117 merge_nullability |= dtype.nullability();
118
119 for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
120 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
121 duplicate_names.insert(field_name.clone());
122 arrays[idx] = field_dtype;
123 } else {
124 field_names.push(field_name.clone());
125 arrays.push(field_dtype);
126 }
127 }
128 }
129
130 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
131 vortex_bail!(
132 "merge: duplicate fields in children: {}",
133 duplicate_names.into_iter().format(", ")
134 )
135 }
136
137 Ok(DType::Struct(
138 StructFields::new(FieldNames::from(field_names), arrays),
139 merge_nullability,
140 ))
141 }
142
143 fn execute(
144 &self,
145 options: &Self::Options,
146 args: &dyn ExecutionArgs,
147 ctx: &mut ExecutionCtx,
148 ) -> VortexResult<ArrayRef> {
149 let mut field_names = Vec::new();
151 let mut arrays = Vec::new();
152 let mut duplicate_names = HashSet::<_>::new();
153
154 for i in 0..args.num_inputs() {
155 let array = args.get(i)?.execute::<StructArray>(ctx)?;
156 if array.dtype().is_nullable() {
157 vortex_bail!("merge expects non-nullable input");
158 }
159
160 for (field_name, field_array) in array
161 .names()
162 .iter()
163 .zip_eq(array.iter_unmasked_fields().cloned())
164 {
165 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
167 duplicate_names.insert(field_name.clone());
168 arrays[idx] = field_array;
169 } else {
170 field_names.push(field_name.clone());
171 arrays.push(field_array);
172 }
173 }
174 }
175
176 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
177 vortex_bail!(
178 "merge: duplicate fields in children: {}",
179 duplicate_names.into_iter().format(", ")
180 )
181 }
182
183 let validity = Validity::NonNullable;
185 let len = args.row_count();
186 Ok(
187 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
188 .into_array(),
189 )
190 }
191
192 fn reduce(
193 &self,
194 options: &Self::Options,
195 node: &dyn ReduceNode,
196 ctx: &dyn ReduceCtx,
197 ) -> VortexResult<Option<ReduceNodeRef>> {
198 let mut names = Vec::with_capacity(node.child_count() * 2);
199 let mut children = Vec::with_capacity(node.child_count() * 2);
200 let mut duplicate_names = HashSet::<_>::new();
201
202 for child in (0..node.child_count()).map(|i| node.child(i)) {
203 let child_dtype = child.node_dtype()?;
204 if !child_dtype.is_struct() {
205 vortex_bail!(
206 "Merge child must return a non-nullable struct dtype, got {}",
207 child_dtype
208 )
209 }
210
211 let child_dtype = child_dtype
212 .as_struct_fields_opt()
213 .vortex_expect("expected struct");
214
215 for name in child_dtype.names().iter() {
216 if let Some(idx) = names.iter().position(|n| n == name) {
217 duplicate_names.insert(name.clone());
218 children[idx] = Arc::clone(&child);
219 } else {
220 names.push(name.clone());
221 children.push(Arc::clone(&child));
222 }
223 }
224
225 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
226 vortex_bail!(
227 "merge: duplicate fields in children: {}",
228 duplicate_names.into_iter().format(", ")
229 )
230 }
231 }
232
233 let pack_children: Vec<_> = names
234 .iter()
235 .zip(children)
236 .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
237 .try_collect()?;
238
239 let pack_expr = ctx.new_node(
240 Pack.bind(PackOptions {
241 names: FieldNames::from(names),
242 nullability: node.node_dtype()?.nullability(),
243 }),
244 &pack_children,
245 )?;
246
247 Ok(Some(pack_expr))
248 }
249
250 fn validity(
251 &self,
252 _options: &Self::Options,
253 _expression: &Expression,
254 ) -> VortexResult<Option<Expression>> {
255 Ok(Some(lit(true)))
256 }
257
258 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
259 true
260 }
261
262 fn is_fallible(&self, instance: &Self::Options) -> bool {
263 matches!(instance, DuplicateHandling::Error)
264 }
265}
266
267#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
269pub enum DuplicateHandling {
270 RightMost,
272 #[default]
274 Error,
275}
276
277impl Display for DuplicateHandling {
278 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279 match self {
280 DuplicateHandling::RightMost => write!(f, "RightMost"),
281 DuplicateHandling::Error => write!(f, "Error"),
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use vortex_buffer::buffer;
289 use vortex_error::VortexResult;
290 use vortex_error::vortex_bail;
291
292 use crate::ArrayRef;
293 use crate::IntoArray;
294 use crate::ToCanonical;
295 use crate::arrays::PrimitiveArray;
296 use crate::arrays::struct_::StructArrayExt;
297 use crate::assert_arrays_eq;
298 use crate::dtype::DType;
299 use crate::dtype::Nullability::NonNullable;
300 use crate::dtype::PType::I32;
301 use crate::dtype::PType::I64;
302 use crate::dtype::PType::U32;
303 use crate::dtype::PType::U64;
304 use crate::expr::Expression;
305 use crate::expr::get_item;
306 use crate::expr::merge;
307 use crate::expr::merge_opts;
308 use crate::expr::root;
309 use crate::scalar_fn::fns::merge::DuplicateHandling;
310 use crate::scalar_fn::fns::merge::StructArray;
311 use crate::scalar_fn::fns::pack::Pack;
312
313 fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
314 let mut field_path = field_path.iter();
315
316 let Some(field) = field_path.next() else {
317 vortex_bail!("empty field path");
318 };
319
320 let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
321 for field in field_path {
322 array = array.to_struct().unmasked_field_by_name(field)?.clone();
323 }
324 Ok(array.to_primitive())
325 }
326
327 #[test]
328 pub fn test_merge_right_most() {
329 let expr = merge_opts(
330 vec![
331 get_item("0", root()),
332 get_item("1", root()),
333 get_item("2", root()),
334 ],
335 DuplicateHandling::RightMost,
336 );
337
338 let test_array = StructArray::from_fields(&[
339 (
340 "0",
341 StructArray::from_fields(&[
342 ("a", buffer![0, 0, 0].into_array()),
343 ("b", buffer![1, 1, 1].into_array()),
344 ])
345 .unwrap()
346 .into_array(),
347 ),
348 (
349 "1",
350 StructArray::from_fields(&[
351 ("b", buffer![2, 2, 2].into_array()),
352 ("c", buffer![3, 3, 3].into_array()),
353 ])
354 .unwrap()
355 .into_array(),
356 ),
357 (
358 "2",
359 StructArray::from_fields(&[
360 ("d", buffer![4, 4, 4].into_array()),
361 ("e", buffer![5, 5, 5].into_array()),
362 ])
363 .unwrap()
364 .into_array(),
365 ),
366 ])
367 .unwrap()
368 .into_array();
369 let actual_array = test_array.apply(&expr).unwrap();
370
371 assert_eq!(
372 actual_array.as_struct_typed().names(),
373 ["a", "b", "c", "d", "e"]
374 );
375
376 assert_arrays_eq!(
377 primitive_field(&actual_array, &["a"]).unwrap(),
378 PrimitiveArray::from_iter([0i32, 0, 0])
379 );
380 assert_arrays_eq!(
381 primitive_field(&actual_array, &["b"]).unwrap(),
382 PrimitiveArray::from_iter([2i32, 2, 2])
383 );
384 assert_arrays_eq!(
385 primitive_field(&actual_array, &["c"]).unwrap(),
386 PrimitiveArray::from_iter([3i32, 3, 3])
387 );
388 assert_arrays_eq!(
389 primitive_field(&actual_array, &["d"]).unwrap(),
390 PrimitiveArray::from_iter([4i32, 4, 4])
391 );
392 assert_arrays_eq!(
393 primitive_field(&actual_array, &["e"]).unwrap(),
394 PrimitiveArray::from_iter([5i32, 5, 5])
395 );
396 }
397
398 #[test]
399 #[should_panic(expected = "merge: duplicate fields in children")]
400 pub fn test_merge_error_on_dupe_return_dtype() {
401 let expr = merge_opts(
402 vec![get_item("0", root()), get_item("1", root())],
403 DuplicateHandling::Error,
404 );
405 let test_array = StructArray::try_from_iter([
406 (
407 "0",
408 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
409 ),
410 (
411 "1",
412 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
413 ),
414 ])
415 .unwrap()
416 .into_array();
417
418 expr.return_dtype(test_array.dtype()).unwrap();
419 }
420
421 #[test]
422 #[should_panic(expected = "merge: duplicate fields in children")]
423 pub fn test_merge_error_on_dupe_evaluate() {
424 let expr = merge_opts(
425 vec![get_item("0", root()), get_item("1", root())],
426 DuplicateHandling::Error,
427 );
428 let test_array = StructArray::try_from_iter([
429 (
430 "0",
431 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
432 ),
433 (
434 "1",
435 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
436 ),
437 ])
438 .unwrap()
439 .into_array();
440
441 test_array.apply(&expr).unwrap();
442 }
443
444 #[test]
445 pub fn test_empty_merge() {
446 let expr = merge(Vec::<Expression>::new());
447
448 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
449 .unwrap()
450 .into_array();
451 let actual_array = test_array.clone().apply(&expr).unwrap();
452 assert_eq!(actual_array.len(), test_array.len());
453 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
454 }
455
456 #[test]
457 pub fn test_nested_merge() {
458 let expr = merge_opts(
461 vec![get_item("0", root()), get_item("1", root())],
462 DuplicateHandling::RightMost,
463 );
464
465 let test_array = StructArray::from_fields(&[
466 (
467 "0",
468 StructArray::from_fields(&[(
469 "a",
470 StructArray::from_fields(&[
471 ("x", buffer![0, 0, 0].into_array()),
472 ("y", buffer![1, 1, 1].into_array()),
473 ])
474 .unwrap()
475 .into_array(),
476 )])
477 .unwrap()
478 .into_array(),
479 ),
480 (
481 "1",
482 StructArray::from_fields(&[(
483 "a",
484 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
485 .unwrap()
486 .into_array(),
487 )])
488 .unwrap()
489 .into_array(),
490 ),
491 ])
492 .unwrap()
493 .into_array();
494 let actual_array = test_array.apply(&expr).unwrap().to_struct();
495
496 assert_eq!(
497 actual_array
498 .unmasked_field_by_name("a")
499 .unwrap()
500 .to_struct()
501 .names()
502 .iter()
503 .map(|name| name.as_ref())
504 .collect::<Vec<_>>(),
505 vec!["x"]
506 );
507 }
508
509 #[test]
510 pub fn test_merge_order() {
511 let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
512
513 let test_array = StructArray::from_fields(&[
514 (
515 "0",
516 StructArray::from_fields(&[
517 ("a", buffer![0, 0, 0].into_array()),
518 ("c", buffer![1, 1, 1].into_array()),
519 ])
520 .unwrap()
521 .into_array(),
522 ),
523 (
524 "1",
525 StructArray::from_fields(&[
526 ("b", buffer![2, 2, 2].into_array()),
527 ("d", buffer![3, 3, 3].into_array()),
528 ])
529 .unwrap()
530 .into_array(),
531 ),
532 ])
533 .unwrap()
534 .into_array();
535 let actual_array = test_array.apply(&expr).unwrap().to_struct();
536
537 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
538 }
539
540 #[test]
541 pub fn test_display() {
542 let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
543 assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
544
545 let expr2 = merge(vec![get_item("a", root())]);
546 assert_eq!(expr2.to_string(), "merge($.a)");
547 }
548
549 #[test]
550 fn test_remove_merge() {
551 let dtype = DType::struct_(
552 [
553 ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
554 ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
555 ],
556 NonNullable,
557 );
558
559 let e = merge_opts(
560 [get_item("0", root()), get_item("1", root())],
561 DuplicateHandling::RightMost,
562 );
563
564 let result = e.optimize(&dtype).unwrap();
565
566 assert!(result.is::<Pack>());
567 assert_eq!(
568 result.return_dtype(&dtype).unwrap(),
569 DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
570 );
571 }
572}