1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_session::VortexSession;
14use vortex_utils::aliases::hash_set::HashSet;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray as _;
19use crate::arrays::StructArray;
20use crate::dtype::DType;
21use crate::dtype::FieldNames;
22use crate::dtype::Nullability;
23use crate::dtype::StructFields;
24use crate::expr::Expression;
25use crate::expr::lit;
26use crate::scalar_fn::Arity;
27use crate::scalar_fn::ChildName;
28use crate::scalar_fn::ExecutionArgs;
29use crate::scalar_fn::ReduceCtx;
30use crate::scalar_fn::ReduceNode;
31use crate::scalar_fn::ReduceNodeRef;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::ScalarFnVTableExt;
35use crate::scalar_fn::fns::get_item::GetItem;
36use crate::scalar_fn::fns::pack::Pack;
37use crate::scalar_fn::fns::pack::PackOptions;
38use crate::validity::Validity;
39
40#[derive(Clone)]
47pub struct Merge;
48
49impl ScalarFnVTable for Merge {
50 type Options = DuplicateHandling;
51
52 fn id(&self) -> ScalarFnId {
53 ScalarFnId::new_ref("vortex.merge")
54 }
55
56 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
57 Ok(Some(match instance {
58 DuplicateHandling::RightMost => vec![0x00],
59 DuplicateHandling::Error => vec![0x01],
60 }))
61 }
62
63 fn deserialize(
64 &self,
65 _metadata: &[u8],
66 _session: &VortexSession,
67 ) -> VortexResult<Self::Options> {
68 let instance = match _metadata {
69 [0x00] => DuplicateHandling::RightMost,
70 [0x01] => DuplicateHandling::Error,
71 _ => {
72 vortex_bail!("invalid metadata for Merge expression");
73 }
74 };
75 Ok(instance)
76 }
77
78 fn arity(&self, _options: &Self::Options) -> Arity {
79 Arity::Variadic { min: 0, max: None }
80 }
81
82 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
83 ChildName::from(Arc::from(format!("{}", child_idx)))
84 }
85
86 fn fmt_sql(
87 &self,
88 _options: &Self::Options,
89 expr: &Expression,
90 f: &mut Formatter<'_>,
91 ) -> std::fmt::Result {
92 write!(f, "merge(")?;
93 for (i, child) in expr.children().iter().enumerate() {
94 child.fmt_sql(f)?;
95 if i + 1 < expr.children().len() {
96 write!(f, ", ")?;
97 }
98 }
99 write!(f, ")")
100 }
101
102 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
103 let mut field_names = Vec::new();
104 let mut arrays = Vec::new();
105 let mut merge_nullability = Nullability::NonNullable;
106 let mut duplicate_names = HashSet::<_>::new();
107
108 for dtype in arg_dtypes {
109 let Some(fields) = dtype.as_struct_fields_opt() else {
110 vortex_bail!("merge expects struct input");
111 };
112 if dtype.is_nullable() {
113 vortex_bail!("merge expects non-nullable input");
114 }
115
116 merge_nullability |= dtype.nullability();
117
118 for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
119 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
120 duplicate_names.insert(field_name.clone());
121 arrays[idx] = field_dtype;
122 } else {
123 field_names.push(field_name.clone());
124 arrays.push(field_dtype);
125 }
126 }
127 }
128
129 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
130 vortex_bail!(
131 "merge: duplicate fields in children: {}",
132 duplicate_names.into_iter().format(", ")
133 )
134 }
135
136 Ok(DType::Struct(
137 StructFields::new(FieldNames::from(field_names), arrays),
138 merge_nullability,
139 ))
140 }
141
142 fn execute(
143 &self,
144 options: &Self::Options,
145 args: &dyn ExecutionArgs,
146 ctx: &mut ExecutionCtx,
147 ) -> VortexResult<ArrayRef> {
148 let mut field_names = Vec::new();
150 let mut arrays = Vec::new();
151 let mut duplicate_names = HashSet::<_>::new();
152
153 for i in 0..args.num_inputs() {
154 let array = args.get(i)?.execute::<StructArray>(ctx)?;
155 if array.dtype().is_nullable() {
156 vortex_bail!("merge expects non-nullable input");
157 }
158
159 for (field_name, field_array) in array
160 .names()
161 .iter()
162 .zip_eq(array.unmasked_fields().iter().cloned())
163 {
164 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
166 duplicate_names.insert(field_name.clone());
167 arrays[idx] = field_array;
168 } else {
169 field_names.push(field_name.clone());
170 arrays.push(field_array);
171 }
172 }
173 }
174
175 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
176 vortex_bail!(
177 "merge: duplicate fields in children: {}",
178 duplicate_names.into_iter().format(", ")
179 )
180 }
181
182 let validity = Validity::NonNullable;
184 let len = args.row_count();
185 Ok(
186 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
187 .into_array(),
188 )
189 }
190
191 fn reduce(
192 &self,
193 options: &Self::Options,
194 node: &dyn ReduceNode,
195 ctx: &dyn ReduceCtx,
196 ) -> VortexResult<Option<ReduceNodeRef>> {
197 let mut names = Vec::with_capacity(node.child_count() * 2);
198 let mut children = Vec::with_capacity(node.child_count() * 2);
199 let mut duplicate_names = HashSet::<_>::new();
200
201 for child in (0..node.child_count()).map(|i| node.child(i)) {
202 let child_dtype = child.node_dtype()?;
203 if !child_dtype.is_struct() {
204 vortex_bail!(
205 "Merge child must return a non-nullable struct dtype, got {}",
206 child_dtype
207 )
208 }
209
210 let child_dtype = child_dtype
211 .as_struct_fields_opt()
212 .vortex_expect("expected struct");
213
214 for name in child_dtype.names().iter() {
215 if let Some(idx) = names.iter().position(|n| n == name) {
216 duplicate_names.insert(name.clone());
217 children[idx] = child.clone();
218 } else {
219 names.push(name.clone());
220 children.push(child.clone());
221 }
222 }
223
224 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
225 vortex_bail!(
226 "merge: duplicate fields in children: {}",
227 duplicate_names.into_iter().format(", ")
228 )
229 }
230 }
231
232 let pack_children: Vec<_> = names
233 .iter()
234 .zip(children)
235 .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
236 .try_collect()?;
237
238 let pack_expr = ctx.new_node(
239 Pack.bind(PackOptions {
240 names: FieldNames::from(names),
241 nullability: node.node_dtype()?.nullability(),
242 }),
243 &pack_children,
244 )?;
245
246 Ok(Some(pack_expr))
247 }
248
249 fn validity(
250 &self,
251 _options: &Self::Options,
252 _expression: &Expression,
253 ) -> VortexResult<Option<Expression>> {
254 Ok(Some(lit(true)))
255 }
256
257 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
258 true
259 }
260
261 fn is_fallible(&self, instance: &Self::Options) -> bool {
262 matches!(instance, DuplicateHandling::Error)
263 }
264}
265
266#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
268pub enum DuplicateHandling {
269 RightMost,
271 #[default]
273 Error,
274}
275
276impl Display for DuplicateHandling {
277 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
278 match self {
279 DuplicateHandling::RightMost => write!(f, "RightMost"),
280 DuplicateHandling::Error => write!(f, "Error"),
281 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use vortex_buffer::buffer;
288 use vortex_error::VortexResult;
289 use vortex_error::vortex_bail;
290
291 use crate::Array;
292 use crate::ArrayRef;
293 use crate::IntoArray;
294 use crate::ToCanonical;
295 use crate::arrays::PrimitiveArray;
296 use crate::arrays::StructArray;
297 use crate::assert_arrays_eq;
298 use crate::dtype::DType;
299 use crate::dtype::Nullability::NonNullable;
300 use crate::dtype::PType::I32;
301 use crate::dtype::PType::I64;
302 use crate::dtype::PType::U32;
303 use crate::dtype::PType::U64;
304 use crate::expr::Expression;
305 use crate::expr::get_item;
306 use crate::expr::merge;
307 use crate::expr::merge_opts;
308 use crate::expr::root;
309 use crate::scalar_fn::fns::merge::DuplicateHandling;
310 use crate::scalar_fn::fns::pack::Pack;
311
312 fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
313 let mut field_path = field_path.iter();
314
315 let Some(field) = field_path.next() else {
316 vortex_bail!("empty field path");
317 };
318
319 let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
320 for field in field_path {
321 array = array.to_struct().unmasked_field_by_name(field)?.clone();
322 }
323 Ok(array.to_primitive())
324 }
325
326 #[test]
327 pub fn test_merge_right_most() {
328 let expr = merge_opts(
329 vec![
330 get_item("0", root()),
331 get_item("1", root()),
332 get_item("2", root()),
333 ],
334 DuplicateHandling::RightMost,
335 );
336
337 let test_array = StructArray::from_fields(&[
338 (
339 "0",
340 StructArray::from_fields(&[
341 ("a", buffer![0, 0, 0].into_array()),
342 ("b", buffer![1, 1, 1].into_array()),
343 ])
344 .unwrap()
345 .into_array(),
346 ),
347 (
348 "1",
349 StructArray::from_fields(&[
350 ("b", buffer![2, 2, 2].into_array()),
351 ("c", buffer![3, 3, 3].into_array()),
352 ])
353 .unwrap()
354 .into_array(),
355 ),
356 (
357 "2",
358 StructArray::from_fields(&[
359 ("d", buffer![4, 4, 4].into_array()),
360 ("e", buffer![5, 5, 5].into_array()),
361 ])
362 .unwrap()
363 .into_array(),
364 ),
365 ])
366 .unwrap()
367 .into_array();
368 let actual_array = test_array.apply(&expr).unwrap();
369
370 assert_eq!(
371 actual_array.as_struct_typed().names(),
372 ["a", "b", "c", "d", "e"]
373 );
374
375 assert_arrays_eq!(
376 primitive_field(&actual_array, &["a"]).unwrap(),
377 PrimitiveArray::from_iter([0i32, 0, 0])
378 );
379 assert_arrays_eq!(
380 primitive_field(&actual_array, &["b"]).unwrap(),
381 PrimitiveArray::from_iter([2i32, 2, 2])
382 );
383 assert_arrays_eq!(
384 primitive_field(&actual_array, &["c"]).unwrap(),
385 PrimitiveArray::from_iter([3i32, 3, 3])
386 );
387 assert_arrays_eq!(
388 primitive_field(&actual_array, &["d"]).unwrap(),
389 PrimitiveArray::from_iter([4i32, 4, 4])
390 );
391 assert_arrays_eq!(
392 primitive_field(&actual_array, &["e"]).unwrap(),
393 PrimitiveArray::from_iter([5i32, 5, 5])
394 );
395 }
396
397 #[test]
398 #[should_panic(expected = "merge: duplicate fields in children")]
399 pub fn test_merge_error_on_dupe_return_dtype() {
400 let expr = merge_opts(
401 vec![get_item("0", root()), get_item("1", root())],
402 DuplicateHandling::Error,
403 );
404 let test_array = StructArray::try_from_iter([
405 (
406 "0",
407 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
408 ),
409 (
410 "1",
411 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
412 ),
413 ])
414 .unwrap()
415 .into_array();
416
417 expr.return_dtype(test_array.dtype()).unwrap();
418 }
419
420 #[test]
421 #[should_panic(expected = "merge: duplicate fields in children")]
422 pub fn test_merge_error_on_dupe_evaluate() {
423 let expr = merge_opts(
424 vec![get_item("0", root()), get_item("1", root())],
425 DuplicateHandling::Error,
426 );
427 let test_array = StructArray::try_from_iter([
428 (
429 "0",
430 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
431 ),
432 (
433 "1",
434 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
435 ),
436 ])
437 .unwrap()
438 .into_array();
439
440 test_array.apply(&expr).unwrap();
441 }
442
443 #[test]
444 pub fn test_empty_merge() {
445 let expr = merge(Vec::<Expression>::new());
446
447 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
448 .unwrap()
449 .into_array();
450 let actual_array = test_array.clone().apply(&expr).unwrap();
451 assert_eq!(actual_array.len(), test_array.len());
452 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
453 }
454
455 #[test]
456 pub fn test_nested_merge() {
457 let expr = merge_opts(
460 vec![get_item("0", root()), get_item("1", root())],
461 DuplicateHandling::RightMost,
462 );
463
464 let test_array = StructArray::from_fields(&[
465 (
466 "0",
467 StructArray::from_fields(&[(
468 "a",
469 StructArray::from_fields(&[
470 ("x", buffer![0, 0, 0].into_array()),
471 ("y", buffer![1, 1, 1].into_array()),
472 ])
473 .unwrap()
474 .into_array(),
475 )])
476 .unwrap()
477 .into_array(),
478 ),
479 (
480 "1",
481 StructArray::from_fields(&[(
482 "a",
483 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
484 .unwrap()
485 .into_array(),
486 )])
487 .unwrap()
488 .into_array(),
489 ),
490 ])
491 .unwrap()
492 .into_array();
493 let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
494
495 assert_eq!(
496 actual_array
497 .unmasked_field_by_name("a")
498 .unwrap()
499 .to_struct()
500 .names()
501 .iter()
502 .map(|name| name.as_ref())
503 .collect::<Vec<_>>(),
504 vec!["x"]
505 );
506 }
507
508 #[test]
509 pub fn test_merge_order() {
510 let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
511
512 let test_array = StructArray::from_fields(&[
513 (
514 "0",
515 StructArray::from_fields(&[
516 ("a", buffer![0, 0, 0].into_array()),
517 ("c", buffer![1, 1, 1].into_array()),
518 ])
519 .unwrap()
520 .into_array(),
521 ),
522 (
523 "1",
524 StructArray::from_fields(&[
525 ("b", buffer![2, 2, 2].into_array()),
526 ("d", buffer![3, 3, 3].into_array()),
527 ])
528 .unwrap()
529 .into_array(),
530 ),
531 ])
532 .unwrap()
533 .into_array();
534 let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
535
536 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
537 }
538
539 #[test]
540 pub fn test_display() {
541 let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
542 assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
543
544 let expr2 = merge(vec![get_item("a", root())]);
545 assert_eq!(expr2.to_string(), "merge($.a)");
546 }
547
548 #[test]
549 fn test_remove_merge() {
550 let dtype = DType::struct_(
551 [
552 ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
553 ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
554 ],
555 NonNullable,
556 );
557
558 let e = merge_opts(
559 [get_item("0", root()), get_item("1", root())],
560 DuplicateHandling::RightMost,
561 );
562
563 let result = e.optimize(&dtype).unwrap();
564
565 assert!(result.is::<Pack>());
566 assert_eq!(
567 result.return_dtype(&dtype).unwrap(),
568 DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
569 );
570 }
571}