vortex_expr/exprs/
merge.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, DeserializeMetadata, EmptyMetadata, IntoArray, ToCanonical};
10use vortex_dtype::{DType, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
15
16vtable!(Merge);
17
18/// Merge zero or more expressions that ALL return structs.
19///
20/// If any field names are duplicated, the field from later expressions wins.
21///
22/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
23/// This makes struct fields behaviour consistent with other dtypes.
24#[allow(clippy::derived_hash_with_manual_eq)]
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub struct MergeExpr {
27    values: Vec<ExprRef>,
28}
29
30pub struct MergeExprEncoding;
31
32impl VTable for MergeVTable {
33    type Expr = MergeExpr;
34    type Encoding = MergeExprEncoding;
35    type Metadata = EmptyMetadata;
36
37    fn id(_encoding: &Self::Encoding) -> ExprId {
38        ExprId::new_ref("merge")
39    }
40
41    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
42        ExprEncodingRef::new_ref(MergeExprEncoding.as_ref())
43    }
44
45    fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
46        Some(EmptyMetadata)
47    }
48
49    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
50        expr.values.iter().collect()
51    }
52
53    fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
54        Ok(MergeExpr { values: children })
55    }
56
57    fn build(
58        _encoding: &Self::Encoding,
59        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
60        children: Vec<ExprRef>,
61    ) -> VortexResult<Self::Expr> {
62        if children.is_empty() {
63            vortex_bail!(
64                "Merge expression must have at least one child, got: {:?}",
65                children
66            );
67        }
68        Ok(MergeExpr { values: children })
69    }
70
71    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
72        let len = scope.len();
73        let value_arrays = expr
74            .values
75            .iter()
76            .map(|value_expr| value_expr.unchecked_evaluate(scope))
77            .process_results(|it| it.collect::<Vec<_>>())?;
78
79        // Collect fields in order of appearance. Later fields overwrite earlier fields.
80        let mut field_names = Vec::new();
81        let mut arrays = Vec::new();
82
83        for value_array in value_arrays.iter() {
84            // TODO(marko): When nullable, we need to merge struct validity into field validity.
85            if value_array.dtype().is_nullable() {
86                todo!("merge nullable structs");
87            }
88            if !value_array.dtype().is_struct() {
89                vortex_bail!("merge expects non-nullable struct input");
90            }
91
92            let struct_array = value_array.to_struct();
93
94            for (field_name, array) in struct_array
95                .names()
96                .iter()
97                .zip_eq(struct_array.fields().iter().cloned())
98            {
99                // Update or insert field.
100                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
101                    arrays[idx] = array;
102                } else {
103                    field_names.push(field_name.clone());
104                    arrays.push(array);
105                }
106            }
107        }
108
109        // TODO(DK): When children are allowed to be nullable, this needs to change.
110        let validity = Validity::NonNullable;
111        Ok(
112            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
113                .into_array(),
114        )
115    }
116
117    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
118        let mut field_names = Vec::new();
119        let mut arrays = Vec::new();
120
121        let mut nullability = Nullability::NonNullable;
122
123        for value in expr.values.iter() {
124            let dtype = value.return_dtype(scope)?;
125            if !dtype.is_struct() {
126                vortex_bail!("merge expects struct input");
127            }
128            if dtype.is_nullable() {
129                vortex_bail!("merge expects non-nullable input");
130            }
131            nullability |= dtype.nullability();
132
133            let struct_dtype = dtype
134                .as_struct_fields_opt()
135                .vortex_expect("merge expects struct input");
136
137            for i in 0..struct_dtype.nfields() {
138                let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
139                let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
140                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
141                    arrays[idx] = field_dtype;
142                } else {
143                    field_names.push(field_name.clone());
144                    arrays.push(field_dtype);
145                }
146            }
147        }
148
149        Ok(DType::Struct(
150            StructFields::new(FieldNames::from(field_names), arrays),
151            nullability,
152        ))
153    }
154}
155
156impl MergeExpr {
157    pub fn new(values: Vec<ExprRef>) -> Self {
158        MergeExpr { values }
159    }
160
161    pub fn new_expr(values: Vec<ExprRef>) -> ExprRef {
162        Self::new(values).into_expr()
163    }
164}
165
166/// Creates an expression that merges struct expressions into a single struct.
167///
168/// Combines fields from all input expressions. If field names are duplicated,
169/// later expressions win. Fields are not recursively merged.
170///
171/// ```rust
172/// # use vortex_dtype::Nullability;
173/// # use vortex_expr::{merge, get_item, root};
174/// let expr = merge([get_item("a", root()), get_item("b", root())]);
175/// ```
176pub fn merge(elements: impl IntoIterator<Item = impl Into<ExprRef>>) -> ExprRef {
177    let values = elements.into_iter().map(|value| value.into()).collect_vec();
178    MergeExpr::new(values).into_expr()
179}
180
181impl DisplayAs for MergeExpr {
182    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
183        match df {
184            DisplayFormat::Compact => {
185                write!(f, "merge({})", self.values.iter().format(", "),)
186            }
187            DisplayFormat::Tree => {
188                write!(f, "Merge")
189            }
190        }
191    }
192}
193
194impl AnalysisExpr for MergeExpr {}
195
196#[cfg(test)]
197mod tests {
198    use vortex_array::arrays::{PrimitiveArray, StructArray};
199    use vortex_array::{Array, IntoArray, ToCanonical};
200    use vortex_buffer::buffer;
201    use vortex_error::{VortexResult, vortex_bail};
202
203    use crate::{MergeExpr, Scope, get_item, merge, root};
204
205    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
206        let mut field_path = field_path.iter();
207
208        let Some(field) = field_path.next() else {
209            vortex_bail!("empty field path");
210        };
211
212        let mut array = array.to_struct().field_by_name(field)?.clone();
213        for field in field_path {
214            array = array.to_struct().field_by_name(field)?.clone();
215        }
216        Ok(array.to_primitive())
217    }
218
219    #[test]
220    pub fn test_merge() {
221        let expr = MergeExpr::new(vec![
222            get_item("0", root()),
223            get_item("1", root()),
224            get_item("2", root()),
225        ]);
226
227        let test_array = StructArray::from_fields(&[
228            (
229                "0",
230                StructArray::from_fields(&[
231                    ("a", buffer![0, 0, 0].into_array()),
232                    ("b", buffer![1, 1, 1].into_array()),
233                ])
234                .unwrap()
235                .into_array(),
236            ),
237            (
238                "1",
239                StructArray::from_fields(&[
240                    ("b", buffer![2, 2, 2].into_array()),
241                    ("c", buffer![3, 3, 3].into_array()),
242                ])
243                .unwrap()
244                .into_array(),
245            ),
246            (
247                "2",
248                StructArray::from_fields(&[
249                    ("d", buffer![4, 4, 4].into_array()),
250                    ("e", buffer![5, 5, 5].into_array()),
251                ])
252                .unwrap()
253                .into_array(),
254            ),
255        ])
256        .unwrap()
257        .into_array();
258        let actual_array = expr.evaluate(&Scope::new(test_array)).unwrap();
259
260        assert_eq!(
261            actual_array.as_struct_typed().names(),
262            ["a", "b", "c", "d", "e"]
263        );
264
265        assert_eq!(
266            primitive_field(&actual_array, &["a"])
267                .unwrap()
268                .as_slice::<i32>(),
269            [0, 0, 0]
270        );
271        assert_eq!(
272            primitive_field(&actual_array, &["b"])
273                .unwrap()
274                .as_slice::<i32>(),
275            [2, 2, 2]
276        );
277        assert_eq!(
278            primitive_field(&actual_array, &["c"])
279                .unwrap()
280                .as_slice::<i32>(),
281            [3, 3, 3]
282        );
283        assert_eq!(
284            primitive_field(&actual_array, &["d"])
285                .unwrap()
286                .as_slice::<i32>(),
287            [4, 4, 4]
288        );
289        assert_eq!(
290            primitive_field(&actual_array, &["e"])
291                .unwrap()
292                .as_slice::<i32>(),
293            [5, 5, 5]
294        );
295    }
296
297    #[test]
298    pub fn test_empty_merge() {
299        let expr = MergeExpr::new(Vec::new());
300
301        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
302            .unwrap()
303            .into_array();
304        let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
305        assert_eq!(actual_array.len(), test_array.len());
306        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
307    }
308
309    #[test]
310    pub fn test_nested_merge() {
311        // Nested structs are not merged!
312
313        let expr = MergeExpr::new(vec![get_item("0", root()), get_item("1", root())]);
314
315        let test_array = StructArray::from_fields(&[
316            (
317                "0",
318                StructArray::from_fields(&[(
319                    "a",
320                    StructArray::from_fields(&[
321                        ("x", buffer![0, 0, 0].into_array()),
322                        ("y", buffer![1, 1, 1].into_array()),
323                    ])
324                    .unwrap()
325                    .into_array(),
326                )])
327                .unwrap()
328                .into_array(),
329            ),
330            (
331                "1",
332                StructArray::from_fields(&[(
333                    "a",
334                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
335                        .unwrap()
336                        .into_array(),
337                )])
338                .unwrap()
339                .into_array(),
340            ),
341        ])
342        .unwrap()
343        .into_array();
344        let actual_array = expr
345            .evaluate(&Scope::new(test_array.clone()))
346            .unwrap()
347            .to_struct();
348
349        assert_eq!(
350            actual_array
351                .field_by_name("a")
352                .unwrap()
353                .to_struct()
354                .names()
355                .iter()
356                .map(|name| name.as_ref())
357                .collect::<Vec<_>>(),
358            vec!["x"]
359        );
360    }
361
362    #[test]
363    pub fn test_merge_order() {
364        let expr = MergeExpr::new(vec![get_item("0", root()), get_item("1", root())]);
365
366        let test_array = StructArray::from_fields(&[
367            (
368                "0",
369                StructArray::from_fields(&[
370                    ("a", buffer![0, 0, 0].into_array()),
371                    ("c", buffer![1, 1, 1].into_array()),
372                ])
373                .unwrap()
374                .into_array(),
375            ),
376            (
377                "1",
378                StructArray::from_fields(&[
379                    ("b", buffer![2, 2, 2].into_array()),
380                    ("d", buffer![3, 3, 3].into_array()),
381                ])
382                .unwrap()
383                .into_array(),
384            ),
385        ])
386        .unwrap()
387        .into_array();
388        let actual_array = expr
389            .evaluate(&Scope::new(test_array.clone()))
390            .unwrap()
391            .to_struct();
392
393        assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
394    }
395
396    #[test]
397    pub fn test_display() {
398        let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
399        assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
400
401        let expr2 = MergeExpr::new(vec![get_item("a", root())]);
402        assert_eq!(expr2.to_string(), "merge($.a)");
403    }
404}