Skip to main content

vortex_array/scalar_fn/fns/
merge.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_session::VortexSession;
14use vortex_utils::aliases::hash_set::HashSet;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray as _;
19use crate::arrays::StructArray;
20use crate::dtype::DType;
21use crate::dtype::FieldNames;
22use crate::dtype::Nullability;
23use crate::dtype::StructFields;
24use crate::expr::Expression;
25use crate::expr::lit;
26use crate::scalar_fn::Arity;
27use crate::scalar_fn::ChildName;
28use crate::scalar_fn::ExecutionArgs;
29use crate::scalar_fn::ReduceCtx;
30use crate::scalar_fn::ReduceNode;
31use crate::scalar_fn::ReduceNodeRef;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::ScalarFnVTableExt;
35use crate::scalar_fn::fns::get_item::GetItem;
36use crate::scalar_fn::fns::pack::Pack;
37use crate::scalar_fn::fns::pack::PackOptions;
38use crate::validity::Validity;
39
40/// Merge zero or more expressions that ALL return structs.
41///
42/// If any field names are duplicated, the field from later expressions wins.
43///
44/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
45/// This makes struct fields behaviour consistent with other dtypes.
46#[derive(Clone)]
47pub struct Merge;
48
49impl ScalarFnVTable for Merge {
50    type Options = DuplicateHandling;
51
52    fn id(&self) -> ScalarFnId {
53        ScalarFnId::new_ref("vortex.merge")
54    }
55
56    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
57        Ok(Some(match instance {
58            DuplicateHandling::RightMost => vec![0x00],
59            DuplicateHandling::Error => vec![0x01],
60        }))
61    }
62
63    fn deserialize(
64        &self,
65        _metadata: &[u8],
66        _session: &VortexSession,
67    ) -> VortexResult<Self::Options> {
68        let instance = match _metadata {
69            [0x00] => DuplicateHandling::RightMost,
70            [0x01] => DuplicateHandling::Error,
71            _ => {
72                vortex_bail!("invalid metadata for Merge expression");
73            }
74        };
75        Ok(instance)
76    }
77
78    fn arity(&self, _options: &Self::Options) -> Arity {
79        Arity::Variadic { min: 0, max: None }
80    }
81
82    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
83        ChildName::from(Arc::from(format!("{}", child_idx)))
84    }
85
86    fn fmt_sql(
87        &self,
88        _options: &Self::Options,
89        expr: &Expression,
90        f: &mut Formatter<'_>,
91    ) -> std::fmt::Result {
92        write!(f, "merge(")?;
93        for (i, child) in expr.children().iter().enumerate() {
94            child.fmt_sql(f)?;
95            if i + 1 < expr.children().len() {
96                write!(f, ", ")?;
97            }
98        }
99        write!(f, ")")
100    }
101
102    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
103        let mut field_names = Vec::new();
104        let mut arrays = Vec::new();
105        let mut merge_nullability = Nullability::NonNullable;
106        let mut duplicate_names = HashSet::<_>::new();
107
108        for dtype in arg_dtypes {
109            let Some(fields) = dtype.as_struct_fields_opt() else {
110                vortex_bail!("merge expects struct input");
111            };
112            if dtype.is_nullable() {
113                vortex_bail!("merge expects non-nullable input");
114            }
115
116            merge_nullability |= dtype.nullability();
117
118            for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
119                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
120                    duplicate_names.insert(field_name.clone());
121                    arrays[idx] = field_dtype;
122                } else {
123                    field_names.push(field_name.clone());
124                    arrays.push(field_dtype);
125                }
126            }
127        }
128
129        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
130            vortex_bail!(
131                "merge: duplicate fields in children: {}",
132                duplicate_names.into_iter().format(", ")
133            )
134        }
135
136        Ok(DType::Struct(
137            StructFields::new(FieldNames::from(field_names), arrays),
138            merge_nullability,
139        ))
140    }
141
142    fn execute(
143        &self,
144        options: &Self::Options,
145        args: &dyn ExecutionArgs,
146        ctx: &mut ExecutionCtx,
147    ) -> VortexResult<ArrayRef> {
148        // Collect fields in order of appearance. Later fields overwrite earlier fields.
149        let mut field_names = Vec::new();
150        let mut arrays = Vec::new();
151        let mut duplicate_names = HashSet::<_>::new();
152
153        for i in 0..args.num_inputs() {
154            let array = args.get(i)?.execute::<StructArray>(ctx)?;
155            if array.dtype().is_nullable() {
156                vortex_bail!("merge expects non-nullable input");
157            }
158
159            for (field_name, field_array) in array
160                .names()
161                .iter()
162                .zip_eq(array.unmasked_fields().iter().cloned())
163            {
164                // Update or insert field.
165                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
166                    duplicate_names.insert(field_name.clone());
167                    arrays[idx] = field_array;
168                } else {
169                    field_names.push(field_name.clone());
170                    arrays.push(field_array);
171                }
172            }
173        }
174
175        if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
176            vortex_bail!(
177                "merge: duplicate fields in children: {}",
178                duplicate_names.into_iter().format(", ")
179            )
180        }
181
182        // TODO(DK): When children are allowed to be nullable, this needs to change.
183        let validity = Validity::NonNullable;
184        let len = args.row_count();
185        Ok(
186            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
187                .into_array(),
188        )
189    }
190
191    fn reduce(
192        &self,
193        options: &Self::Options,
194        node: &dyn ReduceNode,
195        ctx: &dyn ReduceCtx,
196    ) -> VortexResult<Option<ReduceNodeRef>> {
197        let mut names = Vec::with_capacity(node.child_count() * 2);
198        let mut children = Vec::with_capacity(node.child_count() * 2);
199        let mut duplicate_names = HashSet::<_>::new();
200
201        for child in (0..node.child_count()).map(|i| node.child(i)) {
202            let child_dtype = child.node_dtype()?;
203            if !child_dtype.is_struct() {
204                vortex_bail!(
205                    "Merge child must return a non-nullable struct dtype, got {}",
206                    child_dtype
207                )
208            }
209
210            let child_dtype = child_dtype
211                .as_struct_fields_opt()
212                .vortex_expect("expected struct");
213
214            for name in child_dtype.names().iter() {
215                if let Some(idx) = names.iter().position(|n| n == name) {
216                    duplicate_names.insert(name.clone());
217                    children[idx] = child.clone();
218                } else {
219                    names.push(name.clone());
220                    children.push(child.clone());
221                }
222            }
223
224            if options == &DuplicateHandling::Error && !duplicate_names.is_empty() {
225                vortex_bail!(
226                    "merge: duplicate fields in children: {}",
227                    duplicate_names.into_iter().format(", ")
228                )
229            }
230        }
231
232        let pack_children: Vec<_> = names
233            .iter()
234            .zip(children)
235            .map(|(name, child)| ctx.new_node(GetItem.bind(name.clone()), &[child]))
236            .try_collect()?;
237
238        let pack_expr = ctx.new_node(
239            Pack.bind(PackOptions {
240                names: FieldNames::from(names),
241                nullability: node.node_dtype()?.nullability(),
242            }),
243            &pack_children,
244        )?;
245
246        Ok(Some(pack_expr))
247    }
248
249    fn validity(
250        &self,
251        _options: &Self::Options,
252        _expression: &Expression,
253    ) -> VortexResult<Option<Expression>> {
254        Ok(Some(lit(true)))
255    }
256
257    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
258        true
259    }
260
261    fn is_fallible(&self, instance: &Self::Options) -> bool {
262        matches!(instance, DuplicateHandling::Error)
263    }
264}
265
266/// What to do when merged structs share a field name.
267#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
268pub enum DuplicateHandling {
269    /// If two structs share a field name, take the value from the right-most struct.
270    RightMost,
271    /// If two structs share a field name, error.
272    #[default]
273    Error,
274}
275
276impl Display for DuplicateHandling {
277    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
278        match self {
279            DuplicateHandling::RightMost => write!(f, "RightMost"),
280            DuplicateHandling::Error => write!(f, "Error"),
281        }
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use vortex_buffer::buffer;
288    use vortex_error::VortexResult;
289    use vortex_error::vortex_bail;
290
291    use crate::Array;
292    use crate::ArrayRef;
293    use crate::IntoArray;
294    use crate::ToCanonical;
295    use crate::arrays::PrimitiveArray;
296    use crate::arrays::StructArray;
297    use crate::assert_arrays_eq;
298    use crate::dtype::DType;
299    use crate::dtype::Nullability::NonNullable;
300    use crate::dtype::PType::I32;
301    use crate::dtype::PType::I64;
302    use crate::dtype::PType::U32;
303    use crate::dtype::PType::U64;
304    use crate::expr::Expression;
305    use crate::expr::get_item;
306    use crate::expr::merge;
307    use crate::expr::merge_opts;
308    use crate::expr::root;
309    use crate::scalar_fn::fns::merge::DuplicateHandling;
310    use crate::scalar_fn::fns::pack::Pack;
311
312    fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
313        let mut field_path = field_path.iter();
314
315        let Some(field) = field_path.next() else {
316            vortex_bail!("empty field path");
317        };
318
319        let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
320        for field in field_path {
321            array = array.to_struct().unmasked_field_by_name(field)?.clone();
322        }
323        Ok(array.to_primitive())
324    }
325
326    #[test]
327    pub fn test_merge_right_most() {
328        let expr = merge_opts(
329            vec![
330                get_item("0", root()),
331                get_item("1", root()),
332                get_item("2", root()),
333            ],
334            DuplicateHandling::RightMost,
335        );
336
337        let test_array = StructArray::from_fields(&[
338            (
339                "0",
340                StructArray::from_fields(&[
341                    ("a", buffer![0, 0, 0].into_array()),
342                    ("b", buffer![1, 1, 1].into_array()),
343                ])
344                .unwrap()
345                .into_array(),
346            ),
347            (
348                "1",
349                StructArray::from_fields(&[
350                    ("b", buffer![2, 2, 2].into_array()),
351                    ("c", buffer![3, 3, 3].into_array()),
352                ])
353                .unwrap()
354                .into_array(),
355            ),
356            (
357                "2",
358                StructArray::from_fields(&[
359                    ("d", buffer![4, 4, 4].into_array()),
360                    ("e", buffer![5, 5, 5].into_array()),
361                ])
362                .unwrap()
363                .into_array(),
364            ),
365        ])
366        .unwrap()
367        .into_array();
368        let actual_array = test_array.apply(&expr).unwrap();
369
370        assert_eq!(
371            actual_array.as_struct_typed().names(),
372            ["a", "b", "c", "d", "e"]
373        );
374
375        assert_arrays_eq!(
376            primitive_field(&actual_array, &["a"]).unwrap(),
377            PrimitiveArray::from_iter([0i32, 0, 0])
378        );
379        assert_arrays_eq!(
380            primitive_field(&actual_array, &["b"]).unwrap(),
381            PrimitiveArray::from_iter([2i32, 2, 2])
382        );
383        assert_arrays_eq!(
384            primitive_field(&actual_array, &["c"]).unwrap(),
385            PrimitiveArray::from_iter([3i32, 3, 3])
386        );
387        assert_arrays_eq!(
388            primitive_field(&actual_array, &["d"]).unwrap(),
389            PrimitiveArray::from_iter([4i32, 4, 4])
390        );
391        assert_arrays_eq!(
392            primitive_field(&actual_array, &["e"]).unwrap(),
393            PrimitiveArray::from_iter([5i32, 5, 5])
394        );
395    }
396
397    #[test]
398    #[should_panic(expected = "merge: duplicate fields in children")]
399    pub fn test_merge_error_on_dupe_return_dtype() {
400        let expr = merge_opts(
401            vec![get_item("0", root()), get_item("1", root())],
402            DuplicateHandling::Error,
403        );
404        let test_array = StructArray::try_from_iter([
405            (
406                "0",
407                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
408            ),
409            (
410                "1",
411                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
412            ),
413        ])
414        .unwrap()
415        .into_array();
416
417        expr.return_dtype(test_array.dtype()).unwrap();
418    }
419
420    #[test]
421    #[should_panic(expected = "merge: duplicate fields in children")]
422    pub fn test_merge_error_on_dupe_evaluate() {
423        let expr = merge_opts(
424            vec![get_item("0", root()), get_item("1", root())],
425            DuplicateHandling::Error,
426        );
427        let test_array = StructArray::try_from_iter([
428            (
429                "0",
430                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
431            ),
432            (
433                "1",
434                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
435            ),
436        ])
437        .unwrap()
438        .into_array();
439
440        test_array.apply(&expr).unwrap();
441    }
442
443    #[test]
444    pub fn test_empty_merge() {
445        let expr = merge(Vec::<Expression>::new());
446
447        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
448            .unwrap()
449            .into_array();
450        let actual_array = test_array.clone().apply(&expr).unwrap();
451        assert_eq!(actual_array.len(), test_array.len());
452        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
453    }
454
455    #[test]
456    pub fn test_nested_merge() {
457        // Nested structs are not merged!
458
459        let expr = merge_opts(
460            vec![get_item("0", root()), get_item("1", root())],
461            DuplicateHandling::RightMost,
462        );
463
464        let test_array = StructArray::from_fields(&[
465            (
466                "0",
467                StructArray::from_fields(&[(
468                    "a",
469                    StructArray::from_fields(&[
470                        ("x", buffer![0, 0, 0].into_array()),
471                        ("y", buffer![1, 1, 1].into_array()),
472                    ])
473                    .unwrap()
474                    .into_array(),
475                )])
476                .unwrap()
477                .into_array(),
478            ),
479            (
480                "1",
481                StructArray::from_fields(&[(
482                    "a",
483                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
484                        .unwrap()
485                        .into_array(),
486                )])
487                .unwrap()
488                .into_array(),
489            ),
490        ])
491        .unwrap()
492        .into_array();
493        let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
494
495        assert_eq!(
496            actual_array
497                .unmasked_field_by_name("a")
498                .unwrap()
499                .to_struct()
500                .names()
501                .iter()
502                .map(|name| name.as_ref())
503                .collect::<Vec<_>>(),
504            vec!["x"]
505        );
506    }
507
508    #[test]
509    pub fn test_merge_order() {
510        let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
511
512        let test_array = StructArray::from_fields(&[
513            (
514                "0",
515                StructArray::from_fields(&[
516                    ("a", buffer![0, 0, 0].into_array()),
517                    ("c", buffer![1, 1, 1].into_array()),
518                ])
519                .unwrap()
520                .into_array(),
521            ),
522            (
523                "1",
524                StructArray::from_fields(&[
525                    ("b", buffer![2, 2, 2].into_array()),
526                    ("d", buffer![3, 3, 3].into_array()),
527                ])
528                .unwrap()
529                .into_array(),
530            ),
531        ])
532        .unwrap()
533        .into_array();
534        let actual_array = test_array.clone().apply(&expr).unwrap().to_struct();
535
536        assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
537    }
538
539    #[test]
540    pub fn test_display() {
541        let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
542        assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
543
544        let expr2 = merge(vec![get_item("a", root())]);
545        assert_eq!(expr2.to_string(), "merge($.a)");
546    }
547
548    #[test]
549    fn test_remove_merge() {
550        let dtype = DType::struct_(
551            [
552                ("0", DType::struct_([("a", I32), ("b", I64)], NonNullable)),
553                ("1", DType::struct_([("b", U32), ("c", U64)], NonNullable)),
554            ],
555            NonNullable,
556        );
557
558        let e = merge_opts(
559            [get_item("0", root()), get_item("1", root())],
560            DuplicateHandling::RightMost,
561        );
562
563        let result = e.optimize(&dtype).unwrap();
564
565        assert!(result.is::<Pack>());
566        assert_eq!(
567            result.return_dtype(&dtype).unwrap(),
568            DType::struct_([("a", I32), ("b", U32), ("c", U64)], NonNullable)
569        );
570    }
571}