1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_session::VortexSession;
14use vortex_utils::aliases::hash_set::HashSet;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray as _;
19use crate::arrays::StructArray;
20use crate::arrays::struct_::StructArrayExt;
21use crate::dtype::DType;
22use crate::dtype::FieldNames;
23use crate::dtype::Nullability;
24use crate::dtype::StructFields;
25use crate::expr::Expression;
26use crate::expr::lit;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::ExecutionArgs;
30use crate::scalar_fn::ReduceCtx;
31use crate::scalar_fn::ReduceNode;
32use crate::scalar_fn::ReduceNodeRef;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::ScalarFnVTableExt;
36use crate::scalar_fn::fns::get_item::GetItem;
37use crate::scalar_fn::fns::pack::Pack;
38use crate::scalar_fn::fns::pack::PackOptions;
39use crate::validity::Validity;
40
41#[derive(Clone)]
48pub struct Merge;
49
50impl ScalarFnVTable for Merge {
51 type Options = DuplicateHandling;
52
53 fn id(&self) -> ScalarFnId {
54 ScalarFnId::from("vortex.merge")
55 }
56
57 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
58 Ok(Some(match instance {
59 DuplicateHandling::RightMost => vec![0x00],
60 DuplicateHandling::Error => vec![0x01],
61 }))
62 }
63
64 fn deserialize(
65 &self,
66 _metadata: &[u8],
67 _session: &VortexSession,
68 ) -> VortexResult<Self::Options> {
69 let instance = match _metadata {
70 [0x00] => DuplicateHandling::RightMost,
71 [0x01] => DuplicateHandling::Error,
72 _ => {
73 vortex_bail!("invalid metadata for Merge expression");
74 }
75 };
76 Ok(instance)
77 }
78
79 fn arity(&self, _options: &Self::Options) -> Arity {
80 Arity::Variadic { min: 0, max: None }
81 }
82
83 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
84 ChildName::from(Arc::from(format!("{}", child_idx)))
85 }
86
87 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
88 let mut field_names = Vec::new();
89 let mut arrays = Vec::new();
90 let mut merge_nullability = Nullability::NonNullable;
91 let mut duplicate_names = HashSet::<_>::new();
92
93 for dtype in arg_dtypes {
94 let Some(fields) = dtype.as_struct_fields_opt() else {
95 vortex_bail!("merge expects struct input");
96 };
97 if dtype.is_nullable() {
98 vortex_bail!("merge expects non-nullable input");
99 }
100
101 merge_nullability |= dtype.nullability();
102
103 for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
104 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
105 duplicate_names.insert(field_name.clone());
106 arrays[idx] = field_dtype;
107 } else {
108 field_names.push(field_name.clone());
109 arrays.push(field_dtype);
110 }
111 }
112 }
113
114 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
115 vortex_bail!(
116 "merge: duplicate fields in children: {}",
117 duplicate_names.into_iter().format(", ")
118 )
119 }
120
121 Ok(DType::Struct(
122 StructFields::new(FieldNames::from(field_names), arrays),
123 merge_nullability,
124 ))
125 }
126
127 fn execute(
128 &self,
129 options: &Self::Options,
130 args: &dyn ExecutionArgs,
131 ctx: &mut ExecutionCtx,
132 ) -> VortexResult<ArrayRef> {
133 let mut field_names = Vec::new();
135 let mut arrays = Vec::new();
136 let mut duplicate_names = HashSet::<_>::new();
137
138 for i in 0..args.num_inputs() {
139 let array = args.get(i)?.execute::<StructArray>(ctx)?;
140 if array.dtype().is_nullable() {
141 vortex_bail!("merge expects non-nullable input");
142 }
143
144 for (field_name, field_array) in array
145 .names()
146 .iter()
147 .zip_eq(array.iter_unmasked_fields().cloned())
148 {
149 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
151 duplicate_names.insert(field_name.clone());
152 arrays[idx] = field_array;
153 } else {
154 field_names.push(field_name.clone());
155 arrays.push(field_array);
156 }
157 }
158 }
159
160 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
161 vortex_bail!(
162 "merge: duplicate fields in children: {}",
163 duplicate_names.into_iter().format(", ")
164 )
165 }
166
167 let validity = Validity::NonNullable;
169 let len = args.row_count();
170 Ok(
171 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
172 .into_array(),
173 )
174 }
175
176 fn reduce(
177 &self,
178 options: &Self::Options,
179 node: &dyn ReduceNode,
180 ctx: &dyn ReduceCtx,
181 ) -> VortexResult<Option<ReduceNodeRef>> {
182 let mut names = Vec::with_capacity(node.child_count() * 2);
183 let mut children = Vec::with_capacity(node.child_count() * 2);
184 let mut duplicate_names = HashSet::<_>::new();
185
186 for child in (0..node.child_count()).map(|i| node.child(i)) {
187 let child_dtype = child.node_dtype()?;
188 if !child_dtype.is_struct() {
189 vortex_bail!(
190 "Merge child must return a non-nullable struct dtype, got {}",
191 child_dtype
192 )
193 }
194
195 let child_dtype = child_dtype
196 .as_struct_fields_opt()
197 .vortex_expect("expected struct");
198
199 for name in child_dtype.names().iter() {
200 if let Some(idx) = names.iter().position(|n| n == name) {
201 duplicate_names.insert(name.clone());
202 children[idx] = Arc::clone(&child);
203 } else {
204 names.push(name.clone());
205 children.push(Arc::clone(&child));
206 }
207 }
208
209 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
210 vortex_bail!(
211 "merge: duplicate fields in children: {}",
212 duplicate_names.into_iter().format(", ")
213 )
214 }
215 }
216
217 let pack_children: Vec<_> = names
218 .iter()
219 .zip(children)
220 .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
221 .try_collect()?;
222
223 let pack_expr = ctx.new_node(
224 Pack.bind(PackOptions {
225 names: FieldNames::from(names),
226 nullability: node.node_dtype()?.nullability(),
227 }),
228 &pack_children,
229 )?;
230
231 Ok(Some(pack_expr))
232 }
233
234 fn validity(
235 &self,
236 _options: &Self::Options,
237 _expression: &Expression,
238 ) -> VortexResult<Option<Expression>> {
239 Ok(Some(lit(true)))
240 }
241
242 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
243 true
244 }
245
246 fn is_fallible(&self, instance: &Self::Options) -> bool {
247 matches!(instance, DuplicateHandling::Error)
248 }
249}
250
251#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
253pub enum DuplicateHandling {
254 RightMost,
256 #[default]
258 Error,
259}
260
261impl Display for DuplicateHandling {
262 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263 match self {
264 DuplicateHandling::RightMost => write!(f, "RightMost"),
265 DuplicateHandling::Error => write!(f, "Error"),
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use vortex_buffer::buffer;
273 use vortex_error::VortexResult;
274 use vortex_error::vortex_bail;
275
276 use crate::ArrayRef;
277 use crate::IntoArray;
278 #[expect(deprecated)]
279 use crate::ToCanonical as _;
280 use crate::arrays::PrimitiveArray;
281 use crate::arrays::struct_::StructArrayExt;
282 use crate::assert_arrays_eq;
283 use crate::dtype::DType;
284 use crate::dtype::Nullability::NonNullable;
285 use crate::dtype::PType::I32;
286 use crate::dtype::PType::I64;
287 use crate::dtype::PType::U32;
288 use crate::dtype::PType::U64;
289 use crate::expr::Expression;
290 use crate::expr::get_item;
291 use crate::expr::merge;
292 use crate::expr::merge_opts;
293 use crate::expr::root;
294 use crate::scalar_fn::fns::merge::DuplicateHandling;
295 use crate::scalar_fn::fns::merge::StructArray;
296 use crate::scalar_fn::fns::pack::Pack;
297
298 fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
299 let mut field_path = field_path.iter();
300
301 let Some(field) = field_path.next() else {
302 vortex_bail!("empty field path");
303 };
304
305 #[expect(deprecated)]
306 let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
307 for field in field_path {
308 #[expect(deprecated)]
309 let next = array.to_struct().unmasked_field_by_name(field)?.clone();
310 array = next;
311 }
312 #[expect(deprecated)]
313 let result = array.to_primitive();
314 Ok(result)
315 }
316
317 #[test]
318 pub fn test_merge_right_most() {
319 let expr = merge_opts(
320 vec![
321 get_item("0", root()),
322 get_item("1", root()),
323 get_item("2", root()),
324 ],
325 DuplicateHandling::RightMost,
326 );
327
328 let test_array = StructArray::from_fields(&[
329 (
330 "0",
331 StructArray::from_fields(&[
332 ("a", buffer![0, 0, 0].into_array()),
333 ("b", buffer![1, 1, 1].into_array()),
334 ])
335 .unwrap()
336 .into_array(),
337 ),
338 (
339 "1",
340 StructArray::from_fields(&[
341 ("b", buffer![2, 2, 2].into_array()),
342 ("c", buffer![3, 3, 3].into_array()),
343 ])
344 .unwrap()
345 .into_array(),
346 ),
347 (
348 "2",
349 StructArray::from_fields(&[
350 ("d", buffer![4, 4, 4].into_array()),
351 ("e", buffer![5, 5, 5].into_array()),
352 ])
353 .unwrap()
354 .into_array(),
355 ),
356 ])
357 .unwrap()
358 .into_array();
359 let actual_array = test_array.apply(&expr).unwrap();
360
361 assert_eq!(
362 actual_array.as_struct_typed().names(),
363 ["a", "b", "c", "d", "e"]
364 );
365
366 assert_arrays_eq!(
367 primitive_field(&actual_array, &["a"]).unwrap(),
368 PrimitiveArray::from_iter([0i32, 0, 0])
369 );
370 assert_arrays_eq!(
371 primitive_field(&actual_array, &["b"]).unwrap(),
372 PrimitiveArray::from_iter([2i32, 2, 2])
373 );
374 assert_arrays_eq!(
375 primitive_field(&actual_array, &["c"]).unwrap(),
376 PrimitiveArray::from_iter([3i32, 3, 3])
377 );
378 assert_arrays_eq!(
379 primitive_field(&actual_array, &["d"]).unwrap(),
380 PrimitiveArray::from_iter([4i32, 4, 4])
381 );
382 assert_arrays_eq!(
383 primitive_field(&actual_array, &["e"]).unwrap(),
384 PrimitiveArray::from_iter([5i32, 5, 5])
385 );
386 }
387
388 #[test]
389 #[should_panic(expected = "merge: duplicate fields in children")]
390 pub fn test_merge_error_on_dupe_return_dtype() {
391 let expr = merge_opts(
392 vec![get_item("0", root()), get_item("1", root())],
393 DuplicateHandling::Error,
394 );
395 let test_array = StructArray::try_from_iter([
396 (
397 "0",
398 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
399 ),
400 (
401 "1",
402 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
403 ),
404 ])
405 .unwrap()
406 .into_array();
407
408 expr.return_dtype(test_array.dtype()).unwrap();
409 }
410
411 #[test]
412 #[should_panic(expected = "merge: duplicate fields in children")]
413 pub fn test_merge_error_on_dupe_evaluate() {
414 let expr = merge_opts(
415 vec![get_item("0", root()), get_item("1", root())],
416 DuplicateHandling::Error,
417 );
418 let test_array = StructArray::try_from_iter([
419 (
420 "0",
421 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
422 ),
423 (
424 "1",
425 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
426 ),
427 ])
428 .unwrap()
429 .into_array();
430
431 test_array.apply(&expr).unwrap();
432 }
433
434 #[test]
435 pub fn test_empty_merge() {
436 let expr = merge(Vec::<Expression>::new());
437
438 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
439 .unwrap()
440 .into_array();
441 let actual_array = test_array.clone().apply(&expr).unwrap();
442 assert_eq!(actual_array.len(), test_array.len());
443 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
444 }
445
446 #[test]
447 pub fn test_nested_merge() {
448 let expr = merge_opts(
451 vec![get_item("0", root()), get_item("1", root())],
452 DuplicateHandling::RightMost,
453 );
454
455 let test_array = StructArray::from_fields(&[
456 (
457 "0",
458 StructArray::from_fields(&[(
459 "a",
460 StructArray::from_fields(&[
461 ("x", buffer![0, 0, 0].into_array()),
462 ("y", buffer![1, 1, 1].into_array()),
463 ])
464 .unwrap()
465 .into_array(),
466 )])
467 .unwrap()
468 .into_array(),
469 ),
470 (
471 "1",
472 StructArray::from_fields(&[(
473 "a",
474 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
475 .unwrap()
476 .into_array(),
477 )])
478 .unwrap()
479 .into_array(),
480 ),
481 ])
482 .unwrap()
483 .into_array();
484 #[expect(deprecated)]
485 let actual_array = test_array.apply(&expr).unwrap().to_struct();
486
487 #[expect(deprecated)]
488 let inner_struct = actual_array
489 .unmasked_field_by_name("a")
490 .unwrap()
491 .to_struct();
492 assert_eq!(
493 inner_struct
494 .names()
495 .iter()
496 .map(|name| name.as_ref())
497 .collect::<Vec<_>>(),
498 vec!["x"]
499 );
500 }
501
502 #[test]
503 pub fn test_merge_order() {
504 let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
505
506 let test_array = StructArray::from_fields(&[
507 (
508 "0",
509 StructArray::from_fields(&[
510 ("a", buffer![0, 0, 0].into_array()),
511 ("c", buffer![1, 1, 1].into_array()),
512 ])
513 .unwrap()
514 .into_array(),
515 ),
516 (
517 "1",
518 StructArray::from_fields(&[
519 ("b", buffer![2, 2, 2].into_array()),
520 ("d", buffer![3, 3, 3].into_array()),
521 ])
522 .unwrap()
523 .into_array(),
524 ),
525 ])
526 .unwrap()
527 .into_array();
528 #[expect(deprecated)]
529 let actual_array = test_array.apply(&expr).unwrap().to_struct();
530
531 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
532 }
533
534 #[test]
535 pub fn test_display() {
536 let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
537 assert_eq!(
538 expr.to_string(),
539 "vortex.merge($.struct1, $.struct2, opts=Error)"
540 );
541
542 let expr2 = merge(vec![get_item("a", root())]);
543 assert_eq!(expr2.to_string(), "vortex.merge($.a, opts=Error)");
544 }
545
546 #[test]
547 fn test_remove_merge() {
548 let dtype = DType::struct_(
549 [
550 ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
551 ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
552 ],
553 NonNullable,
554 );
555
556 let e = merge_opts(
557 [get_item("0", root()), get_item("1", root())],
558 DuplicateHandling::RightMost,
559 );
560
561 let result = e.optimize(&dtype).unwrap();
562
563 assert!(result.is::<Pack>());
564 assert_eq!(
565 result.return_dtype(&dtype).unwrap(),
566 DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
567 );
568 }
569}