Skip to main content

vortex_array/scalar_fn/fns/
merge.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_session::VortexSession;
14use vortex_utils::aliases::hash_set::HashSet;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray as _;
19use crate::arrays::StructArray;
20use crate::arrays::struct_::StructArrayExt;
21use crate::dtype::DType;
22use crate::dtype::FieldNames;
23use crate::dtype::Nullability;
24use crate::dtype::StructFields;
25use crate::expr::Expression;
26use crate::expr::lit;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::ExecutionArgs;
30use crate::scalar_fn::ReduceCtx;
31use crate::scalar_fn::ReduceNode;
32use crate::scalar_fn::ReduceNodeRef;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::ScalarFnVTableExt;
36use crate::scalar_fn::fns::get_item::GetItem;
37use crate::scalar_fn::fns::pack::Pack;
38use crate::scalar_fn::fns::pack::PackOptions;
39use crate::validity::Validity;
40
41/// Merge zero or more expressions that ALL return structs.
42///
43/// If any field names are duplicated, the field from later expressions wins.
44///
45/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
46/// This makes struct fields behaviour consistent with other dtypes.
47#[derive(Clone)]
48pub struct Merge;
49
50impl ScalarFnVTable for Merge {
51    type Options = DuplicateHandling;
52
53    fn id(&self) -> ScalarFnId {
54        ScalarFnId::from("vortex.merge")
55    }
56
57    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
58        Ok(Some(match instance {
59            DuplicateHandling::RightMost => vec![0x00],
60            DuplicateHandling::Error => vec![0x01],
61        }))
62    }
63
64    fn deserialize(
65        &self,
66        _metadata: &[u8],
67        _session: &VortexSession,
68    ) -> VortexResult<Self::Options> {
69        let instance = match _metadata {
70            [0x00] => DuplicateHandling::RightMost,
71            [0x01] => DuplicateHandling::Error,
72            _ => {
73                vortex_bail!("invalid metadata for Merge expression");
74            }
75        };
76        Ok(instance)
77    }
78
79    fn arity(&self, _options: &Self::Options) -> Arity {
80        Arity::Variadic { min: 0, max: None }
81    }
82
83    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
84        ChildName::from(Arc::from(format!("{}", child_idx)))
85    }
86
87    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
88        let mut field_names = Vec::new();
89        let mut arrays = Vec::new();
90        let mut merge_nullability = Nullability::NonNullable;
91        let mut duplicate_names = HashSet::<_>::new();
92
93        for dtype in arg_dtypes {
94            let Some(fields) = dtype.as_struct_fields_opt() else {
95                vortex_bail!("merge expects struct input");
96            };
97            if dtype.is_nullable() {
98                vortex_bail!("merge expects non-nullable input");
99            }
100
101            merge_nullability |= dtype.nullability();
102
103            for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
104                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
105                    duplicate_names.insert(field_name.clone());
106                    arrays[idx] = field_dtype;
107                } else {
108                    field_names.push(field_name.clone());
109                    arrays.push(field_dtype);
110                }
111            }
112        }
113
114        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
115            vortex_bail!(
116                "merge: duplicate fields in children: {}",
117                duplicate_names.into_iter().format(", ")
118            )
119        }
120
121        Ok(DType::Struct(
122            StructFields::new(FieldNames::from(field_names), arrays),
123            merge_nullability,
124        ))
125    }
126
127    fn execute(
128        &self,
129        options: &Self::Options,
130        args: &dyn ExecutionArgs,
131        ctx: &mut ExecutionCtx,
132    ) -> VortexResult<ArrayRef> {
133        // Collect fields in order of appearance. Later fields overwrite earlier fields.
134        let mut field_names = Vec::new();
135        let mut arrays = Vec::new();
136        let mut duplicate_names = HashSet::<_>::new();
137
138        for i in 0..args.num_inputs() {
139            let array = args.get(i)?.execute::<StructArray>(ctx)?;
140            if array.dtype().is_nullable() {
141                vortex_bail!("merge expects non-nullable input");
142            }
143
144            for (field_name, field_array) in array
145                .names()
146                .iter()
147                .zip_eq(array.iter_unmasked_fields().cloned())
148            {
149                // Update or insert field.
150                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
151                    duplicate_names.insert(field_name.clone());
152                    arrays[idx] = field_array;
153                } else {
154                    field_names.push(field_name.clone());
155                    arrays.push(field_array);
156                }
157            }
158        }
159
160        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
161            vortex_bail!(
162                "merge: duplicate fields in children: {}",
163                duplicate_names.into_iter().format(", ")
164            )
165        }
166
167        // TODO(DK): When children are allowed to be nullable, this needs to change.
168        let validity = Validity::NonNullable;
169        let len = args.row_count();
170        Ok(
171            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
172                .into_array(),
173        )
174    }
175
176    fn reduce(
177        &self,
178        options: &Self::Options,
179        node: &dyn ReduceNode,
180        ctx: &dyn ReduceCtx,
181    ) -> VortexResult<Option<ReduceNodeRef>> {
182        let mut names = Vec::with_capacity(node.child_count() * 2);
183        let mut children = Vec::with_capacity(node.child_count() * 2);
184        let mut duplicate_names = HashSet::<_>::new();
185
186        for child in (0..node.child_count()).map(|i| node.child(i)) {
187            let child_dtype = child.node_dtype()?;
188            if !child_dtype.is_struct() {
189                vortex_bail!(
190                    "Merge child must return a non-nullable struct dtype, got {}",
191                    child_dtype
192                )
193            }
194
195            let child_dtype = child_dtype
196                .as_struct_fields_opt()
197                .vortex_expect("expected struct");
198
199            for name in child_dtype.names().iter() {
200                if let Some(idx) = names.iter().position(|n| n == name) {
201                    duplicate_names.insert(name.clone());
202                    children[idx] = Arc::clone(&child);
203                } else {
204                    names.push(name.clone());
205                    children.push(Arc::clone(&child));
206                }
207            }
208
209            if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
210                vortex_bail!(
211                    "merge: duplicate fields in children: {}",
212                    duplicate_names.into_iter().format(", ")
213                )
214            }
215        }
216
217        let pack_children: Vec<_> = names
218            .iter()
219            .zip(children)
220            .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
221            .try_collect()?;
222
223        let pack_expr = ctx.new_node(
224            Pack.bind(PackOptions {
225                names: FieldNames::from(names),
226                nullability: node.node_dtype()?.nullability(),
227            }),
228            &pack_children,
229        )?;
230
231        Ok(Some(pack_expr))
232    }
233
234    fn validity(
235        &self,
236        _options: &Self::Options,
237        _expression: &Expression,
238    ) -> VortexResult<Option<Expression>> {
239        Ok(Some(lit(true)))
240    }
241
242    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
243        true
244    }
245
246    fn is_fallible(&self, instance: &Self::Options) -> bool {
247        matches!(instance, DuplicateHandling::Error)
248    }
249}
250
251/// What to do when merged structs share a field name.
252#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
253pub enum DuplicateHandling {
254    /// If two structs share a field name, take the value from the right-most struct.
255    RightMost,
256    /// If two structs share a field name, error.
257    #[default]
258    Error,
259}
260
261impl Display for DuplicateHandling {
262    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263        match self {
264            DuplicateHandling::RightMost => write!(f, "RightMost"),
265            DuplicateHandling::Error => write!(f, "Error"),
266        }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use vortex_buffer::buffer;
273    use vortex_error::VortexResult;
274    use vortex_error::vortex_bail;
275
276    use crate::ArrayRef;
277    use crate::IntoArray;
278    #[expect(deprecated)]
279    use crate::ToCanonical as _;
280    use crate::arrays::PrimitiveArray;
281    use crate::arrays::struct_::StructArrayExt;
282    use crate::assert_arrays_eq;
283    use crate::dtype::DType;
284    use crate::dtype::Nullability::NonNullable;
285    use crate::dtype::PType::I32;
286    use crate::dtype::PType::I64;
287    use crate::dtype::PType::U32;
288    use crate::dtype::PType::U64;
289    use crate::expr::Expression;
290    use crate::expr::get_item;
291    use crate::expr::merge;
292    use crate::expr::merge_opts;
293    use crate::expr::root;
294    use crate::scalar_fn::fns::merge::DuplicateHandling;
295    use crate::scalar_fn::fns::merge::StructArray;
296    use crate::scalar_fn::fns::pack::Pack;
297
298    fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
299        let mut field_path = field_path.iter();
300
301        let Some(field) = field_path.next() else {
302            vortex_bail!("empty field path");
303        };
304
305        #[expect(deprecated)]
306        let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
307        for field in field_path {
308            #[expect(deprecated)]
309            let next = array.to_struct().unmasked_field_by_name(field)?.clone();
310            array = next;
311        }
312        #[expect(deprecated)]
313        let result = array.to_primitive();
314        Ok(result)
315    }
316
317    #[test]
318    pub fn test_merge_right_most() {
319        let expr = merge_opts(
320            vec![
321                get_item("0", root()),
322                get_item("1", root()),
323                get_item("2", root()),
324            ],
325            DuplicateHandling::RightMost,
326        );
327
328        let test_array = StructArray::from_fields(&[
329            (
330                "0",
331                StructArray::from_fields(&[
332                    ("a", buffer![0, 0, 0].into_array()),
333                    ("b", buffer![1, 1, 1].into_array()),
334                ])
335                .unwrap()
336                .into_array(),
337            ),
338            (
339                "1",
340                StructArray::from_fields(&[
341                    ("b", buffer![2, 2, 2].into_array()),
342                    ("c", buffer![3, 3, 3].into_array()),
343                ])
344                .unwrap()
345                .into_array(),
346            ),
347            (
348                "2",
349                StructArray::from_fields(&[
350                    ("d", buffer![4, 4, 4].into_array()),
351                    ("e", buffer![5, 5, 5].into_array()),
352                ])
353                .unwrap()
354                .into_array(),
355            ),
356        ])
357        .unwrap()
358        .into_array();
359        let actual_array = test_array.apply(&expr).unwrap();
360
361        assert_eq!(
362            actual_array.as_struct_typed().names(),
363            ["a", "b", "c", "d", "e"]
364        );
365
366        assert_arrays_eq!(
367            primitive_field(&actual_array, &["a"]).unwrap(),
368            PrimitiveArray::from_iter([0i32, 0, 0])
369        );
370        assert_arrays_eq!(
371            primitive_field(&actual_array, &["b"]).unwrap(),
372            PrimitiveArray::from_iter([2i32, 2, 2])
373        );
374        assert_arrays_eq!(
375            primitive_field(&actual_array, &["c"]).unwrap(),
376            PrimitiveArray::from_iter([3i32, 3, 3])
377        );
378        assert_arrays_eq!(
379            primitive_field(&actual_array, &["d"]).unwrap(),
380            PrimitiveArray::from_iter([4i32, 4, 4])
381        );
382        assert_arrays_eq!(
383            primitive_field(&actual_array, &["e"]).unwrap(),
384            PrimitiveArray::from_iter([5i32, 5, 5])
385        );
386    }
387
388    #[test]
389    #[should_panic(expected = "merge: duplicate fields in children")]
390    pub fn test_merge_error_on_dupe_return_dtype() {
391        let expr = merge_opts(
392            vec![get_item("0", root()), get_item("1", root())],
393            DuplicateHandling::Error,
394        );
395        let test_array = StructArray::try_from_iter([
396            (
397                "0",
398                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
399            ),
400            (
401                "1",
402                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
403            ),
404        ])
405        .unwrap()
406        .into_array();
407
408        expr.return_dtype(test_array.dtype()).unwrap();
409    }
410
411    #[test]
412    #[should_panic(expected = "merge: duplicate fields in children")]
413    pub fn test_merge_error_on_dupe_evaluate() {
414        let expr = merge_opts(
415            vec![get_item("0", root()), get_item("1", root())],
416            DuplicateHandling::Error,
417        );
418        let test_array = StructArray::try_from_iter([
419            (
420                "0",
421                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
422            ),
423            (
424                "1",
425                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
426            ),
427        ])
428        .unwrap()
429        .into_array();
430
431        test_array.apply(&expr).unwrap();
432    }
433
434    #[test]
435    pub fn test_empty_merge() {
436        let expr = merge(Vec::<Expression>::new());
437
438        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
439            .unwrap()
440            .into_array();
441        let actual_array = test_array.clone().apply(&expr).unwrap();
442        assert_eq!(actual_array.len(), test_array.len());
443        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
444    }
445
446    #[test]
447    pub fn test_nested_merge() {
448        // Nested structs are not merged!
449
450        let expr = merge_opts(
451            vec![get_item("0", root()), get_item("1", root())],
452            DuplicateHandling::RightMost,
453        );
454
455        let test_array = StructArray::from_fields(&[
456            (
457                "0",
458                StructArray::from_fields(&[(
459                    "a",
460                    StructArray::from_fields(&[
461                        ("x", buffer![0, 0, 0].into_array()),
462                        ("y", buffer![1, 1, 1].into_array()),
463                    ])
464                    .unwrap()
465                    .into_array(),
466                )])
467                .unwrap()
468                .into_array(),
469            ),
470            (
471                "1",
472                StructArray::from_fields(&[(
473                    "a",
474                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
475                        .unwrap()
476                        .into_array(),
477                )])
478                .unwrap()
479                .into_array(),
480            ),
481        ])
482        .unwrap()
483        .into_array();
484        #[expect(deprecated)]
485        let actual_array = test_array.apply(&expr).unwrap().to_struct();
486
487        #[expect(deprecated)]
488        let inner_struct = actual_array
489            .unmasked_field_by_name("a")
490            .unwrap()
491            .to_struct();
492        assert_eq!(
493            inner_struct
494                .names()
495                .iter()
496                .map(|name| name.as_ref())
497                .collect::<Vec<_>>(),
498            vec!["x"]
499        );
500    }
501
502    #[test]
503    pub fn test_merge_order() {
504        let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
505
506        let test_array = StructArray::from_fields(&[
507            (
508                "0",
509                StructArray::from_fields(&[
510                    ("a", buffer![0, 0, 0].into_array()),
511                    ("c", buffer![1, 1, 1].into_array()),
512                ])
513                .unwrap()
514                .into_array(),
515            ),
516            (
517                "1",
518                StructArray::from_fields(&[
519                    ("b", buffer![2, 2, 2].into_array()),
520                    ("d", buffer![3, 3, 3].into_array()),
521                ])
522                .unwrap()
523                .into_array(),
524            ),
525        ])
526        .unwrap()
527        .into_array();
528        #[expect(deprecated)]
529        let actual_array = test_array.apply(&expr).unwrap().to_struct();
530
531        assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
532    }
533
534    #[test]
535    pub fn test_display() {
536        let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
537        assert_eq!(
538            expr.to_string(),
539            "vortex.merge($.struct1, $.struct2, opts=Error)"
540        );
541
542        let expr2 = merge(vec![get_item("a", root())]);
543        assert_eq!(expr2.to_string(), "vortex.merge($.a, opts=Error)");
544    }
545
546    #[test]
547    fn test_remove_merge() {
548        let dtype = DType::struct_(
549            [
550                ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
551                ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
552            ],
553            NonNullable,
554        );
555
556        let e = merge_opts(
557            [get_item("0", root()), get_item("1", root())],
558            DuplicateHandling::RightMost,
559        );
560
561        let result = e.optimize(&dtype).unwrap();
562
563        assert!(result.is::<Pack>());
564        assert_eq!(
565            result.return_dtype(&dtype).unwrap(),
566            DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
567        );
568    }
569}