1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_session::VortexSession;
14use vortex_session::registry::CachedId;
15use vortex_utils::aliases::hash_set::HashSet;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::IntoArray as _;
20use crate::arrays::StructArray;
21use crate::arrays::struct_::StructArrayExt;
22use crate::dtype::DType;
23use crate::dtype::FieldNames;
24use crate::dtype::Nullability;
25use crate::dtype::StructFields;
26use crate::expr::Expression;
27use crate::expr::lit;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ReduceCtx;
32use crate::scalar_fn::ReduceNode;
33use crate::scalar_fn::ReduceNodeRef;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::ScalarFnVTableExt;
37use crate::scalar_fn::fns::get_item::GetItem;
38use crate::scalar_fn::fns::pack::Pack;
39use crate::scalar_fn::fns::pack::PackOptions;
40use crate::validity::Validity;
41
42#[derive(Clone)]
49pub struct Merge;
50
51impl ScalarFnVTable for Merge {
52 type Options = DuplicateHandling;
53
54 fn id(&self) -> ScalarFnId {
55 static ID: CachedId = CachedId::new("vortex.merge");
56 *ID
57 }
58
59 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
60 Ok(Some(match instance {
61 DuplicateHandling::RightMost => vec![0x00],
62 DuplicateHandling::Error => vec![0x01],
63 }))
64 }
65
66 fn deserialize(
67 &self,
68 _metadata: &[u8],
69 _session: &VortexSession,
70 ) -> VortexResult<Self::Options> {
71 let instance = match _metadata {
72 [0x00] => DuplicateHandling::RightMost,
73 [0x01] => DuplicateHandling::Error,
74 _ => {
75 vortex_bail!("invalid metadata for Merge expression");
76 }
77 };
78 Ok(instance)
79 }
80
81 fn arity(&self, _options: &Self::Options) -> Arity {
82 Arity::Variadic { min: 0, max: None }
83 }
84
85 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
86 ChildName::from(Arc::from(format!("{}", child_idx)))
87 }
88
89 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
90 let mut field_names = Vec::new();
91 let mut arrays = Vec::new();
92 let mut merge_nullability = Nullability::NonNullable;
93 let mut duplicate_names = HashSet::<_>::new();
94
95 for dtype in arg_dtypes {
96 let Some(fields) = dtype.as_struct_fields_opt() else {
97 vortex_bail!("merge expects struct input");
98 };
99 if dtype.is_nullable() {
100 vortex_bail!("merge expects non-nullable input");
101 }
102
103 merge_nullability |= dtype.nullability();
104
105 for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
106 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
107 duplicate_names.insert(field_name.clone());
108 arrays[idx] = field_dtype;
109 } else {
110 field_names.push(field_name.clone());
111 arrays.push(field_dtype);
112 }
113 }
114 }
115
116 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
117 vortex_bail!(
118 "merge: duplicate fields in children: {}",
119 duplicate_names.into_iter().format(", ")
120 )
121 }
122
123 Ok(DType::Struct(
124 StructFields::new(FieldNames::from(field_names), arrays),
125 merge_nullability,
126 ))
127 }
128
129 fn execute(
130 &self,
131 options: &Self::Options,
132 args: &dyn ExecutionArgs,
133 ctx: &mut ExecutionCtx,
134 ) -> VortexResult<ArrayRef> {
135 let mut field_names = Vec::new();
137 let mut arrays = Vec::new();
138 let mut duplicate_names = HashSet::<_>::new();
139
140 for i in 0..args.num_inputs() {
141 let array = args.get(i)?.execute::<StructArray>(ctx)?;
142 if array.dtype().is_nullable() {
143 vortex_bail!("merge expects non-nullable input");
144 }
145
146 for (field_name, field_array) in array
147 .names()
148 .iter()
149 .zip_eq(array.iter_unmasked_fields().cloned())
150 {
151 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
153 duplicate_names.insert(field_name.clone());
154 arrays[idx] = field_array;
155 } else {
156 field_names.push(field_name.clone());
157 arrays.push(field_array);
158 }
159 }
160 }
161
162 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
163 vortex_bail!(
164 "merge: duplicate fields in children: {}",
165 duplicate_names.into_iter().format(", ")
166 )
167 }
168
169 let validity = Validity::NonNullable;
171 let len = args.row_count();
172 Ok(
173 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
174 .into_array(),
175 )
176 }
177
178 fn reduce(
179 &self,
180 options: &Self::Options,
181 node: &dyn ReduceNode,
182 ctx: &dyn ReduceCtx,
183 ) -> VortexResult<Option<ReduceNodeRef>> {
184 let mut names = Vec::with_capacity(node.child_count() * 2);
185 let mut children = Vec::with_capacity(node.child_count() * 2);
186 let mut duplicate_names = HashSet::<_>::new();
187
188 for child in (0..node.child_count()).map(|i| node.child(i)) {
189 let child_dtype = child.node_dtype()?;
190 if !child_dtype.is_struct() {
191 vortex_bail!(
192 "Merge child must return a non-nullable struct dtype, got {}",
193 child_dtype
194 )
195 }
196
197 let child_dtype = child_dtype
198 .as_struct_fields_opt()
199 .vortex_expect("expected struct");
200
201 for name in child_dtype.names().iter() {
202 if let Some(idx) = names.iter().position(|n| n == name) {
203 duplicate_names.insert(name.clone());
204 children[idx] = Arc::clone(&child);
205 } else {
206 names.push(name.clone());
207 children.push(Arc::clone(&child));
208 }
209 }
210
211 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
212 vortex_bail!(
213 "merge: duplicate fields in children: {}",
214 duplicate_names.into_iter().format(", ")
215 )
216 }
217 }
218
219 let pack_children: Vec<_> = names
220 .iter()
221 .zip(children)
222 .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
223 .try_collect()?;
224
225 let pack_expr = ctx.new_node(
226 Pack.bind(PackOptions {
227 names: FieldNames::from(names),
228 nullability: node.node_dtype()?.nullability(),
229 }),
230 &pack_children,
231 )?;
232
233 Ok(Some(pack_expr))
234 }
235
236 fn validity(
237 &self,
238 _options: &Self::Options,
239 _expression: &Expression,
240 ) -> VortexResult<Option<Expression>> {
241 Ok(Some(lit(true)))
242 }
243
244 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
245 true
246 }
247
248 fn is_fallible(&self, instance: &Self::Options) -> bool {
249 matches!(instance, DuplicateHandling::Error)
250 }
251}
252
253#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
255pub enum DuplicateHandling {
256 RightMost,
258 #[default]
260 Error,
261}
262
263impl Display for DuplicateHandling {
264 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
265 match self {
266 DuplicateHandling::RightMost => write!(f, "RightMost"),
267 DuplicateHandling::Error => write!(f, "Error"),
268 }
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use vortex_buffer::buffer;
275 use vortex_error::VortexResult;
276 use vortex_error::vortex_bail;
277
278 use crate::ArrayRef;
279 use crate::IntoArray;
280 #[expect(deprecated)]
281 use crate::ToCanonical as _;
282 use crate::VortexSessionExecute;
283 use crate::array_session;
284 use crate::arrays::PrimitiveArray;
285 use crate::arrays::struct_::StructArrayExt;
286 use crate::assert_arrays_eq;
287 use crate::dtype::DType;
288 use crate::dtype::Nullability::NonNullable;
289 use crate::dtype::PType::I32;
290 use crate::dtype::PType::I64;
291 use crate::dtype::PType::U32;
292 use crate::dtype::PType::U64;
293 use crate::expr::Expression;
294 use crate::expr::get_item;
295 use crate::expr::merge;
296 use crate::expr::merge_opts;
297 use crate::expr::root;
298 use crate::scalar_fn::fns::merge::DuplicateHandling;
299 use crate::scalar_fn::fns::merge::StructArray;
300 use crate::scalar_fn::fns::pack::Pack;
301
302 fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
303 let mut field_path = field_path.iter();
304
305 let Some(field) = field_path.next() else {
306 vortex_bail!("empty field path");
307 };
308
309 #[expect(deprecated)]
310 let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
311 for field in field_path {
312 #[expect(deprecated)]
313 let next = array.to_struct().unmasked_field_by_name(field)?.clone();
314 array = next;
315 }
316 #[expect(deprecated)]
317 let result = array.to_primitive();
318 Ok(result)
319 }
320
321 #[test]
322 pub fn test_merge_right_most() {
323 let mut ctx = array_session().create_execution_ctx();
324 let expr = merge_opts(
325 vec![
326 get_item("0", root()),
327 get_item("1", root()),
328 get_item("2", root()),
329 ],
330 DuplicateHandling::RightMost,
331 );
332
333 let test_array = StructArray::from_fields(&[
334 (
335 "0",
336 StructArray::from_fields(&[
337 ("a", buffer![0, 0, 0].into_array()),
338 ("b", buffer![1, 1, 1].into_array()),
339 ])
340 .unwrap()
341 .into_array(),
342 ),
343 (
344 "1",
345 StructArray::from_fields(&[
346 ("b", buffer![2, 2, 2].into_array()),
347 ("c", buffer![3, 3, 3].into_array()),
348 ])
349 .unwrap()
350 .into_array(),
351 ),
352 (
353 "2",
354 StructArray::from_fields(&[
355 ("d", buffer![4, 4, 4].into_array()),
356 ("e", buffer![5, 5, 5].into_array()),
357 ])
358 .unwrap()
359 .into_array(),
360 ),
361 ])
362 .unwrap()
363 .into_array();
364 let actual_array = test_array.apply(&expr).unwrap();
365
366 assert_eq!(
367 actual_array.as_struct_typed().names(),
368 ["a", "b", "c", "d", "e"]
369 );
370
371 assert_arrays_eq!(
372 primitive_field(&actual_array, &["a"]).unwrap(),
373 PrimitiveArray::from_iter([0i32, 0, 0]),
374 &mut ctx
375 );
376 assert_arrays_eq!(
377 primitive_field(&actual_array, &["b"]).unwrap(),
378 PrimitiveArray::from_iter([2i32, 2, 2]),
379 &mut ctx
380 );
381 assert_arrays_eq!(
382 primitive_field(&actual_array, &["c"]).unwrap(),
383 PrimitiveArray::from_iter([3i32, 3, 3]),
384 &mut ctx
385 );
386 assert_arrays_eq!(
387 primitive_field(&actual_array, &["d"]).unwrap(),
388 PrimitiveArray::from_iter([4i32, 4, 4]),
389 &mut ctx
390 );
391 assert_arrays_eq!(
392 primitive_field(&actual_array, &["e"]).unwrap(),
393 PrimitiveArray::from_iter([5i32, 5, 5]),
394 &mut ctx
395 );
396 }
397
398 #[test]
399 #[should_panic(expected = "merge: duplicate fields in children")]
400 pub fn test_merge_error_on_dupe_return_dtype() {
401 let expr = merge_opts(
402 vec![get_item("0", root()), get_item("1", root())],
403 DuplicateHandling::Error,
404 );
405 let test_array = StructArray::try_from_iter([
406 (
407 "0",
408 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
409 ),
410 (
411 "1",
412 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
413 ),
414 ])
415 .unwrap()
416 .into_array();
417
418 expr.return_dtype(test_array.dtype()).unwrap();
419 }
420
421 #[test]
422 #[should_panic(expected = "merge: duplicate fields in children")]
423 pub fn test_merge_error_on_dupe_evaluate() {
424 let expr = merge_opts(
425 vec![get_item("0", root()), get_item("1", root())],
426 DuplicateHandling::Error,
427 );
428 let test_array = StructArray::try_from_iter([
429 (
430 "0",
431 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
432 ),
433 (
434 "1",
435 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
436 ),
437 ])
438 .unwrap()
439 .into_array();
440
441 test_array.apply(&expr).unwrap();
442 }
443
444 #[test]
445 pub fn test_empty_merge() {
446 let expr = merge(Vec::<Expression>::new());
447
448 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
449 .unwrap()
450 .into_array();
451 let actual_array = test_array.clone().apply(&expr).unwrap();
452 assert_eq!(actual_array.len(), test_array.len());
453 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
454 }
455
456 #[test]
457 pub fn test_nested_merge() {
458 let expr = merge_opts(
461 vec![get_item("0", root()), get_item("1", root())],
462 DuplicateHandling::RightMost,
463 );
464
465 let test_array = StructArray::from_fields(&[
466 (
467 "0",
468 StructArray::from_fields(&[(
469 "a",
470 StructArray::from_fields(&[
471 ("x", buffer![0, 0, 0].into_array()),
472 ("y", buffer![1, 1, 1].into_array()),
473 ])
474 .unwrap()
475 .into_array(),
476 )])
477 .unwrap()
478 .into_array(),
479 ),
480 (
481 "1",
482 StructArray::from_fields(&[(
483 "a",
484 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
485 .unwrap()
486 .into_array(),
487 )])
488 .unwrap()
489 .into_array(),
490 ),
491 ])
492 .unwrap()
493 .into_array();
494 #[expect(deprecated)]
495 let actual_array = test_array.apply(&expr).unwrap().to_struct();
496
497 #[expect(deprecated)]
498 let inner_struct = actual_array
499 .unmasked_field_by_name("a")
500 .unwrap()
501 .to_struct();
502 assert_eq!(
503 inner_struct
504 .names()
505 .iter()
506 .map(|name| name.as_ref())
507 .collect::<Vec<_>>(),
508 vec!["x"]
509 );
510 }
511
512 #[test]
513 pub fn test_merge_order() {
514 let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
515
516 let test_array = StructArray::from_fields(&[
517 (
518 "0",
519 StructArray::from_fields(&[
520 ("a", buffer![0, 0, 0].into_array()),
521 ("c", buffer![1, 1, 1].into_array()),
522 ])
523 .unwrap()
524 .into_array(),
525 ),
526 (
527 "1",
528 StructArray::from_fields(&[
529 ("b", buffer![2, 2, 2].into_array()),
530 ("d", buffer![3, 3, 3].into_array()),
531 ])
532 .unwrap()
533 .into_array(),
534 ),
535 ])
536 .unwrap()
537 .into_array();
538 #[expect(deprecated)]
539 let actual_array = test_array.apply(&expr).unwrap().to_struct();
540
541 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
542 }
543
544 #[test]
545 pub fn test_display() {
546 let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
547 assert_eq!(
548 expr.to_string(),
549 "vortex.merge($.struct1, $.struct2, opts=Error)"
550 );
551
552 let expr2 = merge(vec![get_item("a", root())]);
553 assert_eq!(expr2.to_string(), "vortex.merge($.a, opts=Error)");
554 }
555
556 #[test]
557 fn test_remove_merge() {
558 let dtype = DType::struct_(
559 [
560 ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
561 ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
562 ],
563 NonNullable,
564 );
565
566 let e = merge_opts(
567 [get_item("0", root()), get_item("1", root())],
568 DuplicateHandling::RightMost,
569 );
570
571 let result = e.optimize(&dtype).unwrap();
572
573 assert!(result.is::<Pack>());
574 assert_eq!(
575 result.return_dtype(&dtype).unwrap(),
576 DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
577 );
578 }
579}