vortex_expr/
merge.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
10use vortex_dtype::{DType, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
14
15/// Merge zero or more expressions that ALL return structs.
16///
17/// If any field names are duplicated, the field from later expressions wins.
18///
19/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
20/// This makes struct fields behaviour consistent with other dtypes.
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct Merge {
23    values: Vec<ExprRef>,
24    nullability: Nullability,
25}
26
27impl Merge {
28    pub fn new_expr(values: Vec<ExprRef>, nullability: Nullability) -> Arc<Self> {
29        Arc::new(Merge {
30            values,
31            nullability,
32        })
33    }
34
35    pub fn nullability(&self) -> Nullability {
36        self.nullability
37    }
38}
39
40pub fn merge(
41    elements: impl IntoIterator<Item = impl Into<ExprRef>>,
42    nullability: Nullability,
43) -> ExprRef {
44    let values = elements.into_iter().map(|value| value.into()).collect_vec();
45    Merge::new_expr(values, nullability)
46}
47
48impl Display for Merge {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        f.write_str("{")?;
51        self.values
52            .iter()
53            .format_with(", ", |expr, f| f(expr))
54            .fmt(f)?;
55        f.write_str("}")
56    }
57}
58
59#[cfg(feature = "proto")]
60pub(crate) mod proto {
61    use vortex_error::{VortexResult, vortex_bail};
62    use vortex_proto::expr::kind::Kind;
63
64    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Merge};
65
66    pub struct MergeSerde;
67
68    impl Id for MergeSerde {
69        fn id(&self) -> &'static str {
70            "merge"
71        }
72    }
73
74    impl ExprDeserialize for MergeSerde {
75        fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
76            vortex_bail!(NotImplemented: "", self.id())
77        }
78    }
79
80    impl ExprSerializable for Merge {
81        fn id(&self) -> &'static str {
82            MergeSerde.id()
83        }
84
85        fn serialize_kind(&self) -> VortexResult<Kind> {
86            vortex_bail!(NotImplemented: "", self.id())
87        }
88    }
89}
90
91impl AnalysisExpr for Merge {}
92
93impl VortexExpr for Merge {
94    fn as_any(&self) -> &dyn Any {
95        self
96    }
97
98    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef> {
99        let len = ctx.len();
100        let value_arrays = self
101            .values
102            .iter()
103            .map(|value_expr| value_expr.unchecked_evaluate(ctx))
104            .process_results(|it| it.collect::<Vec<_>>())?;
105
106        // Collect fields in order of appearance. Later fields overwrite earlier fields.
107        let mut field_names = Vec::new();
108        let mut arrays = Vec::new();
109
110        for value_array in value_arrays.iter() {
111            // TODO(marko): When nullable, we need to merge struct validity into field validity.
112            if value_array.dtype().is_nullable() {
113                todo!("merge nullable structs");
114            }
115            if !value_array.dtype().is_struct() {
116                vortex_bail!("merge expects non-nullable struct input");
117            }
118
119            let struct_array = value_array.to_struct()?;
120
121            for (i, field_name) in struct_array.names().iter().enumerate() {
122                let array = struct_array.fields()[i].clone();
123
124                // Update or insert field.
125                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
126                    arrays[idx] = array;
127                } else {
128                    field_names.push(field_name.clone());
129                    arrays.push(array);
130                }
131            }
132        }
133
134        let validity = match self.nullability {
135            Nullability::NonNullable => Validity::NonNullable,
136            Nullability::Nullable => Validity::AllValid,
137        };
138        Ok(
139            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
140                .into_array(),
141        )
142    }
143
144    fn children(&self) -> Vec<&ExprRef> {
145        self.values.iter().collect()
146    }
147
148    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
149        Self::new_expr(children, self.nullability)
150    }
151
152    fn return_dtype(&self, ctx: &ScopeDType) -> VortexResult<DType> {
153        let mut field_names = Vec::new();
154        let mut arrays = Vec::new();
155
156        for value in self.values.iter() {
157            let dtype = value.return_dtype(ctx)?;
158            if !dtype.is_struct() {
159                vortex_bail!("merge expects non-nullable struct input");
160            }
161
162            let struct_dtype = dtype
163                .as_struct()
164                .vortex_expect("merge expects struct input");
165
166            for i in 0..struct_dtype.nfields() {
167                let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
168                let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
169                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
170                    arrays[idx] = field_dtype;
171                } else {
172                    field_names.push(field_name.clone());
173                    arrays.push(field_dtype);
174                }
175            }
176        }
177
178        Ok(DType::Struct(
179            Arc::new(StructFields::new(FieldNames::from(field_names), arrays)),
180            self.nullability,
181        ))
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use vortex_array::arrays::{PrimitiveArray, StructArray};
188    use vortex_array::{Array, IntoArray, ToCanonical};
189    use vortex_buffer::buffer;
190    use vortex_dtype::Nullability;
191    use vortex_error::{VortexResult, vortex_bail};
192
193    use crate::{Merge, Scope, VortexExpr, get_item, root};
194
195    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
196        let mut field_path = field_path.iter();
197
198        let Some(field) = field_path.next() else {
199            vortex_bail!("empty field path");
200        };
201
202        let mut array = array.to_struct()?.field_by_name(field)?.clone();
203        for field in field_path {
204            array = array.to_struct()?.field_by_name(field)?.clone();
205        }
206        Ok(array.to_primitive().unwrap())
207    }
208
209    #[test]
210    pub fn test_merge() {
211        let expr = Merge::new_expr(
212            vec![
213                get_item("0", root()),
214                get_item("1", root()),
215                get_item("2", root()),
216            ],
217            Nullability::NonNullable,
218        );
219
220        let test_array = StructArray::from_fields(&[
221            (
222                "0",
223                StructArray::from_fields(&[
224                    ("a", buffer![0, 0, 0].into_array()),
225                    ("b", buffer![1, 1, 1].into_array()),
226                ])
227                .unwrap()
228                .into_array(),
229            ),
230            (
231                "1",
232                StructArray::from_fields(&[
233                    ("b", buffer![2, 2, 2].into_array()),
234                    ("c", buffer![3, 3, 3].into_array()),
235                ])
236                .unwrap()
237                .into_array(),
238            ),
239            (
240                "2",
241                StructArray::from_fields(&[
242                    ("d", buffer![4, 4, 4].into_array()),
243                    ("e", buffer![5, 5, 5].into_array()),
244                ])
245                .unwrap()
246                .into_array(),
247            ),
248        ])
249        .unwrap()
250        .into_array();
251        let actual_array = expr.evaluate(&Scope::new(test_array)).unwrap();
252
253        assert_eq!(
254            actual_array.as_struct_typed().names(),
255            &["a".into(), "b".into(), "c".into(), "d".into(), "e".into()].into()
256        );
257
258        assert_eq!(
259            primitive_field(&actual_array, &["a"])
260                .unwrap()
261                .as_slice::<i32>(),
262            [0, 0, 0]
263        );
264        assert_eq!(
265            primitive_field(&actual_array, &["b"])
266                .unwrap()
267                .as_slice::<i32>(),
268            [2, 2, 2]
269        );
270        assert_eq!(
271            primitive_field(&actual_array, &["c"])
272                .unwrap()
273                .as_slice::<i32>(),
274            [3, 3, 3]
275        );
276        assert_eq!(
277            primitive_field(&actual_array, &["d"])
278                .unwrap()
279                .as_slice::<i32>(),
280            [4, 4, 4]
281        );
282        assert_eq!(
283            primitive_field(&actual_array, &["e"])
284                .unwrap()
285                .as_slice::<i32>(),
286            [5, 5, 5]
287        );
288    }
289
290    #[test]
291    pub fn test_empty_merge() {
292        let expr = Merge::new_expr(Vec::new(), Nullability::NonNullable);
293
294        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
295            .unwrap()
296            .into_array();
297        let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
298        assert_eq!(actual_array.len(), test_array.len());
299        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
300    }
301
302    #[test]
303    pub fn test_nested_merge() {
304        // Nested structs are not merged!
305
306        let expr = Merge::new_expr(
307            vec![get_item("0", root()), get_item("1", root())],
308            Nullability::NonNullable,
309        );
310
311        let test_array = StructArray::from_fields(&[
312            (
313                "0",
314                StructArray::from_fields(&[(
315                    "a",
316                    StructArray::from_fields(&[
317                        ("x", buffer![0, 0, 0].into_array()),
318                        ("y", buffer![1, 1, 1].into_array()),
319                    ])
320                    .unwrap()
321                    .into_array(),
322                )])
323                .unwrap()
324                .into_array(),
325            ),
326            (
327                "1",
328                StructArray::from_fields(&[(
329                    "a",
330                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
331                        .unwrap()
332                        .into_array(),
333                )])
334                .unwrap()
335                .into_array(),
336            ),
337        ])
338        .unwrap()
339        .into_array();
340        let actual_array = expr
341            .evaluate(&Scope::new(test_array.clone()))
342            .unwrap()
343            .to_struct()
344            .unwrap();
345
346        assert_eq!(
347            actual_array
348                .field_by_name("a")
349                .unwrap()
350                .to_struct()
351                .unwrap()
352                .names()
353                .iter()
354                .map(|name| name.as_ref())
355                .collect::<Vec<_>>(),
356            vec!["x"]
357        );
358    }
359
360    #[test]
361    pub fn test_merge_order() {
362        let expr = Merge::new_expr(
363            vec![get_item("0", root()), get_item("1", root())],
364            Nullability::NonNullable,
365        );
366
367        let test_array = StructArray::from_fields(&[
368            (
369                "0",
370                StructArray::from_fields(&[
371                    ("a", buffer![0, 0, 0].into_array()),
372                    ("c", buffer![1, 1, 1].into_array()),
373                ])
374                .unwrap()
375                .into_array(),
376            ),
377            (
378                "1",
379                StructArray::from_fields(&[
380                    ("b", buffer![2, 2, 2].into_array()),
381                    ("d", buffer![3, 3, 3].into_array()),
382                ])
383                .unwrap()
384                .into_array(),
385            ),
386        ])
387        .unwrap()
388        .into_array();
389        let actual_array = expr
390            .evaluate(&Scope::new(test_array.clone()))
391            .unwrap()
392            .to_struct()
393            .unwrap();
394
395        assert_eq!(
396            actual_array.names(),
397            &["a".into(), "c".into(), "b".into(), "d".into()].into()
398        );
399    }
400
401    #[test]
402    pub fn test_merge_nullable() {
403        let expr = Merge::new_expr(vec![get_item("0", root())], Nullability::Nullable);
404
405        let test_array = StructArray::from_fields(&[(
406            "0",
407            StructArray::from_fields(&[
408                ("a", buffer![0, 0, 0].into_array()),
409                ("b", buffer![1, 1, 1].into_array()),
410            ])
411            .unwrap()
412            .into_array(),
413        )])
414        .unwrap()
415        .into_array();
416        let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
417        assert!(actual_array.dtype().is_nullable());
418    }
419}