Skip to main content

vortex_array/scalar_fn/fns/
merge.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_session::VortexSession;
14use vortex_session::registry::CachedId;
15use vortex_utils::aliases::hash_set::HashSet;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::IntoArray as _;
20use crate::arrays::StructArray;
21use crate::arrays::struct_::StructArrayExt;
22use crate::dtype::DType;
23use crate::dtype::FieldNames;
24use crate::dtype::Nullability;
25use crate::dtype::StructFields;
26use crate::expr::Expression;
27use crate::expr::lit;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ReduceCtx;
32use crate::scalar_fn::ReduceNode;
33use crate::scalar_fn::ReduceNodeRef;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::ScalarFnVTableExt;
37use crate::scalar_fn::fns::get_item::GetItem;
38use crate::scalar_fn::fns::pack::Pack;
39use crate::scalar_fn::fns::pack::PackOptions;
40use crate::validity::Validity;
41
42/// Merge zero or more expressions that ALL return structs.
43///
44/// If any field names are duplicated, the field from later expressions wins.
45///
46/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
47/// This makes struct fields behaviour consistent with other dtypes.
48#[derive(Clone)]
49pub struct Merge;
50
51impl ScalarFnVTable for Merge {
52    type Options = DuplicateHandling;
53
54    fn id(&self) -> ScalarFnId {
55        static ID: CachedId = CachedId::new("vortex.merge");
56        *ID
57    }
58
59    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
60        Ok(Some(match instance {
61            DuplicateHandling::RightMost => vec![0x00],
62            DuplicateHandling::Error => vec![0x01],
63        }))
64    }
65
66    fn deserialize(
67        &self,
68        _metadata: &[u8],
69        _session: &VortexSession,
70    ) -> VortexResult<Self::Options> {
71        let instance = match _metadata {
72            [0x00] => DuplicateHandling::RightMost,
73            [0x01] => DuplicateHandling::Error,
74            _ => {
75                vortex_bail!("invalid metadata for Merge expression");
76            }
77        };
78        Ok(instance)
79    }
80
81    fn arity(&self, _options: &Self::Options) -> Arity {
82        Arity::Variadic { min: 0, max: None }
83    }
84
85    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
86        ChildName::from(Arc::from(format!("{}", child_idx)))
87    }
88
89    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
90        let mut field_names = Vec::new();
91        let mut arrays = Vec::new();
92        let mut merge_nullability = Nullability::NonNullable;
93        let mut duplicate_names = HashSet::<_>::new();
94
95        for dtype in arg_dtypes {
96            let Some(fields) = dtype.as_struct_fields_opt() else {
97                vortex_bail!("merge expects struct input");
98            };
99            if dtype.is_nullable() {
100                vortex_bail!("merge expects non-nullable input");
101            }
102
103            merge_nullability |= dtype.nullability();
104
105            for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
106                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
107                    duplicate_names.insert(field_name.clone());
108                    arrays[idx] = field_dtype;
109                } else {
110                    field_names.push(field_name.clone());
111                    arrays.push(field_dtype);
112                }
113            }
114        }
115
116        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
117            vortex_bail!(
118                "merge: duplicate fields in children: {}",
119                duplicate_names.into_iter().format(", ")
120            )
121        }
122
123        Ok(DType::Struct(
124            StructFields::new(FieldNames::from(field_names), arrays),
125            merge_nullability,
126        ))
127    }
128
129    fn execute(
130        &self,
131        options: &Self::Options,
132        args: &dyn ExecutionArgs,
133        ctx: &mut ExecutionCtx,
134    ) -> VortexResult<ArrayRef> {
135        // Collect fields in order of appearance. Later fields overwrite earlier fields.
136        let mut field_names = Vec::new();
137        let mut arrays = Vec::new();
138        let mut duplicate_names = HashSet::<_>::new();
139
140        for i in 0..args.num_inputs() {
141            let array = args.get(i)?.execute::<StructArray>(ctx)?;
142            if array.dtype().is_nullable() {
143                vortex_bail!("merge expects non-nullable input");
144            }
145
146            for (field_name, field_array) in array
147                .names()
148                .iter()
149                .zip_eq(array.iter_unmasked_fields().cloned())
150            {
151                // Update or insert field.
152                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
153                    duplicate_names.insert(field_name.clone());
154                    arrays[idx] = field_array;
155                } else {
156                    field_names.push(field_name.clone());
157                    arrays.push(field_array);
158                }
159            }
160        }
161
162        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
163            vortex_bail!(
164                "merge: duplicate fields in children: {}",
165                duplicate_names.into_iter().format(", ")
166            )
167        }
168
169        // TODO(DK): When children are allowed to be nullable, this needs to change.
170        let validity = Validity::NonNullable;
171        let len = args.row_count();
172        Ok(
173            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
174                .into_array(),
175        )
176    }
177
178    fn reduce(
179        &self,
180        options: &Self::Options,
181        node: &dyn ReduceNode,
182        ctx: &dyn ReduceCtx,
183    ) -> VortexResult<Option<ReduceNodeRef>> {
184        let mut names = Vec::with_capacity(node.child_count() * 2);
185        let mut children = Vec::with_capacity(node.child_count() * 2);
186        let mut duplicate_names = HashSet::<_>::new();
187
188        for child in (0..node.child_count()).map(|i| node.child(i)) {
189            let child_dtype = child.node_dtype()?;
190            if !child_dtype.is_struct() {
191                vortex_bail!(
192                    "Merge child must return a non-nullable struct dtype, got {}",
193                    child_dtype
194                )
195            }
196
197            let child_dtype = child_dtype
198                .as_struct_fields_opt()
199                .vortex_expect("expected struct");
200
201            for name in child_dtype.names().iter() {
202                if let Some(idx) = names.iter().position(|n| n == name) {
203                    duplicate_names.insert(name.clone());
204                    children[idx] = Arc::clone(&child);
205                } else {
206                    names.push(name.clone());
207                    children.push(Arc::clone(&child));
208                }
209            }
210
211            if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
212                vortex_bail!(
213                    "merge: duplicate fields in children: {}",
214                    duplicate_names.into_iter().format(", ")
215                )
216            }
217        }
218
219        let pack_children: Vec<_> = names
220            .iter()
221            .zip(children)
222            .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
223            .try_collect()?;
224
225        let pack_expr = ctx.new_node(
226            Pack.bind(PackOptions {
227                names: FieldNames::from(names),
228                nullability: node.node_dtype()?.nullability(),
229            }),
230            &pack_children,
231        )?;
232
233        Ok(Some(pack_expr))
234    }
235
236    fn validity(
237        &self,
238        _options: &Self::Options,
239        _expression: &Expression,
240    ) -> VortexResult<Option<Expression>> {
241        Ok(Some(lit(true)))
242    }
243
244    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
245        true
246    }
247
248    fn is_fallible(&self, instance: &Self::Options) -> bool {
249        matches!(instance, DuplicateHandling::Error)
250    }
251}
252
253/// What to do when merged structs share a field name.
254#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
255pub enum DuplicateHandling {
256    /// If two structs share a field name, take the value from the right-most struct.
257    RightMost,
258    /// If two structs share a field name, error.
259    #[default]
260    Error,
261}
262
263impl Display for DuplicateHandling {
264    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
265        match self {
266            DuplicateHandling::RightMost => write!(f, "RightMost"),
267            DuplicateHandling::Error => write!(f, "Error"),
268        }
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use vortex_buffer::buffer;
275    use vortex_error::VortexResult;
276    use vortex_error::vortex_bail;
277
278    use crate::ArrayRef;
279    use crate::IntoArray;
280    #[expect(deprecated)]
281    use crate::ToCanonical as _;
282    use crate::VortexSessionExecute;
283    use crate::array_session;
284    use crate::arrays::PrimitiveArray;
285    use crate::arrays::struct_::StructArrayExt;
286    use crate::assert_arrays_eq;
287    use crate::dtype::DType;
288    use crate::dtype::Nullability::NonNullable;
289    use crate::dtype::PType::I32;
290    use crate::dtype::PType::I64;
291    use crate::dtype::PType::U32;
292    use crate::dtype::PType::U64;
293    use crate::expr::Expression;
294    use crate::expr::get_item;
295    use crate::expr::merge;
296    use crate::expr::merge_opts;
297    use crate::expr::root;
298    use crate::scalar_fn::fns::merge::DuplicateHandling;
299    use crate::scalar_fn::fns::merge::StructArray;
300    use crate::scalar_fn::fns::pack::Pack;
301
302    fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
303        let mut field_path = field_path.iter();
304
305        let Some(field) = field_path.next() else {
306            vortex_bail!("empty field path");
307        };
308
309        #[expect(deprecated)]
310        let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
311        for field in field_path {
312            #[expect(deprecated)]
313            let next = array.to_struct().unmasked_field_by_name(field)?.clone();
314            array = next;
315        }
316        #[expect(deprecated)]
317        let result = array.to_primitive();
318        Ok(result)
319    }
320
321    #[test]
322    pub fn test_merge_right_most() {
323        let mut ctx = array_session().create_execution_ctx();
324        let expr = merge_opts(
325            vec![
326                get_item("0", root()),
327                get_item("1", root()),
328                get_item("2", root()),
329            ],
330            DuplicateHandling::RightMost,
331        );
332
333        let test_array = StructArray::from_fields(&[
334            (
335                "0",
336                StructArray::from_fields(&[
337                    ("a", buffer![0, 0, 0].into_array()),
338                    ("b", buffer![1, 1, 1].into_array()),
339                ])
340                .unwrap()
341                .into_array(),
342            ),
343            (
344                "1",
345                StructArray::from_fields(&[
346                    ("b", buffer![2, 2, 2].into_array()),
347                    ("c", buffer![3, 3, 3].into_array()),
348                ])
349                .unwrap()
350                .into_array(),
351            ),
352            (
353                "2",
354                StructArray::from_fields(&[
355                    ("d", buffer![4, 4, 4].into_array()),
356                    ("e", buffer![5, 5, 5].into_array()),
357                ])
358                .unwrap()
359                .into_array(),
360            ),
361        ])
362        .unwrap()
363        .into_array();
364        let actual_array = test_array.apply(&expr).unwrap();
365
366        assert_eq!(
367            actual_array.as_struct_typed().names(),
368            ["a", "b", "c", "d", "e"]
369        );
370
371        assert_arrays_eq!(
372            primitive_field(&actual_array, &["a"]).unwrap(),
373            PrimitiveArray::from_iter([0i32, 0, 0]),
374            &mut ctx
375        );
376        assert_arrays_eq!(
377            primitive_field(&actual_array, &["b"]).unwrap(),
378            PrimitiveArray::from_iter([2i32, 2, 2]),
379            &mut ctx
380        );
381        assert_arrays_eq!(
382            primitive_field(&actual_array, &["c"]).unwrap(),
383            PrimitiveArray::from_iter([3i32, 3, 3]),
384            &mut ctx
385        );
386        assert_arrays_eq!(
387            primitive_field(&actual_array, &["d"]).unwrap(),
388            PrimitiveArray::from_iter([4i32, 4, 4]),
389            &mut ctx
390        );
391        assert_arrays_eq!(
392            primitive_field(&actual_array, &["e"]).unwrap(),
393            PrimitiveArray::from_iter([5i32, 5, 5]),
394            &mut ctx
395        );
396    }
397
398    #[test]
399    #[should_panic(expected = "merge: duplicate fields in children")]
400    pub fn test_merge_error_on_dupe_return_dtype() {
401        let expr = merge_opts(
402            vec![get_item("0", root()), get_item("1", root())],
403            DuplicateHandling::Error,
404        );
405        let test_array = StructArray::try_from_iter([
406            (
407                "0",
408                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
409            ),
410            (
411                "1",
412                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
413            ),
414        ])
415        .unwrap()
416        .into_array();
417
418        expr.return_dtype(test_array.dtype()).unwrap();
419    }
420
421    #[test]
422    #[should_panic(expected = "merge: duplicate fields in children")]
423    pub fn test_merge_error_on_dupe_evaluate() {
424        let expr = merge_opts(
425            vec![get_item("0", root()), get_item("1", root())],
426            DuplicateHandling::Error,
427        );
428        let test_array = StructArray::try_from_iter([
429            (
430                "0",
431                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
432            ),
433            (
434                "1",
435                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
436            ),
437        ])
438        .unwrap()
439        .into_array();
440
441        test_array.apply(&expr).unwrap();
442    }
443
444    #[test]
445    pub fn test_empty_merge() {
446        let expr = merge(Vec::<Expression>::new());
447
448        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
449            .unwrap()
450            .into_array();
451        let actual_array = test_array.clone().apply(&expr).unwrap();
452        assert_eq!(actual_array.len(), test_array.len());
453        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
454    }
455
456    #[test]
457    pub fn test_nested_merge() {
458        // Nested structs are not merged!
459
460        let expr = merge_opts(
461            vec![get_item("0", root()), get_item("1", root())],
462            DuplicateHandling::RightMost,
463        );
464
465        let test_array = StructArray::from_fields(&[
466            (
467                "0",
468                StructArray::from_fields(&[(
469                    "a",
470                    StructArray::from_fields(&[
471                        ("x", buffer![0, 0, 0].into_array()),
472                        ("y", buffer![1, 1, 1].into_array()),
473                    ])
474                    .unwrap()
475                    .into_array(),
476                )])
477                .unwrap()
478                .into_array(),
479            ),
480            (
481                "1",
482                StructArray::from_fields(&[(
483                    "a",
484                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
485                        .unwrap()
486                        .into_array(),
487                )])
488                .unwrap()
489                .into_array(),
490            ),
491        ])
492        .unwrap()
493        .into_array();
494        #[expect(deprecated)]
495        let actual_array = test_array.apply(&expr).unwrap().to_struct();
496
497        #[expect(deprecated)]
498        let inner_struct = actual_array
499            .unmasked_field_by_name("a")
500            .unwrap()
501            .to_struct();
502        assert_eq!(
503            inner_struct
504                .names()
505                .iter()
506                .map(|name| name.as_ref())
507                .collect::<Vec<_>>(),
508            vec!["x"]
509        );
510    }
511
512    #[test]
513    pub fn test_merge_order() {
514        let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
515
516        let test_array = StructArray::from_fields(&[
517            (
518                "0",
519                StructArray::from_fields(&[
520                    ("a", buffer![0, 0, 0].into_array()),
521                    ("c", buffer![1, 1, 1].into_array()),
522                ])
523                .unwrap()
524                .into_array(),
525            ),
526            (
527                "1",
528                StructArray::from_fields(&[
529                    ("b", buffer![2, 2, 2].into_array()),
530                    ("d", buffer![3, 3, 3].into_array()),
531                ])
532                .unwrap()
533                .into_array(),
534            ),
535        ])
536        .unwrap()
537        .into_array();
538        #[expect(deprecated)]
539        let actual_array = test_array.apply(&expr).unwrap().to_struct();
540
541        assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
542    }
543
544    #[test]
545    pub fn test_display() {
546        let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
547        assert_eq!(
548            expr.to_string(),
549            "vortex.merge($.struct1, $.struct2, opts=Error)"
550        );
551
552        let expr2 = merge(vec![get_item("a", root())]);
553        assert_eq!(expr2.to_string(), "vortex.merge($.a, opts=Error)");
554    }
555
556    #[test]
557    fn test_remove_merge() {
558        let dtype = DType::struct_(
559            [
560                ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
561                ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
562            ],
563            NonNullable,
564        );
565
566        let e = merge_opts(
567            [get_item("0", root()), get_item("1", root())],
568            DuplicateHandling::RightMost,
569        );
570
571        let result = e.optimize(&dtype).unwrap();
572
573        assert!(result.is::<Pack>());
574        assert_eq!(
575            result.return_dtype(&dtype).unwrap(),
576            DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
577        );
578    }
579}