Skip to main content

vortex_array/expr/exprs/
merge.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_dtype::DType;
11use vortex_dtype::FieldNames;
12use vortex_dtype::Nullability;
13use vortex_dtype::StructFields;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_session::VortexSession;
18use vortex_utils::aliases::hash_set::HashSet;
19
20use crate::ArrayRef;
21use crate::IntoArray as _;
22use crate::arrays::StructArray;
23use crate::expr::Arity;
24use crate::expr::ChildName;
25use crate::expr::ExecutionArgs;
26use crate::expr::ExprId;
27use crate::expr::Expression;
28use crate::expr::GetItem;
29use crate::expr::Pack;
30use crate::expr::PackOptions;
31use crate::expr::ReduceCtx;
32use crate::expr::ReduceNode;
33use crate::expr::ReduceNodeRef;
34use crate::expr::VTable;
35use crate::expr::VTableExt;
36use crate::expr::lit;
37use crate::validity::Validity;
38
39/// Merge zero or more expressions that ALL return structs.
40///
41/// If any field names are duplicated, the field from later expressions wins.
42///
43/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
44/// This makes struct fields behaviour consistent with other dtypes.
45pub struct Merge;
46
47impl VTable for Merge {
48    type Options = DuplicateHandling;
49
50    fn id(&self) -> ExprId {
51        ExprId::new_ref("vortex.merge")
52    }
53
54    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
55        Ok(Some(match instance {
56            DuplicateHandling::RightMost => vec![0x00],
57            DuplicateHandling::Error => vec![0x01],
58        }))
59    }
60
61    fn deserialize(
62        &self,
63        _metadata: &[u8],
64        _session: &VortexSession,
65    ) -> VortexResult<Self::Options> {
66        let instance = match _metadata {
67            [0x00] => DuplicateHandling::RightMost,
68            [0x01] => DuplicateHandling::Error,
69            _ => {
70                vortex_bail!("invalid metadata for Merge expression");
71            }
72        };
73        Ok(instance)
74    }
75
76    fn arity(&self, _options: &Self::Options) -> Arity {
77        Arity::Variadic { min: 0, max: None }
78    }
79
80    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
81        ChildName::from(Arc::from(format!("{}", child_idx)))
82    }
83
84    fn fmt_sql(
85        &self,
86        _options: &Self::Options,
87        expr: &Expression,
88        f: &mut Formatter<'_>,
89    ) -> std::fmt::Result {
90        write!(f, "merge(")?;
91        for (i, child) in expr.children().iter().enumerate() {
92            child.fmt_sql(f)?;
93            if i + 1 < expr.children().len() {
94                write!(f, ", ")?;
95            }
96        }
97        write!(f, ")")
98    }
99
100    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
101        let mut field_names = Vec::new();
102        let mut arrays = Vec::new();
103        let mut merge_nullability = Nullability::NonNullable;
104        let mut duplicate_names = HashSet::<_>::new();
105
106        for dtype in arg_dtypes {
107            let Some(fields) = dtype.as_struct_fields_opt() else {
108                vortex_bail!("merge expects struct input");
109            };
110            if dtype.is_nullable() {
111                vortex_bail!("merge expects non-nullable input");
112            }
113
114            merge_nullability |= dtype.nullability();
115
116            for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
117                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
118                    duplicate_names.insert(field_name.clone());
119                    arrays[idx] = field_dtype;
120                } else {
121                    field_names.push(field_name.clone());
122                    arrays.push(field_dtype);
123                }
124            }
125        }
126
127        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
128            vortex_bail!(
129                "merge: duplicate fields in children: {}",
130                duplicate_names.into_iter().format(", ")
131            )
132        }
133
134        Ok(DType::Struct(
135            StructFields::new(FieldNames::from(field_names), arrays),
136            merge_nullability,
137        ))
138    }
139
140    fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
141        // Collect fields in order of appearance. Later fields overwrite earlier fields.
142        let mut field_names = Vec::new();
143        let mut arrays = Vec::new();
144        let mut duplicate_names = HashSet::<_>::new();
145
146        for input in args.inputs {
147            let array = input.execute::<StructArray>(args.ctx)?;
148            if array.dtype().is_nullable() {
149                vortex_bail!("merge expects non-nullable input");
150            }
151
152            for (field_name, field_array) in array
153                .names()
154                .iter()
155                .zip_eq(array.unmasked_fields().iter().cloned())
156            {
157                // Update or insert field.
158                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
159                    duplicate_names.insert(field_name.clone());
160                    arrays[idx] = field_array;
161                } else {
162                    field_names.push(field_name.clone());
163                    arrays.push(field_array);
164                }
165            }
166        }
167
168        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
169            vortex_bail!(
170                "merge: duplicate fields in children: {}",
171                duplicate_names.into_iter().format(", ")
172            )
173        }
174
175        // TODO(DK): When children are allowed to be nullable, this needs to change.
176        let validity = Validity::NonNullable;
177        let len = args.row_count;
178        Ok(
179            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
180                .into_array(),
181        )
182    }
183
184    fn reduce(
185        &self,
186        options: &Self::Options,
187        node: &dyn ReduceNode,
188        ctx: &dyn ReduceCtx,
189    ) -> VortexResult<Option<ReduceNodeRef>> {
190        let mut names = Vec::with_capacity(node.child_count() * 2);
191        let mut children = Vec::with_capacity(node.child_count() * 2);
192        let mut duplicate_names = HashSet::<_>::new();
193
194        for child in (0..node.child_count()).map(|i| node.child(i)) {
195            let child_dtype = child.node_dtype()?;
196            if !child_dtype.is_struct() {
197                vortex_bail!(
198                    "Merge child must return a non-nullable struct dtype, got {}",
199                    child_dtype
200                )
201            }
202
203            let child_dtype = child_dtype
204                .as_struct_fields_opt()
205                .vortex_expect("expected struct");
206
207            for name in child_dtype.names().iter() {
208                if let Some(idx) = names.iter().position(|n| n == name) {
209                    duplicate_names.insert(name.clone());
210                    children[idx] = child.clone();
211                } else {
212                    names.push(name.clone());
213                    children.push(child.clone());
214                }
215            }
216
217            if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
218                vortex_bail!(
219                    "merge: duplicate fields in children: {}",
220                    duplicate_names.into_iter().format(", ")
221                )
222            }
223        }
224
225        let pack_children: Vec<_> = names
226            .iter()
227            .zip(children)
228            .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
229            .try_collect()?;
230
231        let pack_expr = ctx.new_node(
232            Pack.bind(PackOptions {
233                names: FieldNames::from(names),
234                nullability: node.node_dtype()?.nullability(),
235            }),
236            &pack_children,
237        )?;
238
239        Ok(Some(pack_expr))
240    }
241
242    fn validity(
243        &self,
244        _options: &Self::Options,
245        _expression: &Expression,
246    ) -> VortexResult<Option<Expression>> {
247        Ok(Some(lit(true)))
248    }
249
250    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
251        true
252    }
253
254    fn is_fallible(&self, instance: &Self::Options) -> bool {
255        matches!(instance, DuplicateHandling::Error)
256    }
257}
258
259/// What to do when merged structs share a field name.
260#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
261pub enum DuplicateHandling {
262    /// If two structs share a field name, take the value from the right-most struct.
263    RightMost,
264    /// If two structs share a field name, error.
265    #[default]
266    Error,
267}
268
269impl Display for DuplicateHandling {
270    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
271        match self {
272            DuplicateHandling::RightMost => write!(f, "RightMost"),
273            DuplicateHandling::Error => write!(f, "Error"),
274        }
275    }
276}
277
278/// Creates an expression that merges struct expressions into a single struct.
279///
280/// Combines fields from all input expressions. If field names are duplicated,
281/// later expressions win. Fields are not recursively merged.
282///
283/// ```rust
284/// # use vortex_dtype::Nullability;
285/// # use vortex_array::expr::{merge, get_item, root};
286/// let expr = merge([get_item("a", root()), get_item("b", root())]);
287/// ```
288pub fn merge(elements: impl IntoIterator<Item = impl Into<Expression>>) -> Expression {
289    let values = elements.into_iter().map(|value| value.into()).collect_vec();
290    Merge.new_expr(DuplicateHandling::default(), values)
291}
292
293pub fn merge_opts(
294    elements: impl IntoIterator<Item = impl Into<Expression>>,
295    duplicate_handling: DuplicateHandling,
296) -> Expression {
297    let values = elements.into_iter().map(|value| value.into()).collect_vec();
298    Merge.new_expr(duplicate_handling, values)
299}
300
301#[cfg(test)]
302mod tests {
303    use vortex_buffer::buffer;
304    use vortex_dtype::DType;
305    use vortex_dtype::Nullability::NonNullable;
306    use vortex_dtype::PType::I32;
307    use vortex_dtype::PType::I64;
308    use vortex_dtype::PType::U32;
309    use vortex_dtype::PType::U64;
310    use vortex_error::VortexResult;
311    use vortex_error::vortex_bail;
312
313    use super::merge;
314    use crate::Array;
315    use crate::IntoArray;
316    use crate::ToCanonical;
317    use crate::arrays::PrimitiveArray;
318    use crate::arrays::StructArray;
319    use crate::assert_arrays_eq;
320    use crate::expr::Expression;
321    use crate::expr::Pack;
322    use crate::expr::exprs::get_item::get_item;
323    use crate::expr::exprs::merge::DuplicateHandling;
324    use crate::expr::exprs::merge::merge_opts;
325    use crate::expr::exprs::root::root;
326
327    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
328        let mut field_path = field_path.iter();
329
330        let Some(field) = field_path.next() else {
331            vortex_bail!("empty field path");
332        };
333
334        let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
335        for field in field_path {
336            array = array.to_struct().unmasked_field_by_name(field)?.clone();
337        }
338        Ok(array.to_primitive())
339    }
340
341    #[test]
342    pub fn test_merge_right_most() {
343        let expr = merge_opts(
344            vec![
345                get_item("0", root()),
346                get_item("1", root()),
347                get_item("2", root()),
348            ],
349            DuplicateHandling::RightMost,
350        );
351
352        let test_array = StructArray::from_fields(&[
353            (
354                "0",
355                StructArray::from_fields(&[
356                    ("a", buffer![0, 0, 0].into_array()),
357                    ("b", buffer![1, 1, 1].into_array()),
358                ])
359                .unwrap()
360                .into_array(),
361            ),
362            (
363                "1",
364                StructArray::from_fields(&[
365                    ("b", buffer![2, 2, 2].into_array()),
366                    ("c", buffer![3, 3, 3].into_array()),
367                ])
368                .unwrap()
369                .into_array(),
370            ),
371            (
372                "2",
373                StructArray::from_fields(&[
374                    ("d", buffer![4, 4, 4].into_array()),
375                    ("e", buffer![5, 5, 5].into_array()),
376                ])
377                .unwrap()
378                .into_array(),
379            ),
380        ])
381        .unwrap()
382        .into_array();
383        let actual_array = test_array.apply(&expr).unwrap();
384
385        assert_eq!(
386            actual_array.as_struct_typed().names(),
387            ["a", "b", "c", "d", "e"]
388        );
389
390        assert_arrays_eq!(
391            primitive_field(&actual_array, &["a"]).unwrap(),
392            PrimitiveArray::from_iter([0i32, 0, 0])
393        );
394        assert_arrays_eq!(
395            primitive_field(&actual_array, &["b"]).unwrap(),
396            PrimitiveArray::from_iter([2i32, 2, 2])
397        );
398        assert_arrays_eq!(
399            primitive_field(&actual_array, &["c"]).unwrap(),
400            PrimitiveArray::from_iter([3i32, 3, 3])
401        );
402        assert_arrays_eq!(
403            primitive_field(&actual_array, &["d"]).unwrap(),
404            PrimitiveArray::from_iter([4i32, 4, 4])
405        );
406        assert_arrays_eq!(
407            primitive_field(&actual_array, &["e"]).unwrap(),
408            PrimitiveArray::from_iter([5i32, 5, 5])
409        );
410    }
411
412    #[test]
413    #[should_panic(expected = "merge: duplicate fields in children")]
414    pub fn test_merge_error_on_dupe_return_dtype() {
415        let expr = merge_opts(
416            vec![get_item("0", root()), get_item("1", root())],
417            DuplicateHandling::Error,
418        );
419        let test_array = StructArray::try_from_iter([
420            (
421                "0",
422                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
423            ),
424            (
425                "1",
426                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
427            ),
428        ])
429        .unwrap()
430        .into_array();
431
432        expr.return_dtype(test_array.dtype()).unwrap();
433    }
434
435    #[test]
436    #[should_panic(expected = "merge: duplicate fields in children")]
437    pub fn test_merge_error_on_dupe_evaluate() {
438        let expr = merge_opts(
439            vec![get_item("0", root()), get_item("1", root())],
440            DuplicateHandling::Error,
441        );
442        let test_array = StructArray::try_from_iter([
443            (
444                "0",
445                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
446            ),
447            (
448                "1",
449                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
450            ),
451        ])
452        .unwrap()
453        .into_array();
454
455        test_array.apply(&expr).unwrap();
456    }
457
458    #[test]
459    pub fn test_empty_merge() {
460        let expr = merge(Vec::<Expression>::new());
461
462        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
463            .unwrap()
464            .into_array();
465        let actual_array = test_array.clone().apply(&expr).unwrap();
466        assert_eq!(actual_array.len(), test_array.len());
467        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
468    }
469
470    #[test]
471    pub fn test_nested_merge() {
472        // Nested structs are not merged!
473
474        let expr = merge_opts(
475            vec![get_item("0", root()), get_item("1", root())],
476            DuplicateHandling::RightMost,
477        );
478
479        let test_array = StructArray::from_fields(&[
480            (
481                "0",
482                StructArray::from_fields(&[(
483                    "a",
484                    StructArray::from_fields(&[
485                        ("x", buffer![0, 0, 0].into_array()),
486                        ("y", buffer![1, 1, 1].into_array()),
487                    ])
488                    .unwrap()
489                    .into_array(),
490                )])
491                .unwrap()
492                .into_array(),
493            ),
494            (
495                "1",
496                StructArray::from_fields(&[(
497                    "a",
498                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
499                        .unwrap()
500                        .into_array(),
501                )])
502                .unwrap()
503                .into_array(),
504            ),
505        ])
506        .unwrap()
507        .into_array();
508        let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
509
510        assert_eq!(
511            actual_array
512                .unmasked_field_by_name("a")
513                .unwrap()
514                .to_struct()
515                .names()
516                .iter()
517                .map(|name| name.as_ref())
518                .collect::<Vec<_>>(),
519            vec!["x"]
520        );
521    }
522
523    #[test]
524    pub fn test_merge_order() {
525        let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
526
527        let test_array = StructArray::from_fields(&[
528            (
529                "0",
530                StructArray::from_fields(&[
531                    ("a", buffer![0, 0, 0].into_array()),
532                    ("c", buffer![1, 1, 1].into_array()),
533                ])
534                .unwrap()
535                .into_array(),
536            ),
537            (
538                "1",
539                StructArray::from_fields(&[
540                    ("b", buffer![2, 2, 2].into_array()),
541                    ("d", buffer![3, 3, 3].into_array()),
542                ])
543                .unwrap()
544                .into_array(),
545            ),
546        ])
547        .unwrap()
548        .into_array();
549        let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
550
551        assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
552    }
553
554    #[test]
555    pub fn test_display() {
556        let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
557        assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
558
559        let expr2 = merge(vec![get_item("a", root())]);
560        assert_eq!(expr2.to_string(), "merge($.a)");
561    }
562
563    #[test]
564    fn test_remove_merge() {
565        let dtype = DType::struct_(
566            [
567                ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
568                ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
569            ],
570            NonNullable,
571        );
572
573        let e = merge_opts(
574            [get_item("0", root()), get_item("1", root())],
575            DuplicateHandling::RightMost,
576        );
577
578        let result = e.optimize(&dtype).unwrap();
579
580        assert!(result.is::<Pack>());
581        assert_eq!(
582            result.return_dtype(&dtype).unwrap(),
583            DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
584        );
585    }
586}