Skip to main content

vortex_array/scalar_fn/fns/
merge.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_session::VortexSession;
14use vortex_utils::aliases::hash_set::HashSet;
15
16use crate::ArrayRef;
17use crate::IntoArray as _;
18use crate::arrays::StructArray;
19use crate::dtype::DType;
20use crate::dtype::FieldNames;
21use crate::dtype::Nullability;
22use crate::dtype::StructFields;
23use crate::expr::Expression;
24use crate::expr::lit;
25use crate::scalar_fn::Arity;
26use crate::scalar_fn::ChildName;
27use crate::scalar_fn::ExecutionArgs;
28use crate::scalar_fn::ReduceCtx;
29use crate::scalar_fn::ReduceNode;
30use crate::scalar_fn::ReduceNodeRef;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33use crate::scalar_fn::ScalarFnVTableExt;
34use crate::scalar_fn::fns::get_item::GetItem;
35use crate::scalar_fn::fns::pack::Pack;
36use crate::scalar_fn::fns::pack::PackOptions;
37use crate::validity::Validity;
38
39/// Merge zero or more expressions that ALL return structs.
40///
41/// If any field names are duplicated, the field from later expressions wins.
42///
43/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
44/// This makes struct fields behaviour consistent with other dtypes.
45#[derive(Clone)]
46pub struct Merge;
47
48impl ScalarFnVTable for Merge {
49    type Options = DuplicateHandling;
50
51    fn id(&self) -> ScalarFnId {
52        ScalarFnId::new_ref("vortex.merge")
53    }
54
55    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
56        Ok(Some(match instance {
57            DuplicateHandling::RightMost => vec![0x00],
58            DuplicateHandling::Error => vec![0x01],
59        }))
60    }
61
62    fn deserialize(
63        &self,
64        _metadata: &[u8],
65        _session: &VortexSession,
66    ) -> VortexResult<Self::Options> {
67        let instance = match _metadata {
68            [0x00] => DuplicateHandling::RightMost,
69            [0x01] => DuplicateHandling::Error,
70            _ => {
71                vortex_bail!("invalid metadata for Merge expression");
72            }
73        };
74        Ok(instance)
75    }
76
77    fn arity(&self, _options: &Self::Options) -> Arity {
78        Arity::Variadic { min: 0, max: None }
79    }
80
81    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
82        ChildName::from(Arc::from(format!("{}", child_idx)))
83    }
84
85    fn fmt_sql(
86        &self,
87        _options: &Self::Options,
88        expr: &Expression,
89        f: &mut Formatter<'_>,
90    ) -> std::fmt::Result {
91        write!(f, "merge(")?;
92        for (i, child) in expr.children().iter().enumerate() {
93            child.fmt_sql(f)?;
94            if i + 1 < expr.children().len() {
95                write!(f, ", ")?;
96            }
97        }
98        write!(f, ")")
99    }
100
101    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
102        let mut field_names = Vec::new();
103        let mut arrays = Vec::new();
104        let mut merge_nullability = Nullability::NonNullable;
105        let mut duplicate_names = HashSet::<_>::new();
106
107        for dtype in arg_dtypes {
108            let Some(fields) = dtype.as_struct_fields_opt() else {
109                vortex_bail!("merge expects struct input");
110            };
111            if dtype.is_nullable() {
112                vortex_bail!("merge expects non-nullable input");
113            }
114
115            merge_nullability |= dtype.nullability();
116
117            for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
118                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
119                    duplicate_names.insert(field_name.clone());
120                    arrays[idx] = field_dtype;
121                } else {
122                    field_names.push(field_name.clone());
123                    arrays.push(field_dtype);
124                }
125            }
126        }
127
128        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
129            vortex_bail!(
130                "merge: duplicate fields in children: {}",
131                duplicate_names.into_iter().format(", ")
132            )
133        }
134
135        Ok(DType::Struct(
136            StructFields::new(FieldNames::from(field_names), arrays),
137            merge_nullability,
138        ))
139    }
140
141    fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
142        // Collect fields in order of appearance. Later fields overwrite earlier fields.
143        let mut field_names = Vec::new();
144        let mut arrays = Vec::new();
145        let mut duplicate_names = HashSet::<_>::new();
146
147        for input in args.inputs {
148            let array = input.execute::<StructArray>(args.ctx)?;
149            if array.dtype().is_nullable() {
150                vortex_bail!("merge expects non-nullable input");
151            }
152
153            for (field_name, field_array) in array
154                .names()
155                .iter()
156                .zip_eq(array.unmasked_fields().iter().cloned())
157            {
158                // Update or insert field.
159                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
160                    duplicate_names.insert(field_name.clone());
161                    arrays[idx] = field_array;
162                } else {
163                    field_names.push(field_name.clone());
164                    arrays.push(field_array);
165                }
166            }
167        }
168
169        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
170            vortex_bail!(
171                "merge: duplicate fields in children: {}",
172                duplicate_names.into_iter().format(", ")
173            )
174        }
175
176        // TODO(DK): When children are allowed to be nullable, this needs to change.
177        let validity = Validity::NonNullable;
178        let len = args.row_count;
179        Ok(
180            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
181                .into_array(),
182        )
183    }
184
185    fn reduce(
186        &self,
187        options: &Self::Options,
188        node: &dyn ReduceNode,
189        ctx: &dyn ReduceCtx,
190    ) -> VortexResult<Option<ReduceNodeRef>> {
191        let mut names = Vec::with_capacity(node.child_count() * 2);
192        let mut children = Vec::with_capacity(node.child_count() * 2);
193        let mut duplicate_names = HashSet::<_>::new();
194
195        for child in (0..node.child_count()).map(|i| node.child(i)) {
196            let child_dtype = child.node_dtype()?;
197            if !child_dtype.is_struct() {
198                vortex_bail!(
199                    "Merge child must return a non-nullable struct dtype, got {}",
200                    child_dtype
201                )
202            }
203
204            let child_dtype = child_dtype
205                .as_struct_fields_opt()
206                .vortex_expect("expected struct");
207
208            for name in child_dtype.names().iter() {
209                if let Some(idx) = names.iter().position(|n| n == name) {
210                    duplicate_names.insert(name.clone());
211                    children[idx] = child.clone();
212                } else {
213                    names.push(name.clone());
214                    children.push(child.clone());
215                }
216            }
217
218            if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
219                vortex_bail!(
220                    "merge: duplicate fields in children: {}",
221                    duplicate_names.into_iter().format(", ")
222                )
223            }
224        }
225
226        let pack_children: Vec<_> = names
227            .iter()
228            .zip(children)
229            .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
230            .try_collect()?;
231
232        let pack_expr = ctx.new_node(
233            Pack.bind(PackOptions {
234                names: FieldNames::from(names),
235                nullability: node.node_dtype()?.nullability(),
236            }),
237            &pack_children,
238        )?;
239
240        Ok(Some(pack_expr))
241    }
242
243    fn validity(
244        &self,
245        _options: &Self::Options,
246        _expression: &Expression,
247    ) -> VortexResult<Option<Expression>> {
248        Ok(Some(lit(true)))
249    }
250
251    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
252        true
253    }
254
255    fn is_fallible(&self, instance: &Self::Options) -> bool {
256        matches!(instance, DuplicateHandling::Error)
257    }
258}
259
260/// What to do when merged structs share a field name.
261#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
262pub enum DuplicateHandling {
263    /// If two structs share a field name, take the value from the right-most struct.
264    RightMost,
265    /// If two structs share a field name, error.
266    #[default]
267    Error,
268}
269
270impl Display for DuplicateHandling {
271    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
272        match self {
273            DuplicateHandling::RightMost => write!(f, "RightMost"),
274            DuplicateHandling::Error => write!(f, "Error"),
275        }
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use vortex_buffer::buffer;
282    use vortex_error::VortexResult;
283    use vortex_error::vortex_bail;
284
285    use crate::Array;
286    use crate::IntoArray;
287    use crate::ToCanonical;
288    use crate::arrays::PrimitiveArray;
289    use crate::arrays::StructArray;
290    use crate::assert_arrays_eq;
291    use crate::dtype::DType;
292    use crate::dtype::Nullability::NonNullable;
293    use crate::dtype::PType::I32;
294    use crate::dtype::PType::I64;
295    use crate::dtype::PType::U32;
296    use crate::dtype::PType::U64;
297    use crate::expr::Expression;
298    use crate::expr::get_item;
299    use crate::expr::merge;
300    use crate::expr::merge_opts;
301    use crate::expr::root;
302    use crate::scalar_fn::fns::merge::DuplicateHandling;
303    use crate::scalar_fn::fns::pack::Pack;
304
305    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
306        let mut field_path = field_path.iter();
307
308        let Some(field) = field_path.next() else {
309            vortex_bail!("empty field path");
310        };
311
312        let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
313        for field in field_path {
314            array = array.to_struct().unmasked_field_by_name(field)?.clone();
315        }
316        Ok(array.to_primitive())
317    }
318
319    #[test]
320    pub fn test_merge_right_most() {
321        let expr = merge_opts(
322            vec![
323                get_item("0", root()),
324                get_item("1", root()),
325                get_item("2", root()),
326            ],
327            DuplicateHandling::RightMost,
328        );
329
330        let test_array = StructArray::from_fields(&[
331            (
332                "0",
333                StructArray::from_fields(&[
334                    ("a", buffer![0, 0, 0].into_array()),
335                    ("b", buffer![1, 1, 1].into_array()),
336                ])
337                .unwrap()
338                .into_array(),
339            ),
340            (
341                "1",
342                StructArray::from_fields(&[
343                    ("b", buffer![2, 2, 2].into_array()),
344                    ("c", buffer![3, 3, 3].into_array()),
345                ])
346                .unwrap()
347                .into_array(),
348            ),
349            (
350                "2",
351                StructArray::from_fields(&[
352                    ("d", buffer![4, 4, 4].into_array()),
353                    ("e", buffer![5, 5, 5].into_array()),
354                ])
355                .unwrap()
356                .into_array(),
357            ),
358        ])
359        .unwrap()
360        .into_array();
361        let actual_array = test_array.apply(&expr).unwrap();
362
363        assert_eq!(
364            actual_array.as_struct_typed().names(),
365            ["a", "b", "c", "d", "e"]
366        );
367
368        assert_arrays_eq!(
369            primitive_field(&actual_array, &["a"]).unwrap(),
370            PrimitiveArray::from_iter([0i32, 0, 0])
371        );
372        assert_arrays_eq!(
373            primitive_field(&actual_array, &["b"]).unwrap(),
374            PrimitiveArray::from_iter([2i32, 2, 2])
375        );
376        assert_arrays_eq!(
377            primitive_field(&actual_array, &["c"]).unwrap(),
378            PrimitiveArray::from_iter([3i32, 3, 3])
379        );
380        assert_arrays_eq!(
381            primitive_field(&actual_array, &["d"]).unwrap(),
382            PrimitiveArray::from_iter([4i32, 4, 4])
383        );
384        assert_arrays_eq!(
385            primitive_field(&actual_array, &["e"]).unwrap(),
386            PrimitiveArray::from_iter([5i32, 5, 5])
387        );
388    }
389
390    #[test]
391    #[should_panic(expected = "merge: duplicate fields in children")]
392    pub fn test_merge_error_on_dupe_return_dtype() {
393        let expr = merge_opts(
394            vec![get_item("0", root()), get_item("1", root())],
395            DuplicateHandling::Error,
396        );
397        let test_array = StructArray::try_from_iter([
398            (
399                "0",
400                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
401            ),
402            (
403                "1",
404                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
405            ),
406        ])
407        .unwrap()
408        .into_array();
409
410        expr.return_dtype(test_array.dtype()).unwrap();
411    }
412
413    #[test]
414    #[should_panic(expected = "merge: duplicate fields in children")]
415    pub fn test_merge_error_on_dupe_evaluate() {
416        let expr = merge_opts(
417            vec![get_item("0", root()), get_item("1", root())],
418            DuplicateHandling::Error,
419        );
420        let test_array = StructArray::try_from_iter([
421            (
422                "0",
423                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
424            ),
425            (
426                "1",
427                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
428            ),
429        ])
430        .unwrap()
431        .into_array();
432
433        test_array.apply(&expr).unwrap();
434    }
435
436    #[test]
437    pub fn test_empty_merge() {
438        let expr = merge(Vec::<Expression>::new());
439
440        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
441            .unwrap()
442            .into_array();
443        let actual_array = test_array.clone().apply(&expr).unwrap();
444        assert_eq!(actual_array.len(), test_array.len());
445        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
446    }
447
448    #[test]
449    pub fn test_nested_merge() {
450        // Nested structs are not merged!
451
452        let expr = merge_opts(
453            vec![get_item("0", root()), get_item("1", root())],
454            DuplicateHandling::RightMost,
455        );
456
457        let test_array = StructArray::from_fields(&[
458            (
459                "0",
460                StructArray::from_fields(&[(
461                    "a",
462                    StructArray::from_fields(&[
463                        ("x", buffer![0, 0, 0].into_array()),
464                        ("y", buffer![1, 1, 1].into_array()),
465                    ])
466                    .unwrap()
467                    .into_array(),
468                )])
469                .unwrap()
470                .into_array(),
471            ),
472            (
473                "1",
474                StructArray::from_fields(&[(
475                    "a",
476                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
477                        .unwrap()
478                        .into_array(),
479                )])
480                .unwrap()
481                .into_array(),
482            ),
483        ])
484        .unwrap()
485        .into_array();
486        let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
487
488        assert_eq!(
489            actual_array
490                .unmasked_field_by_name("a")
491                .unwrap()
492                .to_struct()
493                .names()
494                .iter()
495                .map(|name| name.as_ref())
496                .collect::<Vec<_>>(),
497            vec!["x"]
498        );
499    }
500
501    #[test]
502    pub fn test_merge_order() {
503        let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
504
505        let test_array = StructArray::from_fields(&[
506            (
507                "0",
508                StructArray::from_fields(&[
509                    ("a", buffer![0, 0, 0].into_array()),
510                    ("c", buffer![1, 1, 1].into_array()),
511                ])
512                .unwrap()
513                .into_array(),
514            ),
515            (
516                "1",
517                StructArray::from_fields(&[
518                    ("b", buffer![2, 2, 2].into_array()),
519                    ("d", buffer![3, 3, 3].into_array()),
520                ])
521                .unwrap()
522                .into_array(),
523            ),
524        ])
525        .unwrap()
526        .into_array();
527        let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
528
529        assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
530    }
531
532    #[test]
533    pub fn test_display() {
534        let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
535        assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
536
537        let expr2 = merge(vec![get_item("a", root())]);
538        assert_eq!(expr2.to_string(), "merge($.a)");
539    }
540
541    #[test]
542    fn test_remove_merge() {
543        let dtype = DType::struct_(
544            [
545                ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
546                ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
547            ],
548            NonNullable,
549        );
550
551        let e = merge_opts(
552            [get_item("0", root()), get_item("1", root())],
553            DuplicateHandling::RightMost,
554        );
555
556        let result = e.optimize(&dtype).unwrap();
557
558        assert!(result.is::<Pack>());
559        assert_eq!(
560            result.return_dtype(&dtype).unwrap(),
561            DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
562        );
563    }
564}