vortex_expr/
merge.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
10use vortex_dtype::{DType, FieldNames, Nullability, StructDType};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::{ExprRef, VortexExpr};
14
15/// Merge zero or more expressions that ALL return structs.
16///
17/// If any field names are duplicated, the field from later expressions wins.
18///
19/// NOTE: Fields are not recursively merged, i.e. the later field REPLACES the earlier field.
20/// This makes struct fields behaviour consistent with other dtypes.
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct Merge {
23    values: Vec<ExprRef>,
24    nullability: Nullability,
25}
26
27impl Merge {
28    pub fn new_expr(values: Vec<ExprRef>, nullability: Nullability) -> Arc<Self> {
29        Arc::new(Merge {
30            values,
31            nullability,
32        })
33    }
34
35    pub fn nullability(&self) -> Nullability {
36        self.nullability
37    }
38}
39
40pub fn merge(
41    elements: impl IntoIterator<Item = impl Into<ExprRef>>,
42    nullability: Nullability,
43) -> ExprRef {
44    let values = elements.into_iter().map(|value| value.into()).collect_vec();
45    Merge::new_expr(values, nullability)
46}
47
48impl Display for Merge {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        f.write_str("{")?;
51        self.values
52            .iter()
53            .format_with(", ", |expr, f| f(expr))
54            .fmt(f)?;
55        f.write_str("}")
56    }
57}
58
59#[cfg(feature = "proto")]
60pub(crate) mod proto {
61    use vortex_error::{VortexResult, vortex_bail};
62    use vortex_proto::expr::kind::Kind;
63
64    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Merge};
65
66    pub struct MergeSerde;
67
68    impl Id for MergeSerde {
69        fn id(&self) -> &'static str {
70            "merge"
71        }
72    }
73
74    impl ExprDeserialize for MergeSerde {
75        fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
76            vortex_bail!(NotImplemented: "", self.id())
77        }
78    }
79
80    impl ExprSerializable for Merge {
81        fn id(&self) -> &'static str {
82            MergeSerde.id()
83        }
84
85        fn serialize_kind(&self) -> VortexResult<Kind> {
86            vortex_bail!(NotImplemented: "", self.id())
87        }
88    }
89}
90
91impl VortexExpr for Merge {
92    fn as_any(&self) -> &dyn Any {
93        self
94    }
95
96    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
97        let len = batch.len();
98        let value_arrays = self
99            .values
100            .iter()
101            .map(|value_expr| value_expr.evaluate(batch))
102            .process_results(|it| it.collect::<Vec<_>>())?;
103
104        // Collect fields in order of appearance. Later fields overwrite earlier fields.
105        let mut field_names = Vec::new();
106        let mut arrays = Vec::new();
107
108        for value_array in value_arrays.iter() {
109            // TODO(marko): When nullable, we need to merge struct validity into field validity.
110            if value_array.dtype().is_nullable() {
111                todo!("merge nullable structs");
112            }
113            if !value_array.dtype().is_struct() {
114                vortex_bail!("merge expects non-nullable struct input");
115            }
116
117            let struct_array = value_array.to_struct()?;
118
119            for (i, field_name) in struct_array.names().iter().enumerate() {
120                let array = struct_array.fields()[i].clone();
121
122                // Update or insert field.
123                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
124                    arrays[idx] = array;
125                } else {
126                    field_names.push(field_name.clone());
127                    arrays.push(array);
128                }
129            }
130        }
131
132        let validity = match self.nullability {
133            Nullability::NonNullable => Validity::NonNullable,
134            Nullability::Nullable => Validity::AllValid,
135        };
136        Ok(
137            StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
138                .into_array(),
139        )
140    }
141
142    fn children(&self) -> Vec<&ExprRef> {
143        self.values.iter().collect()
144    }
145
146    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
147        Self::new_expr(children, self.nullability)
148    }
149
150    fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
151        let mut field_names = Vec::new();
152        let mut arrays = Vec::new();
153
154        for value in self.values.iter() {
155            let dtype = value.return_dtype(scope_dtype)?;
156            if !dtype.is_struct() {
157                vortex_bail!("merge expects non-nullable struct input");
158            }
159
160            let struct_dtype = dtype
161                .as_struct()
162                .vortex_expect("merge expects struct input");
163
164            for i in 0..struct_dtype.nfields() {
165                let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
166                let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
167                if let Some(idx) = field_names.iter().position(|name| name == field_name) {
168                    arrays[idx] = field_dtype;
169                } else {
170                    field_names.push(field_name.clone());
171                    arrays.push(field_dtype);
172                }
173            }
174        }
175
176        Ok(DType::Struct(
177            Arc::new(StructDType::new(FieldNames::from(field_names), arrays)),
178            self.nullability,
179        ))
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use vortex_array::arrays::{PrimitiveArray, StructArray};
186    use vortex_array::{Array, IntoArray, ToCanonical};
187    use vortex_buffer::buffer;
188    use vortex_dtype::Nullability;
189    use vortex_error::{VortexResult, vortex_bail};
190
191    use crate::{GetItem, Identity, Merge, VortexExpr};
192
193    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
194        let mut field_path = field_path.iter();
195
196        let Some(field) = field_path.next() else {
197            vortex_bail!("empty field path");
198        };
199
200        let mut array = array.to_struct()?.field_by_name(field)?.clone();
201        for field in field_path {
202            array = array.to_struct()?.field_by_name(field)?.clone();
203        }
204        Ok(array.to_primitive().unwrap())
205    }
206
207    #[test]
208    pub fn test_merge() {
209        let expr = Merge::new_expr(
210            vec![
211                GetItem::new_expr("0", Identity::new_expr()),
212                GetItem::new_expr("1", Identity::new_expr()),
213                GetItem::new_expr("2", Identity::new_expr()),
214            ],
215            Nullability::NonNullable,
216        );
217
218        let test_array = StructArray::from_fields(&[
219            (
220                "0",
221                StructArray::from_fields(&[
222                    ("a", buffer![0, 0, 0].into_array()),
223                    ("b", buffer![1, 1, 1].into_array()),
224                ])
225                .unwrap()
226                .into_array(),
227            ),
228            (
229                "1",
230                StructArray::from_fields(&[
231                    ("b", buffer![2, 2, 2].into_array()),
232                    ("c", buffer![3, 3, 3].into_array()),
233                ])
234                .unwrap()
235                .into_array(),
236            ),
237            (
238                "2",
239                StructArray::from_fields(&[
240                    ("d", buffer![4, 4, 4].into_array()),
241                    ("e", buffer![5, 5, 5].into_array()),
242                ])
243                .unwrap()
244                .into_array(),
245            ),
246        ])
247        .unwrap()
248        .into_array();
249        let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
250
251        assert_eq!(
252            actual_array.as_struct_typed().names(),
253            &["a".into(), "b".into(), "c".into(), "d".into(), "e".into()].into()
254        );
255
256        assert_eq!(
257            primitive_field(&actual_array, &["a"])
258                .unwrap()
259                .as_slice::<i32>(),
260            [0, 0, 0]
261        );
262        assert_eq!(
263            primitive_field(&actual_array, &["b"])
264                .unwrap()
265                .as_slice::<i32>(),
266            [2, 2, 2]
267        );
268        assert_eq!(
269            primitive_field(&actual_array, &["c"])
270                .unwrap()
271                .as_slice::<i32>(),
272            [3, 3, 3]
273        );
274        assert_eq!(
275            primitive_field(&actual_array, &["d"])
276                .unwrap()
277                .as_slice::<i32>(),
278            [4, 4, 4]
279        );
280        assert_eq!(
281            primitive_field(&actual_array, &["e"])
282                .unwrap()
283                .as_slice::<i32>(),
284            [5, 5, 5]
285        );
286    }
287
288    #[test]
289    pub fn test_empty_merge() {
290        let expr = Merge::new_expr(Vec::new(), Nullability::NonNullable);
291
292        let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
293            .unwrap()
294            .into_array();
295        let actual_array = expr.evaluate(&test_array).unwrap();
296        assert_eq!(actual_array.len(), test_array.len());
297        assert_eq!(actual_array.as_struct_typed().nfields(), 0);
298    }
299
300    #[test]
301    pub fn test_nested_merge() {
302        // Nested structs are not merged!
303
304        let expr = Merge::new_expr(
305            vec![
306                GetItem::new_expr("0", Identity::new_expr()),
307                GetItem::new_expr("1", Identity::new_expr()),
308            ],
309            Nullability::NonNullable,
310        );
311
312        let test_array = StructArray::from_fields(&[
313            (
314                "0",
315                StructArray::from_fields(&[(
316                    "a",
317                    StructArray::from_fields(&[
318                        ("x", buffer![0, 0, 0].into_array()),
319                        ("y", buffer![1, 1, 1].into_array()),
320                    ])
321                    .unwrap()
322                    .into_array(),
323                )])
324                .unwrap()
325                .into_array(),
326            ),
327            (
328                "1",
329                StructArray::from_fields(&[(
330                    "a",
331                    StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
332                        .unwrap()
333                        .into_array(),
334                )])
335                .unwrap()
336                .into_array(),
337            ),
338        ])
339        .unwrap()
340        .into_array();
341        let actual_array = expr
342            .evaluate(test_array.as_ref())
343            .unwrap()
344            .to_struct()
345            .unwrap();
346
347        assert_eq!(
348            actual_array
349                .field_by_name("a")
350                .unwrap()
351                .to_struct()
352                .unwrap()
353                .names()
354                .iter()
355                .map(|name| name.as_ref())
356                .collect::<Vec<_>>(),
357            vec!["x"]
358        );
359    }
360
361    #[test]
362    pub fn test_merge_order() {
363        let expr = Merge::new_expr(
364            vec![
365                GetItem::new_expr("0", Identity::new_expr()),
366                GetItem::new_expr("1", Identity::new_expr()),
367            ],
368            Nullability::NonNullable,
369        );
370
371        let test_array = StructArray::from_fields(&[
372            (
373                "0",
374                StructArray::from_fields(&[
375                    ("a", buffer![0, 0, 0].into_array()),
376                    ("c", buffer![1, 1, 1].into_array()),
377                ])
378                .unwrap()
379                .into_array(),
380            ),
381            (
382                "1",
383                StructArray::from_fields(&[
384                    ("b", buffer![2, 2, 2].into_array()),
385                    ("d", buffer![3, 3, 3].into_array()),
386                ])
387                .unwrap()
388                .into_array(),
389            ),
390        ])
391        .unwrap()
392        .into_array();
393        let actual_array = expr
394            .evaluate(test_array.as_ref())
395            .unwrap()
396            .to_struct()
397            .unwrap();
398
399        assert_eq!(
400            actual_array.names(),
401            &["a".into(), "c".into(), "b".into(), "d".into()].into()
402        );
403    }
404
405    #[test]
406    pub fn test_merge_nullable() {
407        let expr = Merge::new_expr(
408            vec![GetItem::new_expr("0", Identity::new_expr())],
409            Nullability::Nullable,
410        );
411
412        let test_array = StructArray::from_fields(&[(
413            "0",
414            StructArray::from_fields(&[
415                ("a", buffer![0, 0, 0].into_array()),
416                ("b", buffer![1, 1, 1].into_array()),
417            ])
418            .unwrap()
419            .into_array(),
420        )])
421        .unwrap()
422        .into_array();
423        let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
424        assert!(actual_array.dtype().is_nullable());
425    }
426}