vortex_expr/
merge.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, ArrayVariants};
10use vortex_dtype::{DType, FieldNames, Nullability, StructDType};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::{ExprRef, VortexExpr};
14
15/// Merge zero or more expressions that ALL return structs.
16///
17/// If any field names are duplicated, the field from later expressions wins.
18///
19/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
20/// This makes struct fields behaviour consistent with other dtypes.
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct Merge {
23    values: Vec<ExprRef>,
24}
25
26impl Merge {
27    pub fn new_expr(values: Vec<ExprRef>) -> Arc<Self> {
28        Arc::new(Merge { values })
29    }
30}
31
32pub fn merge(elements: impl IntoIterator<Item = impl Into<ExprRef>>) -> ExprRef {
33    let values = elements.into_iter().map(|value| value.into()).collect_vec();
34    Merge::new_expr(values)
35}
36
37impl Display for Merge {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.write_str("{")?;
40        self.values
41            .iter()
42            .format_with(", ", |expr, f| f(expr))
43            .fmt(f)?;
44        f.write_str("}")
45    }
46}
47
48impl VortexExpr for Merge {
49    fn as_any(&self) -> &dyn Any {
50        self
51    }
52
53    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
54        let len = batch.len();
55        let value_arrays = self
56            .values
57            .iter()
58            .map(|value_expr| value_expr.evaluate(batch))
59            .process_results(|it| it.collect::<Vec<_>>())?;
60
61        // Collect fields in order of appearance. Later fields overwrite earlier fields.
62        let mut field_names = Vec::new();
63        let mut arrays = Vec::new();
64
65        for value_array in value_arrays.iter() {
66            // TODO(marko): When nullable, we need to merge struct validity into field validity.
67            if value_array.dtype().is_nullable() {
68                todo!("merge nullable structs");
69            }
70            if !value_array.dtype().is_struct() {
71                vortex_bail!("merge expects non-nullable struct input");
72            }
73
74            let struct_array = value_array
75                .as_struct_typed()
76                .vortex_expect("merge expects struct input");
77
78            for (i, field_name) in struct_array.names().iter().enumerate() {
79                let array = struct_array
80                    .maybe_null_field_by_idx(i)
81                    .vortex_expect("struct field not found");
82
83                // Update or insert field.
84                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
85                    arrays[idx] = array;
86                } else {
87                    field_names.push(field_name.clone());
88                    arrays.push(array);
89                }
90            }
91        }
92
93        Ok(StructArray::try_new(
94            FieldNames::from(field_names),
95            arrays,
96            len,
97            Validity::NonNullable,
98        )?
99        .into_array())
100    }
101
102    fn children(&self) -> Vec<&ExprRef> {
103        self.values.iter().collect()
104    }
105
106    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
107        Self::new_expr(children)
108    }
109
110    fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
111        let mut field_names = Vec::new();
112        let mut arrays = Vec::new();
113
114        for value in self.values.iter() {
115            let dtype = value.return_dtype(scope_dtype)?;
116            if !dtype.is_struct() {
117                vortex_bail!("merge expects non-nullable struct input");
118            }
119
120            let struct_dtype = dtype
121                .as_struct()
122                .vortex_expect("merge expects struct input");
123
124            for i in 0..struct_dtype.nfields() {
125                let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
126                let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
127                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
128                    arrays[idx] = field_dtype;
129                } else {
130                    field_names.push(field_name.clone());
131                    arrays.push(field_dtype);
132                }
133            }
134        }
135
136        Ok(DType::Struct(
137            Arc::new(StructDType::new(FieldNames::from(field_names), arrays)),
138            Nullability::NonNullable,
139        ))
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use vortex_array::arrays::{PrimitiveArray, StructArray};
146    use vortex_array::{Array, IntoArray, ToCanonical};
147    use vortex_buffer::buffer;
148    use vortex_error::{VortexResult, vortex_bail, vortex_err};
149
150    use crate::{GetItem, Identity, Merge, VortexExpr};
151
152    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
153        let mut field_path = field_path.iter();
154
155        let Some(field) = field_path.next() else {
156            vortex_bail!("empty field path");
157        };
158
159        let mut array = array
160            .as_struct_typed()
161            .ok_or_else(|| vortex_err!("expected a struct"))?
162            .maybe_null_field_by_name(field)?;
163
164        for field in field_path {
165            array = array
166                .as_struct_typed()
167                .ok_or_else(|| vortex_err!("expected a struct"))?
168                .maybe_null_field_by_name(field)?;
169        }
170        Ok(array.to_primitive().unwrap())
171    }
172
173    #[test]
174    pub fn test_merge() {
175        let expr = Merge::new_expr(vec![
176            GetItem::new_expr("0", Identity::new_expr()),
177            GetItem::new_expr("1", Identity::new_expr()),
178            GetItem::new_expr("2", Identity::new_expr()),
179        ]);
180
181        let test_array = StructArray::from_fields(&[
182            (
183                "0",
184                StructArray::from_fields(&[
185                    ("a", buffer![0, 0, 0].into_array()),
186                    ("b", buffer![1, 1, 1].into_array()),
187                ])
188                .unwrap()
189                .into_array(),
190            ),
191            (
192                "1",
193                StructArray::from_fields(&[
194                    ("b", buffer![2, 2, 2].into_array()),
195                    ("c", buffer![3, 3, 3].into_array()),
196                ])
197                .unwrap()
198                .into_array(),
199            ),
200            (
201                "2",
202                StructArray::from_fields(&[
203                    ("d", buffer![4, 4, 4].into_array()),
204                    ("e", buffer![5, 5, 5].into_array()),
205                ])
206                .unwrap()
207                .into_array(),
208            ),
209        ])
210        .unwrap()
211        .into_array();
212        let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
213
214        assert_eq!(
215            actual_array.as_struct_typed().unwrap().names(),
216            &["a".into(), "b".into(), "c".into(), "d".into(), "e".into()].into()
217        );
218
219        assert_eq!(
220            primitive_field(&actual_array, &["a"])
221                .unwrap()
222                .as_slice::<i32>(),
223            [0, 0, 0]
224        );
225        assert_eq!(
226            primitive_field(&actual_array, &["b"])
227                .unwrap()
228                .as_slice::<i32>(),
229            [2, 2, 2]
230        );
231        assert_eq!(
232            primitive_field(&actual_array, &["c"])
233                .unwrap()
234                .as_slice::<i32>(),
235            [3, 3, 3]
236        );
237        assert_eq!(
238            primitive_field(&actual_array, &["d"])
239                .unwrap()
240                .as_slice::<i32>(),
241            [4, 4, 4]
242        );
243        assert_eq!(
244            primitive_field(&actual_array, &["e"])
245                .unwrap()
246                .as_slice::<i32>(),
247            [5, 5, 5]
248        );
249    }
250
251    #[test]
252    pub fn test_empty_merge() {
253        let expr = Merge::new_expr(Vec::new());
254
255        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
256            .unwrap()
257            .into_array();
258        let actual_array = expr.evaluate(&test_array).unwrap();
259        assert_eq!(actual_array.len(), test_array.len());
260        assert_eq!(actual_array.as_struct_typed().unwrap().nfields(), 0);
261    }
262
263    #[test]
264    pub fn test_nested_merge() {
265        // Nested structs are not merged!
266
267        let expr = Merge::new_expr(vec![
268            GetItem::new_expr("0", Identity::new_expr()),
269            GetItem::new_expr("1", Identity::new_expr()),
270        ]);
271
272        let test_array = StructArray::from_fields(&[
273            (
274                "0",
275                StructArray::from_fields(&[(
276                    "a",
277                    StructArray::from_fields(&[
278                        ("x", buffer![0, 0, 0].into_array()),
279                        ("y", buffer![1, 1, 1].into_array()),
280                    ])
281                    .unwrap()
282                    .into_array(),
283                )])
284                .unwrap()
285                .into_array(),
286            ),
287            (
288                "1",
289                StructArray::from_fields(&[(
290                    "a",
291                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
292                        .unwrap()
293                        .into_array(),
294                )])
295                .unwrap()
296                .into_array(),
297            ),
298        ])
299        .unwrap()
300        .into_array();
301        let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
302
303        assert_eq!(
304            actual_array
305                .as_struct_typed()
306                .unwrap()
307                .maybe_null_field_by_name("a")
308                .unwrap()
309                .as_struct_typed()
310                .unwrap()
311                .names()
312                .iter()
313                .map(|name| name.as_ref())
314                .collect::<Vec<_>>(),
315            vec!["x"]
316        );
317    }
318
319    #[test]
320    pub fn test_merge_order() {
321        let expr = Merge::new_expr(vec![
322            GetItem::new_expr("0", Identity::new_expr()),
323            GetItem::new_expr("1", Identity::new_expr()),
324        ]);
325
326        let test_array = StructArray::from_fields(&[
327            (
328                "0",
329                StructArray::from_fields(&[
330                    ("a", buffer![0, 0, 0].into_array()),
331                    ("c", buffer![1, 1, 1].into_array()),
332                ])
333                .unwrap()
334                .into_array(),
335            ),
336            (
337                "1",
338                StructArray::from_fields(&[
339                    ("b", buffer![2, 2, 2].into_array()),
340                    ("d", buffer![3, 3, 3].into_array()),
341                ])
342                .unwrap()
343                .into_array(),
344            ),
345        ])
346        .unwrap()
347        .into_array();
348        let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
349
350        assert_eq!(
351            actual_array.as_struct_typed().unwrap().names(),
352            &["a".into(), "c".into(), "b".into(), "d".into()].into()
353        );
354    }
355}