vortex_expr/exprs/
merge.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, DeserializeMetadata, EmptyMetadata, IntoArray, ToCanonical};
10use vortex_dtype::{DType, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
15
16vtable!(Merge);
17
18/// Merge zero or more expressions that ALL return structs.
19///
20/// If any field names are duplicated, the field from later expressions wins.
21///
22/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
23/// This makes struct fields behaviour consistent with other dtypes.
24#[allow(clippy::derived_hash_with_manual_eq)]
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub struct MergeExpr {
27    values: Vec<ExprRef>,
28    nullability: Nullability,
29}
30
31pub struct MergeExprEncoding;
32
33impl VTable for MergeVTable {
34    type Expr = MergeExpr;
35    type Encoding = MergeExprEncoding;
36    type Metadata = EmptyMetadata;
37
38    fn id(_encoding: &Self::Encoding) -> ExprId {
39        ExprId::new_ref("merge")
40    }
41
42    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
43        ExprEncodingRef::new_ref(MergeExprEncoding.as_ref())
44    }
45
46    fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
47        Some(EmptyMetadata)
48    }
49
50    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
51        expr.values.iter().collect()
52    }
53
54    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
55        Ok(MergeExpr {
56            values: children,
57            nullability: expr.nullability,
58        })
59    }
60
61    fn build(
62        _encoding: &Self::Encoding,
63        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
64        children: Vec<ExprRef>,
65    ) -> VortexResult<Self::Expr> {
66        if children.is_empty() {
67            vortex_bail!(
68                "Merge expression must have at least one child, got: {:?}",
69                children
70            );
71        }
72        Ok(MergeExpr {
73            values: children,
74            nullability: Nullability::NonNullable, // Default to non-nullable
75        })
76    }
77
78    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
79        let len = scope.len();
80        let value_arrays = expr
81            .values
82            .iter()
83            .map(|value_expr| value_expr.unchecked_evaluate(scope))
84            .process_results(|it| it.collect::<Vec<_>>())?;
85
86        // Collect fields in order of appearance. Later fields overwrite earlier fields.
87        let mut field_names = Vec::new();
88        let mut arrays = Vec::new();
89
90        for value_array in value_arrays.iter() {
91            // TODO(marko): When nullable, we need to merge struct validity into field validity.
92            if value_array.dtype().is_nullable() {
93                todo!("merge nullable structs");
94            }
95            if !value_array.dtype().is_struct() {
96                vortex_bail!("merge expects non-nullable struct input");
97            }
98
99            let struct_array = value_array.to_struct();
100
101            for (i, field_name) in struct_array.names().iter().enumerate() {
102                let array = struct_array.fields()[i].clone();
103
104                // Update or insert field.
105                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
106                    arrays[idx] = array;
107                } else {
108                    field_names.push(field_name.clone());
109                    arrays.push(array);
110                }
111            }
112        }
113
114        let validity = match expr.nullability {
115            Nullability::NonNullable => Validity::NonNullable,
116            Nullability::Nullable => Validity::AllValid,
117        };
118        Ok(
119            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
120                .into_array(),
121        )
122    }
123
124    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
125        let mut field_names = Vec::new();
126        let mut arrays = Vec::new();
127
128        for value in expr.values.iter() {
129            let dtype = value.return_dtype(scope)?;
130            if !dtype.is_struct() {
131                vortex_bail!("merge expects non-nullable struct input");
132            }
133
134            let struct_dtype = dtype
135                .as_struct_fields_opt()
136                .vortex_expect("merge expects struct input");
137
138            for i in 0..struct_dtype.nfields() {
139                let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
140                let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
141                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
142                    arrays[idx] = field_dtype;
143                } else {
144                    field_names.push(field_name.clone());
145                    arrays.push(field_dtype);
146                }
147            }
148        }
149
150        Ok(DType::Struct(
151            StructFields::new(FieldNames::from(field_names), arrays),
152            expr.nullability,
153        ))
154    }
155}
156
157impl MergeExpr {
158    pub fn new(values: Vec<ExprRef>, nullability: Nullability) -> Self {
159        MergeExpr {
160            values,
161            nullability,
162        }
163    }
164
165    pub fn new_expr(values: Vec<ExprRef>, nullability: Nullability) -> ExprRef {
166        Self::new(values, nullability).into_expr()
167    }
168
169    pub fn nullability(&self) -> Nullability {
170        self.nullability
171    }
172}
173
174/// Creates an expression that merges struct expressions into a single struct.
175///
176/// Combines fields from all input expressions. If field names are duplicated,
177/// later expressions win. Fields are not recursively merged.
178///
179/// ```rust
180/// # use vortex_dtype::Nullability;
181/// # use vortex_expr::{merge, get_item, root};
182/// let expr = merge([get_item("a", root()), get_item("b", root())], Nullability::NonNullable);
183/// ```
184pub fn merge(
185    elements: impl IntoIterator<Item = impl Into<ExprRef>>,
186    nullability: Nullability,
187) -> ExprRef {
188    let values = elements.into_iter().map(|value| value.into()).collect_vec();
189    MergeExpr::new(values, nullability).into_expr()
190}
191
192impl DisplayAs for MergeExpr {
193    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
194        match df {
195            DisplayFormat::Compact => {
196                write!(
197                    f,
198                    "merge({}){}",
199                    self.values.iter().format(", "),
200                    self.nullability
201                )
202            }
203            DisplayFormat::Tree => {
204                write!(f, "Merge")
205            }
206        }
207    }
208}
209
210impl AnalysisExpr for MergeExpr {}
211
212#[cfg(test)]
213mod tests {
214    use vortex_array::arrays::{PrimitiveArray, StructArray};
215    use vortex_array::{Array, IntoArray, ToCanonical};
216    use vortex_buffer::buffer;
217    use vortex_dtype::Nullability;
218    use vortex_error::{VortexResult, vortex_bail};
219
220    use crate::{MergeExpr, Scope, get_item, merge, root};
221
222    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
223        let mut field_path = field_path.iter();
224
225        let Some(field) = field_path.next() else {
226            vortex_bail!("empty field path");
227        };
228
229        let mut array = array.to_struct().field_by_name(field)?.clone();
230        for field in field_path {
231            array = array.to_struct().field_by_name(field)?.clone();
232        }
233        Ok(array.to_primitive())
234    }
235
236    #[test]
237    pub fn test_merge() {
238        let expr = MergeExpr::new(
239            vec![
240                get_item("0", root()),
241                get_item("1", root()),
242                get_item("2", root()),
243            ],
244            Nullability::NonNullable,
245        );
246
247        let test_array = StructArray::from_fields(&[
248            (
249                "0",
250                StructArray::from_fields(&[
251                    ("a", buffer![0, 0, 0].into_array()),
252                    ("b", buffer![1, 1, 1].into_array()),
253                ])
254                .unwrap()
255                .into_array(),
256            ),
257            (
258                "1",
259                StructArray::from_fields(&[
260                    ("b", buffer![2, 2, 2].into_array()),
261                    ("c", buffer![3, 3, 3].into_array()),
262                ])
263                .unwrap()
264                .into_array(),
265            ),
266            (
267                "2",
268                StructArray::from_fields(&[
269                    ("d", buffer![4, 4, 4].into_array()),
270                    ("e", buffer![5, 5, 5].into_array()),
271                ])
272                .unwrap()
273                .into_array(),
274            ),
275        ])
276        .unwrap()
277        .into_array();
278        let actual_array = expr.evaluate(&Scope::new(test_array)).unwrap();
279
280        assert_eq!(
281            actual_array.as_struct_typed().names(),
282            ["a", "b", "c", "d", "e"]
283        );
284
285        assert_eq!(
286            primitive_field(&actual_array, &["a"])
287                .unwrap()
288                .as_slice::<i32>(),
289            [0, 0, 0]
290        );
291        assert_eq!(
292            primitive_field(&actual_array, &["b"])
293                .unwrap()
294                .as_slice::<i32>(),
295            [2, 2, 2]
296        );
297        assert_eq!(
298            primitive_field(&actual_array, &["c"])
299                .unwrap()
300                .as_slice::<i32>(),
301            [3, 3, 3]
302        );
303        assert_eq!(
304            primitive_field(&actual_array, &["d"])
305                .unwrap()
306                .as_slice::<i32>(),
307            [4, 4, 4]
308        );
309        assert_eq!(
310            primitive_field(&actual_array, &["e"])
311                .unwrap()
312                .as_slice::<i32>(),
313            [5, 5, 5]
314        );
315    }
316
317    #[test]
318    pub fn test_empty_merge() {
319        let expr = MergeExpr::new(Vec::new(), Nullability::NonNullable);
320
321        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
322            .unwrap()
323            .into_array();
324        let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
325        assert_eq!(actual_array.len(), test_array.len());
326        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
327    }
328
329    #[test]
330    pub fn test_nested_merge() {
331        // Nested structs are not merged!
332
333        let expr = MergeExpr::new(
334            vec![get_item("0", root()), get_item("1", root())],
335            Nullability::NonNullable,
336        );
337
338        let test_array = StructArray::from_fields(&[
339            (
340                "0",
341                StructArray::from_fields(&[(
342                    "a",
343                    StructArray::from_fields(&[
344                        ("x", buffer![0, 0, 0].into_array()),
345                        ("y", buffer![1, 1, 1].into_array()),
346                    ])
347                    .unwrap()
348                    .into_array(),
349                )])
350                .unwrap()
351                .into_array(),
352            ),
353            (
354                "1",
355                StructArray::from_fields(&[(
356                    "a",
357                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
358                        .unwrap()
359                        .into_array(),
360                )])
361                .unwrap()
362                .into_array(),
363            ),
364        ])
365        .unwrap()
366        .into_array();
367        let actual_array = expr
368            .evaluate(&Scope::new(test_array.clone()))
369            .unwrap()
370            .to_struct();
371
372        assert_eq!(
373            actual_array
374                .field_by_name("a")
375                .unwrap()
376                .to_struct()
377                .names()
378                .iter()
379                .map(|name| name.as_ref())
380                .collect::<Vec<_>>(),
381            vec!["x"]
382        );
383    }
384
385    #[test]
386    pub fn test_merge_order() {
387        let expr = MergeExpr::new(
388            vec![get_item("0", root()), get_item("1", root())],
389            Nullability::NonNullable,
390        );
391
392        let test_array = StructArray::from_fields(&[
393            (
394                "0",
395                StructArray::from_fields(&[
396                    ("a", buffer![0, 0, 0].into_array()),
397                    ("c", buffer![1, 1, 1].into_array()),
398                ])
399                .unwrap()
400                .into_array(),
401            ),
402            (
403                "1",
404                StructArray::from_fields(&[
405                    ("b", buffer![2, 2, 2].into_array()),
406                    ("d", buffer![3, 3, 3].into_array()),
407                ])
408                .unwrap()
409                .into_array(),
410            ),
411        ])
412        .unwrap()
413        .into_array();
414        let actual_array = expr
415            .evaluate(&Scope::new(test_array.clone()))
416            .unwrap()
417            .to_struct();
418
419        assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
420    }
421
422    #[test]
423    pub fn test_merge_nullable() {
424        let expr = MergeExpr::new(vec![get_item("0", root())], Nullability::Nullable);
425
426        let test_array = StructArray::from_fields(&[(
427            "0",
428            StructArray::from_fields(&[
429                ("a", buffer![0, 0, 0].into_array()),
430                ("b", buffer![1, 1, 1].into_array()),
431            ])
432            .unwrap()
433            .into_array(),
434        )])
435        .unwrap()
436        .into_array();
437        let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
438        assert!(actual_array.dtype().is_nullable());
439    }
440
441    #[test]
442    pub fn test_display() {
443        let expr = merge(
444            [get_item("struct1", root()), get_item("struct2", root())],
445            Nullability::NonNullable,
446        );
447        assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
448
449        let expr2 = MergeExpr::new(vec![get_item("a", root())], Nullability::Nullable);
450        assert_eq!(expr2.to_string(), "merge($.a)?");
451    }
452}