1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_session::VortexSession;
14use vortex_utils::aliases::hash_set::HashSet;
15
16use crate::ArrayRef;
17use crate::IntoArray as _;
18use crate::arrays::StructArray;
19use crate::dtype::DType;
20use crate::dtype::FieldNames;
21use crate::dtype::Nullability;
22use crate::dtype::StructFields;
23use crate::expr::Expression;
24use crate::expr::lit;
25use crate::scalar_fn::Arity;
26use crate::scalar_fn::ChildName;
27use crate::scalar_fn::ExecutionArgs;
28use crate::scalar_fn::ReduceCtx;
29use crate::scalar_fn::ReduceNode;
30use crate::scalar_fn::ReduceNodeRef;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33use crate::scalar_fn::ScalarFnVTableExt;
34use crate::scalar_fn::fns::get_item::GetItem;
35use crate::scalar_fn::fns::pack::Pack;
36use crate::scalar_fn::fns::pack::PackOptions;
37use crate::validity::Validity;
38
39#[derive(Clone)]
46pub struct Merge;
47
48impl ScalarFnVTable for Merge {
49 type Options = DuplicateHandling;
50
51 fn id(&self) -> ScalarFnId {
52 ScalarFnId::new_ref("vortex.merge")
53 }
54
55 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
56 Ok(Some(match instance {
57 DuplicateHandling::RightMost => vec![0x00],
58 DuplicateHandling::Error => vec![0x01],
59 }))
60 }
61
62 fn deserialize(
63 &self,
64 _metadata: &[u8],
65 _session: &VortexSession,
66 ) -> VortexResult<Self::Options> {
67 let instance = match _metadata {
68 [0x00] => DuplicateHandling::RightMost,
69 [0x01] => DuplicateHandling::Error,
70 _ => {
71 vortex_bail!("invalid metadata for Merge expression");
72 }
73 };
74 Ok(instance)
75 }
76
77 fn arity(&self, _options: &Self::Options) -> Arity {
78 Arity::Variadic { min: 0, max: None }
79 }
80
81 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
82 ChildName::from(Arc::from(format!("{}", child_idx)))
83 }
84
85 fn fmt_sql(
86 &self,
87 _options: &Self::Options,
88 expr: &Expression,
89 f: &mut Formatter<'_>,
90 ) -> std::fmt::Result {
91 write!(f, "merge(")?;
92 for (i, child) in expr.children().iter().enumerate() {
93 child.fmt_sql(f)?;
94 if i + 1 < expr.children().len() {
95 write!(f, ", ")?;
96 }
97 }
98 write!(f, ")")
99 }
100
101 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
102 let mut field_names = Vec::new();
103 let mut arrays = Vec::new();
104 let mut merge_nullability = Nullability::NonNullable;
105 let mut duplicate_names = HashSet::<_>::new();
106
107 for dtype in arg_dtypes {
108 let Some(fields) = dtype.as_struct_fields_opt() else {
109 vortex_bail!("merge expects struct input");
110 };
111 if dtype.is_nullable() {
112 vortex_bail!("merge expects non-nullable input");
113 }
114
115 merge_nullability |= dtype.nullability();
116
117 for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
118 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
119 duplicate_names.insert(field_name.clone());
120 arrays[idx] = field_dtype;
121 } else {
122 field_names.push(field_name.clone());
123 arrays.push(field_dtype);
124 }
125 }
126 }
127
128 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
129 vortex_bail!(
130 "merge: duplicate fields in children: {}",
131 duplicate_names.into_iter().format(", ")
132 )
133 }
134
135 Ok(DType::Struct(
136 StructFields::new(FieldNames::from(field_names), arrays),
137 merge_nullability,
138 ))
139 }
140
141 fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
142 let mut field_names = Vec::new();
144 let mut arrays = Vec::new();
145 let mut duplicate_names = HashSet::<_>::new();
146
147 for input in args.inputs {
148 let array = input.execute::<StructArray>(args.ctx)?;
149 if array.dtype().is_nullable() {
150 vortex_bail!("merge expects non-nullable input");
151 }
152
153 for (field_name, field_array) in array
154 .names()
155 .iter()
156 .zip_eq(array.unmasked_fields().iter().cloned())
157 {
158 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
160 duplicate_names.insert(field_name.clone());
161 arrays[idx] = field_array;
162 } else {
163 field_names.push(field_name.clone());
164 arrays.push(field_array);
165 }
166 }
167 }
168
169 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
170 vortex_bail!(
171 "merge: duplicate fields in children: {}",
172 duplicate_names.into_iter().format(", ")
173 )
174 }
175
176 let validity = Validity::NonNullable;
178 let len = args.row_count;
179 Ok(
180 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
181 .into_array(),
182 )
183 }
184
185 fn reduce(
186 &self,
187 options: &Self::Options,
188 node: &dyn ReduceNode,
189 ctx: &dyn ReduceCtx,
190 ) -> VortexResult<Option<ReduceNodeRef>> {
191 let mut names = Vec::with_capacity(node.child_count() * 2);
192 let mut children = Vec::with_capacity(node.child_count() * 2);
193 let mut duplicate_names = HashSet::<_>::new();
194
195 for child in (0..node.child_count()).map(|i| node.child(i)) {
196 let child_dtype = child.node_dtype()?;
197 if !child_dtype.is_struct() {
198 vortex_bail!(
199 "Merge child must return a non-nullable struct dtype, got {}",
200 child_dtype
201 )
202 }
203
204 let child_dtype = child_dtype
205 .as_struct_fields_opt()
206 .vortex_expect("expected struct");
207
208 for name in child_dtype.names().iter() {
209 if let Some(idx) = names.iter().position(|n| n == name) {
210 duplicate_names.insert(name.clone());
211 children[idx] = child.clone();
212 } else {
213 names.push(name.clone());
214 children.push(child.clone());
215 }
216 }
217
218 if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
219 vortex_bail!(
220 "merge: duplicate fields in children: {}",
221 duplicate_names.into_iter().format(", ")
222 )
223 }
224 }
225
226 let pack_children: Vec<_> = names
227 .iter()
228 .zip(children)
229 .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
230 .try_collect()?;
231
232 let pack_expr = ctx.new_node(
233 Pack.bind(PackOptions {
234 names: FieldNames::from(names),
235 nullability: node.node_dtype()?.nullability(),
236 }),
237 &pack_children,
238 )?;
239
240 Ok(Some(pack_expr))
241 }
242
243 fn validity(
244 &self,
245 _options: &Self::Options,
246 _expression: &Expression,
247 ) -> VortexResult<Option<Expression>> {
248 Ok(Some(lit(true)))
249 }
250
251 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
252 true
253 }
254
255 fn is_fallible(&self, instance: &Self::Options) -> bool {
256 matches!(instance, DuplicateHandling::Error)
257 }
258}
259
260#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
262pub enum DuplicateHandling {
263 RightMost,
265 #[default]
267 Error,
268}
269
270impl Display for DuplicateHandling {
271 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
272 match self {
273 DuplicateHandling::RightMost => write!(f, "RightMost"),
274 DuplicateHandling::Error => write!(f, "Error"),
275 }
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use vortex_buffer::buffer;
282 use vortex_error::VortexResult;
283 use vortex_error::vortex_bail;
284
285 use crate::Array;
286 use crate::IntoArray;
287 use crate::ToCanonical;
288 use crate::arrays::PrimitiveArray;
289 use crate::arrays::StructArray;
290 use crate::assert_arrays_eq;
291 use crate::dtype::DType;
292 use crate::dtype::Nullability::NonNullable;
293 use crate::dtype::PType::I32;
294 use crate::dtype::PType::I64;
295 use crate::dtype::PType::U32;
296 use crate::dtype::PType::U64;
297 use crate::expr::Expression;
298 use crate::expr::get_item;
299 use crate::expr::merge;
300 use crate::expr::merge_opts;
301 use crate::expr::root;
302 use crate::scalar_fn::fns::merge::DuplicateHandling;
303 use crate::scalar_fn::fns::pack::Pack;
304
305 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
306 let mut field_path = field_path.iter();
307
308 let Some(field) = field_path.next() else {
309 vortex_bail!("empty field path");
310 };
311
312 let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
313 for field in field_path {
314 array = array.to_struct().unmasked_field_by_name(field)?.clone();
315 }
316 Ok(array.to_primitive())
317 }
318
319 #[test]
320 pub fn test_merge_right_most() {
321 let expr = merge_opts(
322 vec![
323 get_item("0", root()),
324 get_item("1", root()),
325 get_item("2", root()),
326 ],
327 DuplicateHandling::RightMost,
328 );
329
330 let test_array = StructArray::from_fields(&[
331 (
332 "0",
333 StructArray::from_fields(&[
334 ("a", buffer![0, 0, 0].into_array()),
335 ("b", buffer![1, 1, 1].into_array()),
336 ])
337 .unwrap()
338 .into_array(),
339 ),
340 (
341 "1",
342 StructArray::from_fields(&[
343 ("b", buffer![2, 2, 2].into_array()),
344 ("c", buffer![3, 3, 3].into_array()),
345 ])
346 .unwrap()
347 .into_array(),
348 ),
349 (
350 "2",
351 StructArray::from_fields(&[
352 ("d", buffer![4, 4, 4].into_array()),
353 ("e", buffer![5, 5, 5].into_array()),
354 ])
355 .unwrap()
356 .into_array(),
357 ),
358 ])
359 .unwrap()
360 .into_array();
361 let actual_array = test_array.apply(&expr).unwrap();
362
363 assert_eq!(
364 actual_array.as_struct_typed().names(),
365 ["a", "b", "c", "d", "e"]
366 );
367
368 assert_arrays_eq!(
369 primitive_field(&actual_array, &["a"]).unwrap(),
370 PrimitiveArray::from_iter([0i32, 0, 0])
371 );
372 assert_arrays_eq!(
373 primitive_field(&actual_array, &["b"]).unwrap(),
374 PrimitiveArray::from_iter([2i32, 2, 2])
375 );
376 assert_arrays_eq!(
377 primitive_field(&actual_array, &["c"]).unwrap(),
378 PrimitiveArray::from_iter([3i32, 3, 3])
379 );
380 assert_arrays_eq!(
381 primitive_field(&actual_array, &["d"]).unwrap(),
382 PrimitiveArray::from_iter([4i32, 4, 4])
383 );
384 assert_arrays_eq!(
385 primitive_field(&actual_array, &["e"]).unwrap(),
386 PrimitiveArray::from_iter([5i32, 5, 5])
387 );
388 }
389
390 #[test]
391 #[should_panic(expected = "merge: duplicate fields in children")]
392 pub fn test_merge_error_on_dupe_return_dtype() {
393 let expr = merge_opts(
394 vec![get_item("0", root()), get_item("1", root())],
395 DuplicateHandling::Error,
396 );
397 let test_array = StructArray::try_from_iter([
398 (
399 "0",
400 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
401 ),
402 (
403 "1",
404 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
405 ),
406 ])
407 .unwrap()
408 .into_array();
409
410 expr.return_dtype(test_array.dtype()).unwrap();
411 }
412
413 #[test]
414 #[should_panic(expected = "merge: duplicate fields in children")]
415 pub fn test_merge_error_on_dupe_evaluate() {
416 let expr = merge_opts(
417 vec![get_item("0", root()), get_item("1", root())],
418 DuplicateHandling::Error,
419 );
420 let test_array = StructArray::try_from_iter([
421 (
422 "0",
423 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
424 ),
425 (
426 "1",
427 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
428 ),
429 ])
430 .unwrap()
431 .into_array();
432
433 test_array.apply(&expr).unwrap();
434 }
435
436 #[test]
437 pub fn test_empty_merge() {
438 let expr = merge(Vec::<Expression>::new());
439
440 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
441 .unwrap()
442 .into_array();
443 let actual_array = test_array.clone().apply(&expr).unwrap();
444 assert_eq!(actual_array.len(), test_array.len());
445 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
446 }
447
448 #[test]
449 pub fn test_nested_merge() {
450 let expr = merge_opts(
453 vec![get_item("0", root()), get_item("1", root())],
454 DuplicateHandling::RightMost,
455 );
456
457 let test_array = StructArray::from_fields(&[
458 (
459 "0",
460 StructArray::from_fields(&[(
461 "a",
462 StructArray::from_fields(&[
463 ("x", buffer![0, 0, 0].into_array()),
464 ("y", buffer![1, 1, 1].into_array()),
465 ])
466 .unwrap()
467 .into_array(),
468 )])
469 .unwrap()
470 .into_array(),
471 ),
472 (
473 "1",
474 StructArray::from_fields(&[(
475 "a",
476 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
477 .unwrap()
478 .into_array(),
479 )])
480 .unwrap()
481 .into_array(),
482 ),
483 ])
484 .unwrap()
485 .into_array();
486 let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
487
488 assert_eq!(
489 actual_array
490 .unmasked_field_by_name("a")
491 .unwrap()
492 .to_struct()
493 .names()
494 .iter()
495 .map(|name| name.as_ref())
496 .collect::<Vec<_>>(),
497 vec!["x"]
498 );
499 }
500
501 #[test]
502 pub fn test_merge_order() {
503 let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
504
505 let test_array = StructArray::from_fields(&[
506 (
507 "0",
508 StructArray::from_fields(&[
509 ("a", buffer![0, 0, 0].into_array()),
510 ("c", buffer![1, 1, 1].into_array()),
511 ])
512 .unwrap()
513 .into_array(),
514 ),
515 (
516 "1",
517 StructArray::from_fields(&[
518 ("b", buffer![2, 2, 2].into_array()),
519 ("d", buffer![3, 3, 3].into_array()),
520 ])
521 .unwrap()
522 .into_array(),
523 ),
524 ])
525 .unwrap()
526 .into_array();
527 let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
528
529 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
530 }
531
532 #[test]
533 pub fn test_display() {
534 let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
535 assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
536
537 let expr2 = merge(vec![get_item("a", root())]);
538 assert_eq!(expr2.to_string(), "merge($.a)");
539 }
540
541 #[test]
542 fn test_remove_merge() {
543 let dtype = DType::struct_(
544 [
545 ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
546 ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
547 ],
548 NonNullable,
549 );
550
551 let e = merge_opts(
552 [get_item("0", root()), get_item("1", root())],
553 DuplicateHandling::RightMost,
554 );
555
556 let result = e.optimize(&dtype).unwrap();
557
558 assert!(result.is::<Pack>());
559 assert_eq!(
560 result.return_dtype(&dtype).unwrap(),
561 DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
562 );
563 }
564}