vortex_array/expr/exprs/merge/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4pub mod transform;
5
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::sync::Arc;
9
10use itertools::Itertools as _;
11use vortex_dtype::DType;
12use vortex_dtype::FieldNames;
13use vortex_dtype::Nullability;
14use vortex_dtype::StructFields;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_utils::aliases::hash_set::HashSet;
18
19use crate::Array;
20use crate::ArrayRef;
21use crate::IntoArray as _;
22use crate::ToCanonical;
23use crate::arrays::StructArray;
24use crate::expr::ChildName;
25use crate::expr::ExprId;
26use crate::expr::Expression;
27use crate::expr::ExpressionView;
28use crate::expr::VTable;
29use crate::expr::VTableExt;
30use crate::validity::Validity;
31
32/// Merge zero or more expressions that ALL return structs.
33///
34/// If any field names are duplicated, the field from later expressions wins.
35///
36/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
37/// This makes struct fields behaviour consistent with other dtypes.
38pub struct Merge;
39
40impl VTable for Merge {
41    type Instance = DuplicateHandling;
42
43    fn id(&self) -> ExprId {
44        ExprId::new_ref("vortex.merge")
45    }
46
47    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
48        Ok(Some(match instance {
49            DuplicateHandling::RightMost => vec![0x00],
50            DuplicateHandling::Error => vec![0x01],
51        }))
52    }
53
54    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
55        let instance = match metadata {
56            [0x00] => DuplicateHandling::RightMost,
57            [0x01] => DuplicateHandling::Error,
58            _ => {
59                vortex_bail!("invalid metadata for Merge expression");
60            }
61        };
62        Ok(Some(instance))
63    }
64
65    fn validate(&self, _expr: &ExpressionView<Self>) -> VortexResult<()> {
66        Ok(())
67    }
68
69    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
70        ChildName::from(Arc::from(format!("{}", child_idx)))
71    }
72
73    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
74        write!(f, "merge(")?;
75        for (i, child) in expr.children().iter().enumerate() {
76            child.fmt_sql(f)?;
77            if i + 1 < expr.children().len() {
78                write!(f, ", ")?;
79            }
80        }
81        write!(f, ")")
82    }
83
84    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
85        let mut field_names = Vec::new();
86        let mut arrays = Vec::new();
87        let mut merge_nullability = Nullability::NonNullable;
88        let mut duplicate_names = HashSet::<_>::new();
89
90        for child in expr.children().iter() {
91            let dtype = child.return_dtype(scope)?;
92            let Some(fields) = dtype.as_struct_fields_opt() else {
93                vortex_bail!("merge expects struct input");
94            };
95            if dtype.is_nullable() {
96                vortex_bail!("merge expects non-nullable input");
97            }
98
99            merge_nullability |= dtype.nullability();
100
101            for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
102                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
103                    duplicate_names.insert(field_name.clone());
104                    arrays[idx] = field_dtype;
105                } else {
106                    field_names.push(field_name.clone());
107                    arrays.push(field_dtype);
108                }
109            }
110        }
111
112        if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() {
113            vortex_bail!(
114                "merge: duplicate fields in children: {}",
115                duplicate_names.into_iter().format(", ")
116            )
117        }
118
119        Ok(DType::Struct(
120            StructFields::new(FieldNames::from(field_names), arrays),
121            merge_nullability,
122        ))
123    }
124
125    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
126        // Collect fields in order of appearance. Later fields overwrite earlier fields.
127        let mut field_names = Vec::new();
128        let mut arrays = Vec::new();
129        let mut duplicate_names = HashSet::<_>::new();
130
131        for child in expr.children().iter() {
132            // TODO(marko): When nullable, we need to merge struct validity into field validity.
133            let array = child.evaluate(scope)?;
134            if array.dtype().is_nullable() {
135                vortex_bail!("merge expects non-nullable input");
136            }
137            if !array.dtype().is_struct() {
138                vortex_bail!("merge expects struct input");
139            }
140            let array = array.to_struct();
141
142            for (field_name, array) in array.names().iter().zip_eq(array.fields().iter().cloned()) {
143                // Update or insert field.
144                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
145                    duplicate_names.insert(field_name.clone());
146                    arrays[idx] = array;
147                } else {
148                    field_names.push(field_name.clone());
149                    arrays.push(array);
150                }
151            }
152        }
153
154        if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() {
155            vortex_bail!(
156                "merge: duplicate fields in children: {}",
157                duplicate_names.into_iter().format(", ")
158            )
159        }
160
161        // TODO(DK): When children are allowed to be nullable, this needs to change.
162        let validity = Validity::NonNullable;
163        let len = scope.len();
164        Ok(
165            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
166                .into_array(),
167        )
168    }
169
170    fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
171        true
172    }
173
174    fn is_fallible(&self, instance: &Self::Instance) -> bool {
175        matches!(instance, DuplicateHandling::Error)
176    }
177}
178
179/// What to do when merged structs share a field name.
180#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
181pub enum DuplicateHandling {
182    /// If two structs share a field name, take the value from the right-most struct.
183    RightMost,
184    /// If two structs share a field name, error.
185    #[default]
186    Error,
187}
188
189/// Creates an expression that merges struct expressions into a single struct.
190///
191/// Combines fields from all input expressions. If field names are duplicated,
192/// later expressions win. Fields are not recursively merged.
193///
194/// ```rust
195/// # use vortex_dtype::Nullability;
196/// # use vortex_array::expr::{merge, get_item, root};
197/// let expr = merge([get_item("a", root()), get_item("b", root())]);
198/// ```
199pub fn merge(elements: impl IntoIterator<Item = impl Into<Expression>>) -> Expression {
200    let values = elements.into_iter().map(|value| value.into()).collect_vec();
201    Merge.new_expr(DuplicateHandling::default(), values)
202}
203
204pub fn merge_opts(
205    elements: impl IntoIterator<Item = impl Into<Expression>>,
206    duplicate_handling: DuplicateHandling,
207) -> Expression {
208    let values = elements.into_iter().map(|value| value.into()).collect_vec();
209    Merge.new_expr(duplicate_handling, values)
210}
211
212#[cfg(test)]
213mod tests {
214    use vortex_buffer::buffer;
215    use vortex_error::VortexResult;
216    use vortex_error::vortex_bail;
217
218    use super::merge;
219    use crate::Array;
220    use crate::IntoArray;
221    use crate::ToCanonical;
222    use crate::arrays::PrimitiveArray;
223    use crate::arrays::StructArray;
224    use crate::expr::Expression;
225    use crate::expr::exprs::get_item::get_item;
226    use crate::expr::exprs::merge::DuplicateHandling;
227    use crate::expr::exprs::merge::merge_opts;
228    use crate::expr::exprs::root::root;
229
230    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
231        let mut field_path = field_path.iter();
232
233        let Some(field) = field_path.next() else {
234            vortex_bail!("empty field path");
235        };
236
237        let mut array = array.to_struct().field_by_name(field)?.clone();
238        for field in field_path {
239            array = array.to_struct().field_by_name(field)?.clone();
240        }
241        Ok(array.to_primitive())
242    }
243
244    #[test]
245    pub fn test_merge_right_most() {
246        let expr = merge_opts(
247            vec![
248                get_item("0", root()),
249                get_item("1", root()),
250                get_item("2", root()),
251            ],
252            DuplicateHandling::RightMost,
253        );
254
255        let test_array = StructArray::from_fields(&[
256            (
257                "0",
258                StructArray::from_fields(&[
259                    ("a", buffer![0, 0, 0].into_array()),
260                    ("b", buffer![1, 1, 1].into_array()),
261                ])
262                .unwrap()
263                .into_array(),
264            ),
265            (
266                "1",
267                StructArray::from_fields(&[
268                    ("b", buffer![2, 2, 2].into_array()),
269                    ("c", buffer![3, 3, 3].into_array()),
270                ])
271                .unwrap()
272                .into_array(),
273            ),
274            (
275                "2",
276                StructArray::from_fields(&[
277                    ("d", buffer![4, 4, 4].into_array()),
278                    ("e", buffer![5, 5, 5].into_array()),
279                ])
280                .unwrap()
281                .into_array(),
282            ),
283        ])
284        .unwrap()
285        .into_array();
286        let actual_array = expr.evaluate(&test_array).unwrap();
287
288        assert_eq!(
289            actual_array.as_struct_typed().names(),
290            ["a", "b", "c", "d", "e"]
291        );
292
293        assert_eq!(
294            primitive_field(&actual_array, &["a"])
295                .unwrap()
296                .as_slice::<i32>(),
297            [0, 0, 0]
298        );
299        assert_eq!(
300            primitive_field(&actual_array, &["b"])
301                .unwrap()
302                .as_slice::<i32>(),
303            [2, 2, 2]
304        );
305        assert_eq!(
306            primitive_field(&actual_array, &["c"])
307                .unwrap()
308                .as_slice::<i32>(),
309            [3, 3, 3]
310        );
311        assert_eq!(
312            primitive_field(&actual_array, &["d"])
313                .unwrap()
314                .as_slice::<i32>(),
315            [4, 4, 4]
316        );
317        assert_eq!(
318            primitive_field(&actual_array, &["e"])
319                .unwrap()
320                .as_slice::<i32>(),
321            [5, 5, 5]
322        );
323    }
324
325    #[test]
326    #[should_panic(expected = "merge: duplicate fields in children")]
327    pub fn test_merge_error_on_dupe_return_dtype() {
328        let expr = merge_opts(
329            vec![get_item("0", root()), get_item("1", root())],
330            DuplicateHandling::Error,
331        );
332        let test_array = StructArray::try_from_iter([
333            (
334                "0",
335                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
336            ),
337            (
338                "1",
339                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
340            ),
341        ])
342        .unwrap()
343        .into_array();
344
345        expr.return_dtype(test_array.dtype()).unwrap();
346    }
347
348    #[test]
349    #[should_panic(expected = "merge: duplicate fields in children")]
350    pub fn test_merge_error_on_dupe_evaluate() {
351        let expr = merge_opts(
352            vec![get_item("0", root()), get_item("1", root())],
353            DuplicateHandling::Error,
354        );
355        let test_array = StructArray::try_from_iter([
356            (
357                "0",
358                StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
359            ),
360            (
361                "1",
362                StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
363            ),
364        ])
365        .unwrap()
366        .into_array();
367
368        expr.evaluate(&test_array).unwrap();
369    }
370
371    #[test]
372    pub fn test_empty_merge() {
373        let expr = merge(Vec::<Expression>::new());
374
375        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
376            .unwrap()
377            .into_array();
378        let actual_array = expr.evaluate(&test_array.clone()).unwrap();
379        assert_eq!(actual_array.len(), test_array.len());
380        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
381    }
382
383    #[test]
384    pub fn test_nested_merge() {
385        // Nested structs are not merged!
386
387        let expr = merge_opts(
388            vec![get_item("0", root()), get_item("1", root())],
389            DuplicateHandling::RightMost,
390        );
391
392        let test_array = StructArray::from_fields(&[
393            (
394                "0",
395                StructArray::from_fields(&[(
396                    "a",
397                    StructArray::from_fields(&[
398                        ("x", buffer![0, 0, 0].into_array()),
399                        ("y", buffer![1, 1, 1].into_array()),
400                    ])
401                    .unwrap()
402                    .into_array(),
403                )])
404                .unwrap()
405                .into_array(),
406            ),
407            (
408                "1",
409                StructArray::from_fields(&[(
410                    "a",
411                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
412                        .unwrap()
413                        .into_array(),
414                )])
415                .unwrap()
416                .into_array(),
417            ),
418        ])
419        .unwrap()
420        .into_array();
421        let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
422
423        assert_eq!(
424            actual_array
425                .field_by_name("a")
426                .unwrap()
427                .to_struct()
428                .names()
429                .iter()
430                .map(|name| name.as_ref())
431                .collect::<Vec<_>>(),
432            vec!["x"]
433        );
434    }
435
436    #[test]
437    pub fn test_merge_order() {
438        let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
439
440        let test_array = StructArray::from_fields(&[
441            (
442                "0",
443                StructArray::from_fields(&[
444                    ("a", buffer![0, 0, 0].into_array()),
445                    ("c", buffer![1, 1, 1].into_array()),
446                ])
447                .unwrap()
448                .into_array(),
449            ),
450            (
451                "1",
452                StructArray::from_fields(&[
453                    ("b", buffer![2, 2, 2].into_array()),
454                    ("d", buffer![3, 3, 3].into_array()),
455                ])
456                .unwrap()
457                .into_array(),
458            ),
459        ])
460        .unwrap()
461        .into_array();
462        let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
463
464        assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
465    }
466
467    #[test]
468    pub fn test_display() {
469        let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
470        assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
471
472        let expr2 = merge(vec![get_item("a", root())]);
473        assert_eq!(expr2.to_string(), "merge($.a)");
474    }
475}