vortex_array/expr/exprs/
merge.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::hash::Hash;
6use std::sync::Arc;
7
8use itertools::Itertools as _;
9use vortex_dtype::{DType, FieldNames, Nullability, StructFields};
10use vortex_error::{VortexResult, vortex_bail};
11use vortex_utils::aliases::hash_set::HashSet;
12
13use crate::arrays::StructArray;
14use crate::expr::{ChildName, ExprId, Expression, ExpressionView, VTable, VTableExt};
15use crate::validity::Validity;
16use crate::{Array, ArrayRef, IntoArray as _, ToCanonical};
17
18/// Merge zero or more expressions that ALL return structs.
19///
20/// If any field names are duplicated, the field from later expressions wins.
21///
22/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
23/// This makes struct fields behaviour consistent with other dtypes.
24pub struct Merge;
25
26impl VTable for Merge {
27    type Instance = DuplicateHandling;
28
29    fn id(&self) -> ExprId {
30        ExprId::new_ref("vortex.merge")
31    }
32
33    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
34        Ok(Some(match instance {
35            DuplicateHandling::RightMost => vec![0x00],
36            DuplicateHandling::Error => vec![0x01],
37        }))
38    }
39
40    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
41        let instance = match metadata {
42            [0x00] => DuplicateHandling::RightMost,
43            [0x01] => DuplicateHandling::Error,
44            _ => {
45                vortex_bail!("invalid metadata for Merge expression");
46            }
47        };
48        Ok(Some(instance))
49    }
50
51    fn validate(&self, _expr: &ExpressionView<Self>) -> VortexResult<()> {
52        Ok(())
53    }
54
55    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
56        ChildName::from(Arc::from(format!("{}", child_idx)))
57    }
58
59    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
60        write!(f, "merge(")?;
61        for (i, child) in expr.children().iter().enumerate() {
62            child.fmt_sql(f)?;
63            if i + 1 < expr.children().len() {
64                write!(f, ", ")?;
65            }
66        }
67        write!(f, ")")
68    }
69
70    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
71        let mut field_names = Vec::new();
72        let mut arrays = Vec::new();
73        let mut merge_nullability = Nullability::NonNullable;
74        let mut duplicate_names = HashSet::<_>::new();
75
76        for child in expr.children().iter() {
77            let dtype = child.return_dtype(scope)?;
78            let Some(fields) = dtype.as_struct_fields_opt() else {
79                vortex_bail!("merge expects struct input");
80            };
81            if dtype.is_nullable() {
82                vortex_bail!("merge expects non-nullable input");
83            }
84
85            merge_nullability |= dtype.nullability();
86
87            for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
88                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
89                    duplicate_names.insert(field_name.clone());
90                    arrays[idx] = field_dtype;
91                } else {
92                    field_names.push(field_name.clone());
93                    arrays.push(field_dtype);
94                }
95            }
96        }
97
98        if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() {
99            vortex_bail!(
100                "merge: duplicate fields in children: {}",
101                duplicate_names.into_iter().format(", ")
102            )
103        }
104
105        Ok(DType::Struct(
106            StructFields::new(FieldNames::from(field_names), arrays),
107            merge_nullability,
108        ))
109    }
110
111    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
112        // Collect fields in order of appearance. Later fields overwrite earlier fields.
113        let mut field_names = Vec::new();
114        let mut arrays = Vec::new();
115        let mut duplicate_names = HashSet::<_>::new();
116
117        for child in expr.children().iter() {
118            // TODO(marko): When nullable, we need to merge struct validity into field validity.
119            let array = child.evaluate(scope)?;
120            if array.dtype().is_nullable() {
121                vortex_bail!("merge expects non-nullable input");
122            }
123            if !array.dtype().is_struct() {
124                vortex_bail!("merge expects struct input");
125            }
126            let array = array.to_struct();
127
128            for (field_name, array) in array.names().iter().zip_eq(array.fields().iter().cloned()) {
129                // Update or insert field.
130                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
131                    duplicate_names.insert(field_name.clone());
132                    arrays[idx] = array;
133                } else {
134                    field_names.push(field_name.clone());
135                    arrays.push(array);
136                }
137            }
138        }
139
140        if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() {
141            vortex_bail!(
142                "merge: duplicate fields in children: {}",
143                duplicate_names.into_iter().format(", ")
144            )
145        }
146
147        // TODO(DK): When children are allowed to be nullable, this needs to change.
148        let validity = Validity::NonNullable;
149        let len = scope.len();
150        Ok(
151            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
152                .into_array(),
153        )
154    }
155}
156
157/// What to do when merged structs share a field name.
158#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
159pub enum DuplicateHandling {
160    /// If two structs share a field name, take the value from the right-most struct.
161    RightMost,
162    /// If two structs share a field name, error.
163    #[default]
164    Error,
165}
166
167/// Creates an expression that merges struct expressions into a single struct.
168///
169/// Combines fields from all input expressions. If field names are duplicated,
170/// later expressions win. Fields are not recursively merged.
171///
172/// ```rust
173/// # use vortex_dtype::Nullability;
174/// # use vortex_array::expr::{merge, get_item, root};
175/// let expr = merge([get_item("a", root()), get_item("b", root())]);
176/// ```
177pub fn merge(elements: impl IntoIterator<Item = impl Into<Expression>>) -> Expression {
178    let values = elements.into_iter().map(|value| value.into()).collect_vec();
179    Merge.new_expr(DuplicateHandling::default(), values)
180}
181
182pub fn merge_opts(
183    elements: impl IntoIterator<Item = impl Into<Expression>>,
184    duplicate_handling: DuplicateHandling,
185) -> Expression {
186    let values = elements.into_iter().map(|value| value.into()).collect_vec();
187    Merge.new_expr(duplicate_handling, values)
188}
189
190#[cfg(test)]
191mod tests {
192    use vortex_buffer::buffer;
193    use vortex_error::{VortexResult, vortex_bail};
194
195    use super::merge;
196    use crate::arrays::{PrimitiveArray, StructArray};
197    use crate::expr::Expression;
198    use crate::expr::exprs::get_item::get_item;
199    use crate::expr::exprs::merge::{DuplicateHandling, merge_opts};
200    use crate::expr::exprs::root::root;
201    use crate::{Array, IntoArray, ToCanonical};
202
203    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
204        let mut field_path = field_path.iter();
205
206        let Some(field) = field_path.next() else {
207            vortex_bail!("empty field path");
208        };
209
210        let mut array = array.to_struct().field_by_name(field)?.clone();
211        for field in field_path {
212            array = array.to_struct().field_by_name(field)?.clone();
213        }
214        Ok(array.to_primitive())
215    }
216
217    #[test]
218    pub fn test_merge_right_most() {
219        let expr = merge_opts(
220            vec![
221                get_item("0", root()),
222                get_item("1", root()),
223                get_item("2", root()),
224            ],
225            DuplicateHandling::RightMost,
226        );
227
228        let test_array = StructArray::from_fields(&[
229            (
230                "0",
231                StructArray::from_fields(&[
232                    ("a", buffer![0, 0, 0].into_array()),
233                    ("b", buffer![1, 1, 1].into_array()),
234                ])
235                .unwrap()
236                .into_array(),
237            ),
238            (
239                "1",
240                StructArray::from_fields(&[
241                    ("b", buffer![2, 2, 2].into_array()),
242                    ("c", buffer![3, 3, 3].into_array()),
243                ])
244                .unwrap()
245                .into_array(),
246            ),
247            (
248                "2",
249                StructArray::from_fields(&[
250                    ("d", buffer![4, 4, 4].into_array()),
251                    ("e", buffer![5, 5, 5].into_array()),
252                ])
253                .unwrap()
254                .into_array(),
255            ),
256        ])
257        .unwrap()
258        .into_array();
259        let actual_array = expr.evaluate(&test_array).unwrap();
260
261        assert_eq!(
262            actual_array.as_struct_typed().names(),
263            ["a", "b", "c", "d", "e"]
264        );
265
266        assert_eq!(
267            primitive_field(&actual_array, &["a"])
268                .unwrap()
269                .as_slice::<i32>(),
270            [0, 0, 0]
271        );
272        assert_eq!(
273            primitive_field(&actual_array, &["b"])
274                .unwrap()
275                .as_slice::<i32>(),
276            [2, 2, 2]
277        );
278        assert_eq!(
279            primitive_field(&actual_array, &["c"])
280                .unwrap()
281                .as_slice::<i32>(),
282            [3, 3, 3]
283        );
284        assert_eq!(
285            primitive_field(&actual_array, &["d"])
286                .unwrap()
287                .as_slice::<i32>(),
288            [4, 4, 4]
289        );
290        assert_eq!(
291            primitive_field(&actual_array, &["e"])
292                .unwrap()
293                .as_slice::<i32>(),
294            [5, 5, 5]
295        );
296    }
297
298    #[test]
299    #[should_panic(expected = "merge: duplicate fields in children")]
300    pub fn test_merge_error_on_dupe_return_dtype() {
301        let expr = merge_opts(
302            vec![get_item("0", root()), get_item("1", root())],
303            DuplicateHandling::Error,
304        );
305        let test_array = StructArray::try_from_iter([
306            (
307                "0",
308                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
309            ),
310            (
311                "1",
312                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
313            ),
314        ])
315        .unwrap()
316        .into_array();
317
318        expr.return_dtype(test_array.dtype()).unwrap();
319    }
320
321    #[test]
322    #[should_panic(expected = "merge: duplicate fields in children")]
323    pub fn test_merge_error_on_dupe_evaluate() {
324        let expr = merge_opts(
325            vec![get_item("0", root()), get_item("1", root())],
326            DuplicateHandling::Error,
327        );
328        let test_array = StructArray::try_from_iter([
329            (
330                "0",
331                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
332            ),
333            (
334                "1",
335                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
336            ),
337        ])
338        .unwrap()
339        .into_array();
340
341        expr.evaluate(&test_array).unwrap();
342    }
343
344    #[test]
345    pub fn test_empty_merge() {
346        let expr = merge(Vec::<Expression>::new());
347
348        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
349            .unwrap()
350            .into_array();
351        let actual_array = expr.evaluate(&test_array.clone()).unwrap();
352        assert_eq!(actual_array.len(), test_array.len());
353        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
354    }
355
356    #[test]
357    pub fn test_nested_merge() {
358        // Nested structs are not merged!
359
360        let expr = merge_opts(
361            vec![get_item("0", root()), get_item("1", root())],
362            DuplicateHandling::RightMost,
363        );
364
365        let test_array = StructArray::from_fields(&[
366            (
367                "0",
368                StructArray::from_fields(&[(
369                    "a",
370                    StructArray::from_fields(&[
371                        ("x", buffer![0, 0, 0].into_array()),
372                        ("y", buffer![1, 1, 1].into_array()),
373                    ])
374                    .unwrap()
375                    .into_array(),
376                )])
377                .unwrap()
378                .into_array(),
379            ),
380            (
381                "1",
382                StructArray::from_fields(&[(
383                    "a",
384                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
385                        .unwrap()
386                        .into_array(),
387                )])
388                .unwrap()
389                .into_array(),
390            ),
391        ])
392        .unwrap()
393        .into_array();
394        let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
395
396        assert_eq!(
397            actual_array
398                .field_by_name("a")
399                .unwrap()
400                .to_struct()
401                .names()
402                .iter()
403                .map(|name| name.as_ref())
404                .collect::<Vec<_>>(),
405            vec!["x"]
406        );
407    }
408
409    #[test]
410    pub fn test_merge_order() {
411        let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
412
413        let test_array = StructArray::from_fields(&[
414            (
415                "0",
416                StructArray::from_fields(&[
417                    ("a", buffer![0, 0, 0].into_array()),
418                    ("c", buffer![1, 1, 1].into_array()),
419                ])
420                .unwrap()
421                .into_array(),
422            ),
423            (
424                "1",
425                StructArray::from_fields(&[
426                    ("b", buffer![2, 2, 2].into_array()),
427                    ("d", buffer![3, 3, 3].into_array()),
428                ])
429                .unwrap()
430                .into_array(),
431            ),
432        ])
433        .unwrap()
434        .into_array();
435        let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
436
437        assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
438    }
439
440    #[test]
441    pub fn test_display() {
442        let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
443        assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
444
445        let expr2 = merge(vec![get_item("a", root())]);
446        assert_eq!(expr2.to_string(), "merge($.a)");
447    }
448}