Skip to main content

vortex_array/scalar_fn/fns/
merge.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_session::VortexSession;
14use vortex_utils::aliases::hash_set::HashSet;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray as _;
19use crate::arrays::StructArray;
20use crate::arrays::struct_::StructArrayExt;
21use crate::dtype::DType;
22use crate::dtype::FieldNames;
23use crate::dtype::Nullability;
24use crate::dtype::StructFields;
25use crate::expr::Expression;
26use crate::expr::lit;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::ExecutionArgs;
30use crate::scalar_fn::ReduceCtx;
31use crate::scalar_fn::ReduceNode;
32use crate::scalar_fn::ReduceNodeRef;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::ScalarFnVTableExt;
36use crate::scalar_fn::fns::get_item::GetItem;
37use crate::scalar_fn::fns::pack::Pack;
38use crate::scalar_fn::fns::pack::PackOptions;
39use crate::validity::Validity;
40
41/// Merge zero or more expressions that ALL return structs.
42///
43/// If any field names are duplicated, the field from later expressions wins.
44///
45/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
46/// This makes struct fields behaviour consistent with other dtypes.
47#[derive(Clone)]
48pub struct Merge;
49
50impl ScalarFnVTable for Merge {
51    type Options = DuplicateHandling;
52
53    fn id(&self) -> ScalarFnId {
54        ScalarFnId::from("vortex.merge")
55    }
56
57    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
58        Ok(Some(match instance {
59            DuplicateHandling::RightMost => vec![0x00],
60            DuplicateHandling::Error => vec![0x01],
61        }))
62    }
63
64    fn deserialize(
65        &self,
66        _metadata: &[u8],
67        _session: &VortexSession,
68    ) -> VortexResult<Self::Options> {
69        let instance = match _metadata {
70            [0x00] => DuplicateHandling::RightMost,
71            [0x01] => DuplicateHandling::Error,
72            _ => {
73                vortex_bail!("invalid metadata for Merge expression");
74            }
75        };
76        Ok(instance)
77    }
78
79    fn arity(&self, _options: &Self::Options) -> Arity {
80        Arity::Variadic { min: 0, max: None }
81    }
82
83    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
84        ChildName::from(Arc::from(format!("{}", child_idx)))
85    }
86
87    fn fmt_sql(
88        &self,
89        _options: &Self::Options,
90        expr: &Expression,
91        f: &mut Formatter<'_>,
92    ) -> std::fmt::Result {
93        write!(f, "merge(")?;
94        for (i, child) in expr.children().iter().enumerate() {
95            child.fmt_sql(f)?;
96            if i + 1 < expr.children().len() {
97                write!(f, ", ")?;
98            }
99        }
100        write!(f, ")")
101    }
102
103    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
104        let mut field_names = Vec::new();
105        let mut arrays = Vec::new();
106        let mut merge_nullability = Nullability::NonNullable;
107        let mut duplicate_names = HashSet::<_>::new();
108
109        for dtype in arg_dtypes {
110            let Some(fields) = dtype.as_struct_fields_opt() else {
111                vortex_bail!("merge expects struct input");
112            };
113            if dtype.is_nullable() {
114                vortex_bail!("merge expects non-nullable input");
115            }
116
117            merge_nullability |= dtype.nullability();
118
119            for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
120                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
121                    duplicate_names.insert(field_name.clone());
122                    arrays[idx] = field_dtype;
123                } else {
124                    field_names.push(field_name.clone());
125                    arrays.push(field_dtype);
126                }
127            }
128        }
129
130        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
131            vortex_bail!(
132                "merge: duplicate fields in children: {}",
133                duplicate_names.into_iter().format(", ")
134            )
135        }
136
137        Ok(DType::Struct(
138            StructFields::new(FieldNames::from(field_names), arrays),
139            merge_nullability,
140        ))
141    }
142
143    fn execute(
144        &self,
145        options: &Self::Options,
146        args: &dyn ExecutionArgs,
147        ctx: &mut ExecutionCtx,
148    ) -> VortexResult<ArrayRef> {
149        // Collect fields in order of appearance. Later fields overwrite earlier fields.
150        let mut field_names = Vec::new();
151        let mut arrays = Vec::new();
152        let mut duplicate_names = HashSet::<_>::new();
153
154        for i in 0..args.num_inputs() {
155            let array = args.get(i)?.execute::<StructArray>(ctx)?;
156            if array.dtype().is_nullable() {
157                vortex_bail!("merge expects non-nullable input");
158            }
159
160            for (field_name, field_array) in array
161                .names()
162                .iter()
163                .zip_eq(array.iter_unmasked_fields().cloned())
164            {
165                // Update or insert field.
166                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
167                    duplicate_names.insert(field_name.clone());
168                    arrays[idx] = field_array;
169                } else {
170                    field_names.push(field_name.clone());
171                    arrays.push(field_array);
172                }
173            }
174        }
175
176        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
177            vortex_bail!(
178                "merge: duplicate fields in children: {}",
179                duplicate_names.into_iter().format(", ")
180            )
181        }
182
183        // TODO(DK): When children are allowed to be nullable, this needs to change.
184        let validity = Validity::NonNullable;
185        let len = args.row_count();
186        Ok(
187            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
188                .into_array(),
189        )
190    }
191
192    fn reduce(
193        &self,
194        options: &Self::Options,
195        node: &dyn ReduceNode,
196        ctx: &dyn ReduceCtx,
197    ) -> VortexResult<Option<ReduceNodeRef>> {
198        let mut names = Vec::with_capacity(node.child_count() * 2);
199        let mut children = Vec::with_capacity(node.child_count() * 2);
200        let mut duplicate_names = HashSet::<_>::new();
201
202        for child in (0..node.child_count()).map(|i| node.child(i)) {
203            let child_dtype = child.node_dtype()?;
204            if !child_dtype.is_struct() {
205                vortex_bail!(
206                    "Merge child must return a non-nullable struct dtype, got {}",
207                    child_dtype
208                )
209            }
210
211            let child_dtype = child_dtype
212                .as_struct_fields_opt()
213                .vortex_expect("expected struct");
214
215            for name in child_dtype.names().iter() {
216                if let Some(idx) = names.iter().position(|n| n == name) {
217                    duplicate_names.insert(name.clone());
218                    children[idx] = Arc::clone(&child);
219                } else {
220                    names.push(name.clone());
221                    children.push(Arc::clone(&child));
222                }
223            }
224
225            if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
226                vortex_bail!(
227                    "merge: duplicate fields in children: {}",
228                    duplicate_names.into_iter().format(", ")
229                )
230            }
231        }
232
233        let pack_children: Vec<_> = names
234            .iter()
235            .zip(children)
236            .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
237            .try_collect()?;
238
239        let pack_expr = ctx.new_node(
240            Pack.bind(PackOptions {
241                names: FieldNames::from(names),
242                nullability: node.node_dtype()?.nullability(),
243            }),
244            &pack_children,
245        )?;
246
247        Ok(Some(pack_expr))
248    }
249
250    fn validity(
251        &self,
252        _options: &Self::Options,
253        _expression: &Expression,
254    ) -> VortexResult<Option<Expression>> {
255        Ok(Some(lit(true)))
256    }
257
258    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
259        true
260    }
261
262    fn is_fallible(&self, instance: &Self::Options) -> bool {
263        matches!(instance, DuplicateHandling::Error)
264    }
265}
266
267/// What to do when merged structs share a field name.
268#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
269pub enum DuplicateHandling {
270    /// If two structs share a field name, take the value from the right-most struct.
271    RightMost,
272    /// If two structs share a field name, error.
273    #[default]
274    Error,
275}
276
277impl Display for DuplicateHandling {
278    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279        match self {
280            DuplicateHandling::RightMost => write!(f, "RightMost"),
281            DuplicateHandling::Error => write!(f, "Error"),
282        }
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use vortex_buffer::buffer;
289    use vortex_error::VortexResult;
290    use vortex_error::vortex_bail;
291
292    use crate::ArrayRef;
293    use crate::IntoArray;
294    #[expect(deprecated)]
295    use crate::ToCanonical as _;
296    use crate::arrays::PrimitiveArray;
297    use crate::arrays::struct_::StructArrayExt;
298    use crate::assert_arrays_eq;
299    use crate::dtype::DType;
300    use crate::dtype::Nullability::NonNullable;
301    use crate::dtype::PType::I32;
302    use crate::dtype::PType::I64;
303    use crate::dtype::PType::U32;
304    use crate::dtype::PType::U64;
305    use crate::expr::Expression;
306    use crate::expr::get_item;
307    use crate::expr::merge;
308    use crate::expr::merge_opts;
309    use crate::expr::root;
310    use crate::scalar_fn::fns::merge::DuplicateHandling;
311    use crate::scalar_fn::fns::merge::StructArray;
312    use crate::scalar_fn::fns::pack::Pack;
313
314    fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
315        let mut field_path = field_path.iter();
316
317        let Some(field) = field_path.next() else {
318            vortex_bail!("empty field path");
319        };
320
321        #[expect(deprecated)]
322        let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
323        for field in field_path {
324            #[expect(deprecated)]
325            let next = array.to_struct().unmasked_field_by_name(field)?.clone();
326            array = next;
327        }
328        #[expect(deprecated)]
329        let result = array.to_primitive();
330        Ok(result)
331    }
332
333    #[test]
334    pub fn test_merge_right_most() {
335        let expr = merge_opts(
336            vec![
337                get_item("0", root()),
338                get_item("1", root()),
339                get_item("2", root()),
340            ],
341            DuplicateHandling::RightMost,
342        );
343
344        let test_array = StructArray::from_fields(&[
345            (
346                "0",
347                StructArray::from_fields(&[
348                    ("a", buffer![0, 0, 0].into_array()),
349                    ("b", buffer![1, 1, 1].into_array()),
350                ])
351                .unwrap()
352                .into_array(),
353            ),
354            (
355                "1",
356                StructArray::from_fields(&[
357                    ("b", buffer![2, 2, 2].into_array()),
358                    ("c", buffer![3, 3, 3].into_array()),
359                ])
360                .unwrap()
361                .into_array(),
362            ),
363            (
364                "2",
365                StructArray::from_fields(&[
366                    ("d", buffer![4, 4, 4].into_array()),
367                    ("e", buffer![5, 5, 5].into_array()),
368                ])
369                .unwrap()
370                .into_array(),
371            ),
372        ])
373        .unwrap()
374        .into_array();
375        let actual_array = test_array.apply(&expr).unwrap();
376
377        assert_eq!(
378            actual_array.as_struct_typed().names(),
379            ["a", "b", "c", "d", "e"]
380        );
381
382        assert_arrays_eq!(
383            primitive_field(&actual_array, &["a"]).unwrap(),
384            PrimitiveArray::from_iter([0i32, 0, 0])
385        );
386        assert_arrays_eq!(
387            primitive_field(&actual_array, &["b"]).unwrap(),
388            PrimitiveArray::from_iter([2i32, 2, 2])
389        );
390        assert_arrays_eq!(
391            primitive_field(&actual_array, &["c"]).unwrap(),
392            PrimitiveArray::from_iter([3i32, 3, 3])
393        );
394        assert_arrays_eq!(
395            primitive_field(&actual_array, &["d"]).unwrap(),
396            PrimitiveArray::from_iter([4i32, 4, 4])
397        );
398        assert_arrays_eq!(
399            primitive_field(&actual_array, &["e"]).unwrap(),
400            PrimitiveArray::from_iter([5i32, 5, 5])
401        );
402    }
403
404    #[test]
405    #[should_panic(expected = "merge: duplicate fields in children")]
406    pub fn test_merge_error_on_dupe_return_dtype() {
407        let expr = merge_opts(
408            vec![get_item("0", root()), get_item("1", root())],
409            DuplicateHandling::Error,
410        );
411        let test_array = StructArray::try_from_iter([
412            (
413                "0",
414                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
415            ),
416            (
417                "1",
418                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
419            ),
420        ])
421        .unwrap()
422        .into_array();
423
424        expr.return_dtype(test_array.dtype()).unwrap();
425    }
426
427    #[test]
428    #[should_panic(expected = "merge: duplicate fields in children")]
429    pub fn test_merge_error_on_dupe_evaluate() {
430        let expr = merge_opts(
431            vec![get_item("0", root()), get_item("1", root())],
432            DuplicateHandling::Error,
433        );
434        let test_array = StructArray::try_from_iter([
435            (
436                "0",
437                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
438            ),
439            (
440                "1",
441                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
442            ),
443        ])
444        .unwrap()
445        .into_array();
446
447        test_array.apply(&expr).unwrap();
448    }
449
450    #[test]
451    pub fn test_empty_merge() {
452        let expr = merge(Vec::<Expression>::new());
453
454        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
455            .unwrap()
456            .into_array();
457        let actual_array = test_array.clone().apply(&expr).unwrap();
458        assert_eq!(actual_array.len(), test_array.len());
459        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
460    }
461
462    #[test]
463    pub fn test_nested_merge() {
464        // Nested structs are not merged!
465
466        let expr = merge_opts(
467            vec![get_item("0", root()), get_item("1", root())],
468            DuplicateHandling::RightMost,
469        );
470
471        let test_array = StructArray::from_fields(&[
472            (
473                "0",
474                StructArray::from_fields(&[(
475                    "a",
476                    StructArray::from_fields(&[
477                        ("x", buffer![0, 0, 0].into_array()),
478                        ("y", buffer![1, 1, 1].into_array()),
479                    ])
480                    .unwrap()
481                    .into_array(),
482                )])
483                .unwrap()
484                .into_array(),
485            ),
486            (
487                "1",
488                StructArray::from_fields(&[(
489                    "a",
490                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
491                        .unwrap()
492                        .into_array(),
493                )])
494                .unwrap()
495                .into_array(),
496            ),
497        ])
498        .unwrap()
499        .into_array();
500        #[expect(deprecated)]
501        let actual_array = test_array.apply(&expr).unwrap().to_struct();
502
503        #[expect(deprecated)]
504        let inner_struct = actual_array
505            .unmasked_field_by_name("a")
506            .unwrap()
507            .to_struct();
508        assert_eq!(
509            inner_struct
510                .names()
511                .iter()
512                .map(|name| name.as_ref())
513                .collect::<Vec<_>>(),
514            vec!["x"]
515        );
516    }
517
518    #[test]
519    pub fn test_merge_order() {
520        let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
521
522        let test_array = StructArray::from_fields(&[
523            (
524                "0",
525                StructArray::from_fields(&[
526                    ("a", buffer![0, 0, 0].into_array()),
527                    ("c", buffer![1, 1, 1].into_array()),
528                ])
529                .unwrap()
530                .into_array(),
531            ),
532            (
533                "1",
534                StructArray::from_fields(&[
535                    ("b", buffer![2, 2, 2].into_array()),
536                    ("d", buffer![3, 3, 3].into_array()),
537                ])
538                .unwrap()
539                .into_array(),
540            ),
541        ])
542        .unwrap()
543        .into_array();
544        #[expect(deprecated)]
545        let actual_array = test_array.apply(&expr).unwrap().to_struct();
546
547        assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
548    }
549
550    #[test]
551    pub fn test_display() {
552        let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
553        assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
554
555        let expr2 = merge(vec![get_item("a", root())]);
556        assert_eq!(expr2.to_string(), "merge($.a)");
557    }
558
559    #[test]
560    fn test_remove_merge() {
561        let dtype = DType::struct_(
562            [
563                ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
564                ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
565            ],
566            NonNullable,
567        );
568
569        let e = merge_opts(
570            [get_item("0", root()), get_item("1", root())],
571            DuplicateHandling::RightMost,
572        );
573
574        let result = e.optimize(&dtype).unwrap();
575
576        assert!(result.is::<Pack>());
577        assert_eq!(
578            result.return_dtype(&dtype).unwrap(),
579            DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
580        );
581    }
582}