vortex_array/expr/exprs/
merge.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_dtype::DType;
11use vortex_dtype::FieldNames;
12use vortex_dtype::Nullability;
13use vortex_dtype::StructFields;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_utils::aliases::hash_set::HashSet;
18use vortex_vector::Datum;
19
20use crate::Array;
21use crate::ArrayRef;
22use crate::IntoArray as _;
23use crate::ToCanonical;
24use crate::arrays::StructArray;
25use crate::expr::Arity;
26use crate::expr::ChildName;
27use crate::expr::ExecutionArgs;
28use crate::expr::ExprId;
29use crate::expr::Expression;
30use crate::expr::GetItem;
31use crate::expr::Pack;
32use crate::expr::PackOptions;
33use crate::expr::ReduceCtx;
34use crate::expr::ReduceNode;
35use crate::expr::ReduceNodeRef;
36use crate::expr::VTable;
37use crate::expr::VTableExt;
38use crate::validity::Validity;
39
40/// Merge zero or more expressions that ALL return structs.
41///
42/// If any field names are duplicated, the field from later expressions wins.
43///
44/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
45/// This makes struct fields behaviour consistent with other dtypes.
46pub struct Merge;
47
48impl VTable for Merge {
49    type Options = DuplicateHandling;
50
51    fn id(&self) -> ExprId {
52        ExprId::new_ref("vortex.merge")
53    }
54
55    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
56        Ok(Some(match instance {
57            DuplicateHandling::RightMost => vec![0x00],
58            DuplicateHandling::Error => vec![0x01],
59        }))
60    }
61
62    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
63        let instance = match metadata {
64            [0x00] => DuplicateHandling::RightMost,
65            [0x01] => DuplicateHandling::Error,
66            _ => {
67                vortex_bail!("invalid metadata for Merge expression");
68            }
69        };
70        Ok(instance)
71    }
72
73    fn arity(&self, _options: &Self::Options) -> Arity {
74        Arity::Variadic { min: 0, max: None }
75    }
76
77    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
78        ChildName::from(Arc::from(format!("{}", child_idx)))
79    }
80
81    fn fmt_sql(
82        &self,
83        _options: &Self::Options,
84        expr: &Expression,
85        f: &mut Formatter<'_>,
86    ) -> std::fmt::Result {
87        write!(f, "merge(")?;
88        for (i, child) in expr.children().iter().enumerate() {
89            child.fmt_sql(f)?;
90            if i + 1 < expr.children().len() {
91                write!(f, ", ")?;
92            }
93        }
94        write!(f, ")")
95    }
96
97    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
98        let mut field_names = Vec::new();
99        let mut arrays = Vec::new();
100        let mut merge_nullability = Nullability::NonNullable;
101        let mut duplicate_names = HashSet::<_>::new();
102
103        for dtype in arg_dtypes {
104            let Some(fields) = dtype.as_struct_fields_opt() else {
105                vortex_bail!("merge expects struct input");
106            };
107            if dtype.is_nullable() {
108                vortex_bail!("merge expects non-nullable input");
109            }
110
111            merge_nullability |= dtype.nullability();
112
113            for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
114                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
115                    duplicate_names.insert(field_name.clone());
116                    arrays[idx] = field_dtype;
117                } else {
118                    field_names.push(field_name.clone());
119                    arrays.push(field_dtype);
120                }
121            }
122        }
123
124        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
125            vortex_bail!(
126                "merge: duplicate fields in children: {}",
127                duplicate_names.into_iter().format(", ")
128            )
129        }
130
131        Ok(DType::Struct(
132            StructFields::new(FieldNames::from(field_names), arrays),
133            merge_nullability,
134        ))
135    }
136
137    fn evaluate(
138        &self,
139        options: &Self::Options,
140        expr: &Expression,
141        scope: &ArrayRef,
142    ) -> VortexResult<ArrayRef> {
143        // Collect fields in order of appearance. Later fields overwrite earlier fields.
144        let mut field_names = Vec::new();
145        let mut arrays = Vec::new();
146        let mut duplicate_names = HashSet::<_>::new();
147
148        for child in expr.children().iter() {
149            // TODO(marko): When nullable, we need to merge struct validity into field validity.
150            let array = child.evaluate(scope)?;
151            if array.dtype().is_nullable() {
152                vortex_bail!("merge expects non-nullable input");
153            }
154            if !array.dtype().is_struct() {
155                vortex_bail!("merge expects struct input");
156            }
157            let array = array.to_struct();
158
159            for (field_name, array) in array.names().iter().zip_eq(array.fields().iter().cloned()) {
160                // Update or insert field.
161                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
162                    duplicate_names.insert(field_name.clone());
163                    arrays[idx] = array;
164                } else {
165                    field_names.push(field_name.clone());
166                    arrays.push(array);
167                }
168            }
169        }
170
171        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
172            vortex_bail!(
173                "merge: duplicate fields in children: {}",
174                duplicate_names.into_iter().format(", ")
175            )
176        }
177
178        // TODO(DK): When children are allowed to be nullable, this needs to change.
179        let validity = Validity::NonNullable;
180        let len = scope.len();
181        Ok(
182            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
183                .into_array(),
184        )
185    }
186
187    fn execute(&self, _data: &Self::Options, _args: ExecutionArgs) -> VortexResult<Datum> {
188        todo!()
189    }
190
191    fn reduce(
192        &self,
193        options: &Self::Options,
194        node: &dyn ReduceNode,
195        ctx: &dyn ReduceCtx,
196    ) -> VortexResult<Option<ReduceNodeRef>> {
197        let merge_dtype = node.node_dtype()?;
198        let mut names = Vec::with_capacity(node.child_count() * 2);
199        let mut children = Vec::with_capacity(node.child_count() * 2);
200        let mut duplicate_names = HashSet::<_>::new();
201
202        for child in node.children() {
203            let child_dtype = child.node_dtype()?;
204            if !child_dtype.is_struct() {
205                vortex_bail!(
206                    "Merge child must return a non-nullable struct dtype, got {}",
207                    child_dtype
208                )
209            }
210
211            let child_dtype = child_dtype
212                .as_struct_fields_opt()
213                .vortex_expect("expected struct");
214
215            for name in child_dtype.names().iter() {
216                if let Some(idx) = names.iter().position(|n| n == name) {
217                    duplicate_names.insert(name.clone());
218                    children[idx] = child.clone();
219                } else {
220                    names.push(name.clone());
221                    children.push(child.clone());
222                }
223            }
224
225            if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
226                vortex_bail!(
227                    "merge: duplicate fields in children: {}",
228                    duplicate_names.into_iter().format(", ")
229                )
230            }
231        }
232
233        let pack_children: Vec<_> = names
234            .iter()
235            .zip(children)
236            .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
237            .try_collect()?;
238
239        let pack_expr = ctx.new_node(
240            Pack.bind(PackOptions {
241                names: FieldNames::from(names),
242                nullability: merge_dtype.nullability(),
243            }),
244            &pack_children,
245        )?;
246
247        Ok(Some(pack_expr))
248    }
249
250    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
251        true
252    }
253
254    fn is_fallible(&self, instance: &Self::Options) -> bool {
255        matches!(instance, DuplicateHandling::Error)
256    }
257}
258
259/// What to do when merged structs share a field name.
260#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
261pub enum DuplicateHandling {
262    /// If two structs share a field name, take the value from the right-most struct.
263    RightMost,
264    /// If two structs share a field name, error.
265    #[default]
266    Error,
267}
268
269impl Display for DuplicateHandling {
270    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
271        match self {
272            DuplicateHandling::RightMost => write!(f, "RightMost"),
273            DuplicateHandling::Error => write!(f, "Error"),
274        }
275    }
276}
277
278/// Creates an expression that merges struct expressions into a single struct.
279///
280/// Combines fields from all input expressions. If field names are duplicated,
281/// later expressions win. Fields are not recursively merged.
282///
283/// ```rust
284/// # use vortex_dtype::Nullability;
285/// # use vortex_array::expr::{merge, get_item, root};
286/// let expr = merge([get_item("a", root()), get_item("b", root())]);
287/// ```
288pub fn merge(elements: impl IntoIterator<Item = impl Into<Expression>>) -> Expression {
289    let values = elements.into_iter().map(|value| value.into()).collect_vec();
290    Merge.new_expr(DuplicateHandling::default(), values)
291}
292
293pub fn merge_opts(
294    elements: impl IntoIterator<Item = impl Into<Expression>>,
295    duplicate_handling: DuplicateHandling,
296) -> Expression {
297    let values = elements.into_iter().map(|value| value.into()).collect_vec();
298    Merge.new_expr(duplicate_handling, values)
299}
300
301#[cfg(test)]
302mod tests {
303    use vortex_buffer::buffer;
304    use vortex_dtype::DType;
305    use vortex_dtype::Nullability::NonNullable;
306    use vortex_dtype::PType::I32;
307    use vortex_dtype::PType::I64;
308    use vortex_dtype::PType::U32;
309    use vortex_dtype::PType::U64;
310    use vortex_error::VortexResult;
311    use vortex_error::vortex_bail;
312
313    use super::merge;
314    use crate::Array;
315    use crate::IntoArray;
316    use crate::ToCanonical;
317    use crate::arrays::PrimitiveArray;
318    use crate::arrays::StructArray;
319    use crate::expr::Expression;
320    use crate::expr::Pack;
321    use crate::expr::exprs::get_item::get_item;
322    use crate::expr::exprs::merge::DuplicateHandling;
323    use crate::expr::exprs::merge::merge_opts;
324    use crate::expr::exprs::root::root;
325
326    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
327        let mut field_path = field_path.iter();
328
329        let Some(field) = field_path.next() else {
330            vortex_bail!("empty field path");
331        };
332
333        let mut array = array.to_struct().field_by_name(field)?.clone();
334        for field in field_path {
335            array = array.to_struct().field_by_name(field)?.clone();
336        }
337        Ok(array.to_primitive())
338    }
339
340    #[test]
341    pub fn test_merge_right_most() {
342        let expr = merge_opts(
343            vec![
344                get_item("0", root()),
345                get_item("1", root()),
346                get_item("2", root()),
347            ],
348            DuplicateHandling::RightMost,
349        );
350
351        let test_array = StructArray::from_fields(&[
352            (
353                "0",
354                StructArray::from_fields(&[
355                    ("a", buffer![0, 0, 0].into_array()),
356                    ("b", buffer![1, 1, 1].into_array()),
357                ])
358                .unwrap()
359                .into_array(),
360            ),
361            (
362                "1",
363                StructArray::from_fields(&[
364                    ("b", buffer![2, 2, 2].into_array()),
365                    ("c", buffer![3, 3, 3].into_array()),
366                ])
367                .unwrap()
368                .into_array(),
369            ),
370            (
371                "2",
372                StructArray::from_fields(&[
373                    ("d", buffer![4, 4, 4].into_array()),
374                    ("e", buffer![5, 5, 5].into_array()),
375                ])
376                .unwrap()
377                .into_array(),
378            ),
379        ])
380        .unwrap()
381        .into_array();
382        let actual_array = expr.evaluate(&test_array).unwrap();
383
384        assert_eq!(
385            actual_array.as_struct_typed().names(),
386            ["a", "b", "c", "d", "e"]
387        );
388
389        assert_eq!(
390            primitive_field(&actual_array, &["a"])
391                .unwrap()
392                .as_slice::<i32>(),
393            [0, 0, 0]
394        );
395        assert_eq!(
396            primitive_field(&actual_array, &["b"])
397                .unwrap()
398                .as_slice::<i32>(),
399            [2, 2, 2]
400        );
401        assert_eq!(
402            primitive_field(&actual_array, &["c"])
403                .unwrap()
404                .as_slice::<i32>(),
405            [3, 3, 3]
406        );
407        assert_eq!(
408            primitive_field(&actual_array, &["d"])
409                .unwrap()
410                .as_slice::<i32>(),
411            [4, 4, 4]
412        );
413        assert_eq!(
414            primitive_field(&actual_array, &["e"])
415                .unwrap()
416                .as_slice::<i32>(),
417            [5, 5, 5]
418        );
419    }
420
421    #[test]
422    #[should_panic(expected = "merge: duplicate fields in children")]
423    pub fn test_merge_error_on_dupe_return_dtype() {
424        let expr = merge_opts(
425            vec![get_item("0", root()), get_item("1", root())],
426            DuplicateHandling::Error,
427        );
428        let test_array = StructArray::try_from_iter([
429            (
430                "0",
431                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
432            ),
433            (
434                "1",
435                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
436            ),
437        ])
438        .unwrap()
439        .into_array();
440
441        expr.return_dtype(test_array.dtype()).unwrap();
442    }
443
444    #[test]
445    #[should_panic(expected = "merge: duplicate fields in children")]
446    pub fn test_merge_error_on_dupe_evaluate() {
447        let expr = merge_opts(
448            vec![get_item("0", root()), get_item("1", root())],
449            DuplicateHandling::Error,
450        );
451        let test_array = StructArray::try_from_iter([
452            (
453                "0",
454                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
455            ),
456            (
457                "1",
458                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
459            ),
460        ])
461        .unwrap()
462        .into_array();
463
464        expr.evaluate(&test_array).unwrap();
465    }
466
467    #[test]
468    pub fn test_empty_merge() {
469        let expr = merge(Vec::<Expression>::new());
470
471        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
472            .unwrap()
473            .into_array();
474        let actual_array = expr.evaluate(&test_array.clone()).unwrap();
475        assert_eq!(actual_array.len(), test_array.len());
476        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
477    }
478
479    #[test]
480    pub fn test_nested_merge() {
481        // Nested structs are not merged!
482
483        let expr = merge_opts(
484            vec![get_item("0", root()), get_item("1", root())],
485            DuplicateHandling::RightMost,
486        );
487
488        let test_array = StructArray::from_fields(&[
489            (
490                "0",
491                StructArray::from_fields(&[(
492                    "a",
493                    StructArray::from_fields(&[
494                        ("x", buffer![0, 0, 0].into_array()),
495                        ("y", buffer![1, 1, 1].into_array()),
496                    ])
497                    .unwrap()
498                    .into_array(),
499                )])
500                .unwrap()
501                .into_array(),
502            ),
503            (
504                "1",
505                StructArray::from_fields(&[(
506                    "a",
507                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
508                        .unwrap()
509                        .into_array(),
510                )])
511                .unwrap()
512                .into_array(),
513            ),
514        ])
515        .unwrap()
516        .into_array();
517        let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
518
519        assert_eq!(
520            actual_array
521                .field_by_name("a")
522                .unwrap()
523                .to_struct()
524                .names()
525                .iter()
526                .map(|name| name.as_ref())
527                .collect::<Vec<_>>(),
528            vec!["x"]
529        );
530    }
531
532    #[test]
533    pub fn test_merge_order() {
534        let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
535
536        let test_array = StructArray::from_fields(&[
537            (
538                "0",
539                StructArray::from_fields(&[
540                    ("a", buffer![0, 0, 0].into_array()),
541                    ("c", buffer![1, 1, 1].into_array()),
542                ])
543                .unwrap()
544                .into_array(),
545            ),
546            (
547                "1",
548                StructArray::from_fields(&[
549                    ("b", buffer![2, 2, 2].into_array()),
550                    ("d", buffer![3, 3, 3].into_array()),
551                ])
552                .unwrap()
553                .into_array(),
554            ),
555        ])
556        .unwrap()
557        .into_array();
558        let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
559
560        assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
561    }
562
563    #[test]
564    pub fn test_display() {
565        let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
566        assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
567
568        let expr2 = merge(vec![get_item("a", root())]);
569        assert_eq!(expr2.to_string(), "merge($.a)");
570    }
571
572    #[test]
573    fn test_remove_merge() {
574        let dtype = DType::struct_(
575            [
576                ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
577                ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
578            ],
579            NonNullable,
580        );
581
582        let e = merge_opts(
583            [get_item("0", root()), get_item("1", root())],
584            DuplicateHandling::RightMost,
585        );
586
587        let result = e.optimize(&dtype).unwrap();
588
589        assert!(result.is::<Pack>());
590        assert_eq!(
591            result.return_dtype(&dtype).unwrap(),
592            DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
593        );
594    }
595}