vortex_expr/
merge.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, ArrayVariants};
10use vortex_dtype::{DType, FieldNames, Nullability, StructDType};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::{ExprRef, VortexExpr};
14
15/// Merge zero or more expressions that ALL return structs.
16///
17/// If any field names are duplicated, the field from later expressions wins.
18///
19/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
20/// This makes struct fields behaviour consistent with other dtypes.
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct Merge {
23    values: Vec<ExprRef>,
24}
25
26impl Merge {
27    pub fn new_expr(values: Vec<ExprRef>) -> Arc<Self> {
28        Arc::new(Merge { values })
29    }
30}
31
32pub fn merge(elements: impl IntoIterator<Item = impl Into<ExprRef>>) -> ExprRef {
33    let values = elements.into_iter().map(|value| value.into()).collect_vec();
34    Merge::new_expr(values)
35}
36
37impl Display for Merge {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.write_str("{")?;
40        self.values
41            .iter()
42            .format_with(", ", |expr, f| f(expr))
43            .fmt(f)?;
44        f.write_str("}")
45    }
46}
47
48#[cfg(feature = "proto")]
49pub(crate) mod proto {
50    use vortex_error::{VortexResult, vortex_bail};
51    use vortex_proto::expr::kind::Kind;
52
53    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Merge};
54
55    pub struct MergeSerde;
56
57    impl Id for MergeSerde {
58        fn id(&self) -> &'static str {
59            "merge"
60        }
61    }
62
63    impl ExprDeserialize for MergeSerde {
64        fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
65            vortex_bail!(NotImplemented: "", self.id())
66        }
67    }
68
69    impl ExprSerializable for Merge {
70        fn id(&self) -> &'static str {
71            MergeSerde.id()
72        }
73
74        fn serialize_kind(&self) -> VortexResult<Kind> {
75            vortex_bail!(NotImplemented: "", self.id())
76        }
77    }
78}
79
80impl VortexExpr for Merge {
81    fn as_any(&self) -> &dyn Any {
82        self
83    }
84
85    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
86        let len = batch.len();
87        let value_arrays = self
88            .values
89            .iter()
90            .map(|value_expr| value_expr.evaluate(batch))
91            .process_results(|it| it.collect::<Vec<_>>())?;
92
93        // Collect fields in order of appearance. Later fields overwrite earlier fields.
94        let mut field_names = Vec::new();
95        let mut arrays = Vec::new();
96
97        for value_array in value_arrays.iter() {
98            // TODO(marko): When nullable, we need to merge struct validity into field validity.
99            if value_array.dtype().is_nullable() {
100                todo!("merge nullable structs");
101            }
102            if !value_array.dtype().is_struct() {
103                vortex_bail!("merge expects non-nullable struct input");
104            }
105
106            let struct_array = value_array
107                .as_struct_typed()
108                .vortex_expect("merge expects struct input");
109
110            for (i, field_name) in struct_array.names().iter().enumerate() {
111                let array = struct_array
112                    .maybe_null_field_by_idx(i)
113                    .vortex_expect("struct field not found");
114
115                // Update or insert field.
116                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
117                    arrays[idx] = array;
118                } else {
119                    field_names.push(field_name.clone());
120                    arrays.push(array);
121                }
122            }
123        }
124
125        Ok(StructArray::try_new(
126            FieldNames::from(field_names),
127            arrays,
128            len,
129            Validity::NonNullable,
130        )?
131        .into_array())
132    }
133
134    fn children(&self) -> Vec<&ExprRef> {
135        self.values.iter().collect()
136    }
137
138    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
139        Self::new_expr(children)
140    }
141
142    fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
143        let mut field_names = Vec::new();
144        let mut arrays = Vec::new();
145
146        for value in self.values.iter() {
147            let dtype = value.return_dtype(scope_dtype)?;
148            if !dtype.is_struct() {
149                vortex_bail!("merge expects non-nullable struct input");
150            }
151
152            let struct_dtype = dtype
153                .as_struct()
154                .vortex_expect("merge expects struct input");
155
156            for i in 0..struct_dtype.nfields() {
157                let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
158                let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
159                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
160                    arrays[idx] = field_dtype;
161                } else {
162                    field_names.push(field_name.clone());
163                    arrays.push(field_dtype);
164                }
165            }
166        }
167
168        Ok(DType::Struct(
169            Arc::new(StructDType::new(FieldNames::from(field_names), arrays)),
170            Nullability::NonNullable,
171        ))
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use vortex_array::arrays::{PrimitiveArray, StructArray};
178    use vortex_array::{Array, IntoArray, ToCanonical};
179    use vortex_buffer::buffer;
180    use vortex_error::{VortexResult, vortex_bail, vortex_err};
181
182    use crate::{GetItem, Identity, Merge, VortexExpr};
183
184    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
185        let mut field_path = field_path.iter();
186
187        let Some(field) = field_path.next() else {
188            vortex_bail!("empty field path");
189        };
190
191        let mut array = array
192            .as_struct_typed()
193            .ok_or_else(|| vortex_err!("expected a struct"))?
194            .maybe_null_field_by_name(field)?;
195
196        for field in field_path {
197            array = array
198                .as_struct_typed()
199                .ok_or_else(|| vortex_err!("expected a struct"))?
200                .maybe_null_field_by_name(field)?;
201        }
202        Ok(array.to_primitive().unwrap())
203    }
204
205    #[test]
206    pub fn test_merge() {
207        let expr = Merge::new_expr(vec![
208            GetItem::new_expr("0", Identity::new_expr()),
209            GetItem::new_expr("1", Identity::new_expr()),
210            GetItem::new_expr("2", Identity::new_expr()),
211        ]);
212
213        let test_array = StructArray::from_fields(&[
214            (
215                "0",
216                StructArray::from_fields(&[
217                    ("a", buffer![0, 0, 0].into_array()),
218                    ("b", buffer![1, 1, 1].into_array()),
219                ])
220                .unwrap()
221                .into_array(),
222            ),
223            (
224                "1",
225                StructArray::from_fields(&[
226                    ("b", buffer![2, 2, 2].into_array()),
227                    ("c", buffer![3, 3, 3].into_array()),
228                ])
229                .unwrap()
230                .into_array(),
231            ),
232            (
233                "2",
234                StructArray::from_fields(&[
235                    ("d", buffer![4, 4, 4].into_array()),
236                    ("e", buffer![5, 5, 5].into_array()),
237                ])
238                .unwrap()
239                .into_array(),
240            ),
241        ])
242        .unwrap()
243        .into_array();
244        let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
245
246        assert_eq!(
247            actual_array.as_struct_typed().unwrap().names(),
248            &["a".into(), "b".into(), "c".into(), "d".into(), "e".into()].into()
249        );
250
251        assert_eq!(
252            primitive_field(&actual_array, &["a"])
253                .unwrap()
254                .as_slice::<i32>(),
255            [0, 0, 0]
256        );
257        assert_eq!(
258            primitive_field(&actual_array, &["b"])
259                .unwrap()
260                .as_slice::<i32>(),
261            [2, 2, 2]
262        );
263        assert_eq!(
264            primitive_field(&actual_array, &["c"])
265                .unwrap()
266                .as_slice::<i32>(),
267            [3, 3, 3]
268        );
269        assert_eq!(
270            primitive_field(&actual_array, &["d"])
271                .unwrap()
272                .as_slice::<i32>(),
273            [4, 4, 4]
274        );
275        assert_eq!(
276            primitive_field(&actual_array, &["e"])
277                .unwrap()
278                .as_slice::<i32>(),
279            [5, 5, 5]
280        );
281    }
282
283    #[test]
284    pub fn test_empty_merge() {
285        let expr = Merge::new_expr(Vec::new());
286
287        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
288            .unwrap()
289            .into_array();
290        let actual_array = expr.evaluate(&test_array).unwrap();
291        assert_eq!(actual_array.len(), test_array.len());
292        assert_eq!(actual_array.as_struct_typed().unwrap().nfields(), 0);
293    }
294
295    #[test]
296    pub fn test_nested_merge() {
297        // Nested structs are not merged!
298
299        let expr = Merge::new_expr(vec![
300            GetItem::new_expr("0", Identity::new_expr()),
301            GetItem::new_expr("1", Identity::new_expr()),
302        ]);
303
304        let test_array = StructArray::from_fields(&[
305            (
306                "0",
307                StructArray::from_fields(&[(
308                    "a",
309                    StructArray::from_fields(&[
310                        ("x", buffer![0, 0, 0].into_array()),
311                        ("y", buffer![1, 1, 1].into_array()),
312                    ])
313                    .unwrap()
314                    .into_array(),
315                )])
316                .unwrap()
317                .into_array(),
318            ),
319            (
320                "1",
321                StructArray::from_fields(&[(
322                    "a",
323                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
324                        .unwrap()
325                        .into_array(),
326                )])
327                .unwrap()
328                .into_array(),
329            ),
330        ])
331        .unwrap()
332        .into_array();
333        let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
334
335        assert_eq!(
336            actual_array
337                .as_struct_typed()
338                .unwrap()
339                .maybe_null_field_by_name("a")
340                .unwrap()
341                .as_struct_typed()
342                .unwrap()
343                .names()
344                .iter()
345                .map(|name| name.as_ref())
346                .collect::<Vec<_>>(),
347            vec!["x"]
348        );
349    }
350
351    #[test]
352    pub fn test_merge_order() {
353        let expr = Merge::new_expr(vec![
354            GetItem::new_expr("0", Identity::new_expr()),
355            GetItem::new_expr("1", Identity::new_expr()),
356        ]);
357
358        let test_array = StructArray::from_fields(&[
359            (
360                "0",
361                StructArray::from_fields(&[
362                    ("a", buffer![0, 0, 0].into_array()),
363                    ("c", buffer![1, 1, 1].into_array()),
364                ])
365                .unwrap()
366                .into_array(),
367            ),
368            (
369                "1",
370                StructArray::from_fields(&[
371                    ("b", buffer![2, 2, 2].into_array()),
372                    ("d", buffer![3, 3, 3].into_array()),
373                ])
374                .unwrap()
375                .into_array(),
376            ),
377        ])
378        .unwrap()
379        .into_array();
380        let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
381
382        assert_eq!(
383            actual_array.as_struct_typed().unwrap().names(),
384            &["a".into(), "c".into(), "b".into(), "d".into()].into()
385        );
386    }
387}