1use std::hash::Hash;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, DeserializeMetadata, EmptyMetadata, IntoArray, ToCanonical};
10use vortex_dtype::{DType, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
15
16vtable!(Merge);
17
18#[allow(clippy::derived_hash_with_manual_eq)]
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub struct MergeExpr {
27 values: Vec<ExprRef>,
28}
29
30pub struct MergeExprEncoding;
31
32impl VTable for MergeVTable {
33 type Expr = MergeExpr;
34 type Encoding = MergeExprEncoding;
35 type Metadata = EmptyMetadata;
36
37 fn id(_encoding: &Self::Encoding) -> ExprId {
38 ExprId::new_ref("merge")
39 }
40
41 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
42 ExprEncodingRef::new_ref(MergeExprEncoding.as_ref())
43 }
44
45 fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
46 Some(EmptyMetadata)
47 }
48
49 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
50 expr.values.iter().collect()
51 }
52
53 fn with_children(_expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
54 Ok(MergeExpr { values: children })
55 }
56
57 fn build(
58 _encoding: &Self::Encoding,
59 _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
60 children: Vec<ExprRef>,
61 ) -> VortexResult<Self::Expr> {
62 if children.is_empty() {
63 vortex_bail!(
64 "Merge expression must have at least one child, got: {:?}",
65 children
66 );
67 }
68 Ok(MergeExpr { values: children })
69 }
70
71 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
72 let len = scope.len();
73 let value_arrays = expr
74 .values
75 .iter()
76 .map(|value_expr| value_expr.unchecked_evaluate(scope))
77 .process_results(|it| it.collect::<Vec<_>>())?;
78
79 let mut field_names = Vec::new();
81 let mut arrays = Vec::new();
82
83 for value_array in value_arrays.iter() {
84 if value_array.dtype().is_nullable() {
86 todo!("merge nullable structs");
87 }
88 if !value_array.dtype().is_struct() {
89 vortex_bail!("merge expects non-nullable struct input");
90 }
91
92 let struct_array = value_array.to_struct();
93
94 for (field_name, array) in struct_array
95 .names()
96 .iter()
97 .zip_eq(struct_array.fields().iter().cloned())
98 {
99 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
101 arrays[idx] = array;
102 } else {
103 field_names.push(field_name.clone());
104 arrays.push(array);
105 }
106 }
107 }
108
109 let validity = Validity::NonNullable;
111 Ok(
112 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
113 .into_array(),
114 )
115 }
116
117 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
118 let mut field_names = Vec::new();
119 let mut arrays = Vec::new();
120
121 let mut nullability = Nullability::NonNullable;
122
123 for value in expr.values.iter() {
124 let dtype = value.return_dtype(scope)?;
125 if !dtype.is_struct() {
126 vortex_bail!("merge expects struct input");
127 }
128 if dtype.is_nullable() {
129 vortex_bail!("merge expects non-nullable input");
130 }
131 nullability |= dtype.nullability();
132
133 let struct_dtype = dtype
134 .as_struct_fields_opt()
135 .vortex_expect("merge expects struct input");
136
137 for i in 0..struct_dtype.nfields() {
138 let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
139 let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
140 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
141 arrays[idx] = field_dtype;
142 } else {
143 field_names.push(field_name.clone());
144 arrays.push(field_dtype);
145 }
146 }
147 }
148
149 Ok(DType::Struct(
150 StructFields::new(FieldNames::from(field_names), arrays),
151 nullability,
152 ))
153 }
154}
155
156impl MergeExpr {
157 pub fn new(values: Vec<ExprRef>) -> Self {
158 MergeExpr { values }
159 }
160
161 pub fn new_expr(values: Vec<ExprRef>) -> ExprRef {
162 Self::new(values).into_expr()
163 }
164}
165
166pub fn merge(elements: impl IntoIterator<Item = impl Into<ExprRef>>) -> ExprRef {
177 let values = elements.into_iter().map(|value| value.into()).collect_vec();
178 MergeExpr::new(values).into_expr()
179}
180
181impl DisplayAs for MergeExpr {
182 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
183 match df {
184 DisplayFormat::Compact => {
185 write!(f, "merge({})", self.values.iter().format(", "),)
186 }
187 DisplayFormat::Tree => {
188 write!(f, "Merge")
189 }
190 }
191 }
192}
193
194impl AnalysisExpr for MergeExpr {}
195
196#[cfg(test)]
197mod tests {
198 use vortex_array::arrays::{PrimitiveArray, StructArray};
199 use vortex_array::{Array, IntoArray, ToCanonical};
200 use vortex_buffer::buffer;
201 use vortex_error::{VortexResult, vortex_bail};
202
203 use crate::{MergeExpr, Scope, get_item, merge, root};
204
205 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
206 let mut field_path = field_path.iter();
207
208 let Some(field) = field_path.next() else {
209 vortex_bail!("empty field path");
210 };
211
212 let mut array = array.to_struct().field_by_name(field)?.clone();
213 for field in field_path {
214 array = array.to_struct().field_by_name(field)?.clone();
215 }
216 Ok(array.to_primitive())
217 }
218
219 #[test]
220 pub fn test_merge() {
221 let expr = MergeExpr::new(vec![
222 get_item("0", root()),
223 get_item("1", root()),
224 get_item("2", root()),
225 ]);
226
227 let test_array = StructArray::from_fields(&[
228 (
229 "0",
230 StructArray::from_fields(&[
231 ("a", buffer![0, 0, 0].into_array()),
232 ("b", buffer![1, 1, 1].into_array()),
233 ])
234 .unwrap()
235 .into_array(),
236 ),
237 (
238 "1",
239 StructArray::from_fields(&[
240 ("b", buffer![2, 2, 2].into_array()),
241 ("c", buffer![3, 3, 3].into_array()),
242 ])
243 .unwrap()
244 .into_array(),
245 ),
246 (
247 "2",
248 StructArray::from_fields(&[
249 ("d", buffer![4, 4, 4].into_array()),
250 ("e", buffer![5, 5, 5].into_array()),
251 ])
252 .unwrap()
253 .into_array(),
254 ),
255 ])
256 .unwrap()
257 .into_array();
258 let actual_array = expr.evaluate(&Scope::new(test_array)).unwrap();
259
260 assert_eq!(
261 actual_array.as_struct_typed().names(),
262 ["a", "b", "c", "d", "e"]
263 );
264
265 assert_eq!(
266 primitive_field(&actual_array, &["a"])
267 .unwrap()
268 .as_slice::<i32>(),
269 [0, 0, 0]
270 );
271 assert_eq!(
272 primitive_field(&actual_array, &["b"])
273 .unwrap()
274 .as_slice::<i32>(),
275 [2, 2, 2]
276 );
277 assert_eq!(
278 primitive_field(&actual_array, &["c"])
279 .unwrap()
280 .as_slice::<i32>(),
281 [3, 3, 3]
282 );
283 assert_eq!(
284 primitive_field(&actual_array, &["d"])
285 .unwrap()
286 .as_slice::<i32>(),
287 [4, 4, 4]
288 );
289 assert_eq!(
290 primitive_field(&actual_array, &["e"])
291 .unwrap()
292 .as_slice::<i32>(),
293 [5, 5, 5]
294 );
295 }
296
297 #[test]
298 pub fn test_empty_merge() {
299 let expr = MergeExpr::new(Vec::new());
300
301 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
302 .unwrap()
303 .into_array();
304 let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
305 assert_eq!(actual_array.len(), test_array.len());
306 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
307 }
308
309 #[test]
310 pub fn test_nested_merge() {
311 let expr = MergeExpr::new(vec![get_item("0", root()), get_item("1", root())]);
314
315 let test_array = StructArray::from_fields(&[
316 (
317 "0",
318 StructArray::from_fields(&[(
319 "a",
320 StructArray::from_fields(&[
321 ("x", buffer![0, 0, 0].into_array()),
322 ("y", buffer![1, 1, 1].into_array()),
323 ])
324 .unwrap()
325 .into_array(),
326 )])
327 .unwrap()
328 .into_array(),
329 ),
330 (
331 "1",
332 StructArray::from_fields(&[(
333 "a",
334 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
335 .unwrap()
336 .into_array(),
337 )])
338 .unwrap()
339 .into_array(),
340 ),
341 ])
342 .unwrap()
343 .into_array();
344 let actual_array = expr
345 .evaluate(&Scope::new(test_array.clone()))
346 .unwrap()
347 .to_struct();
348
349 assert_eq!(
350 actual_array
351 .field_by_name("a")
352 .unwrap()
353 .to_struct()
354 .names()
355 .iter()
356 .map(|name| name.as_ref())
357 .collect::<Vec<_>>(),
358 vec!["x"]
359 );
360 }
361
362 #[test]
363 pub fn test_merge_order() {
364 let expr = MergeExpr::new(vec![get_item("0", root()), get_item("1", root())]);
365
366 let test_array = StructArray::from_fields(&[
367 (
368 "0",
369 StructArray::from_fields(&[
370 ("a", buffer![0, 0, 0].into_array()),
371 ("c", buffer![1, 1, 1].into_array()),
372 ])
373 .unwrap()
374 .into_array(),
375 ),
376 (
377 "1",
378 StructArray::from_fields(&[
379 ("b", buffer![2, 2, 2].into_array()),
380 ("d", buffer![3, 3, 3].into_array()),
381 ])
382 .unwrap()
383 .into_array(),
384 ),
385 ])
386 .unwrap()
387 .into_array();
388 let actual_array = expr
389 .evaluate(&Scope::new(test_array.clone()))
390 .unwrap()
391 .to_struct();
392
393 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
394 }
395
396 #[test]
397 pub fn test_display() {
398 let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
399 assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
400
401 let expr2 = MergeExpr::new(vec![get_item("a", root())]);
402 assert_eq!(expr2.to_string(), "merge($.a)");
403 }
404}