1use std::hash::Hash;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, DeserializeMetadata, EmptyMetadata, IntoArray, ToCanonical};
10use vortex_dtype::{DType, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
15
16vtable!(Merge);
17
18#[allow(clippy::derived_hash_with_manual_eq)]
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26pub struct MergeExpr {
27 values: Vec<ExprRef>,
28 nullability: Nullability,
29}
30
31pub struct MergeExprEncoding;
32
33impl VTable for MergeVTable {
34 type Expr = MergeExpr;
35 type Encoding = MergeExprEncoding;
36 type Metadata = EmptyMetadata;
37
38 fn id(_encoding: &Self::Encoding) -> ExprId {
39 ExprId::new_ref("merge")
40 }
41
42 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
43 ExprEncodingRef::new_ref(MergeExprEncoding.as_ref())
44 }
45
46 fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
47 Some(EmptyMetadata)
48 }
49
50 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
51 expr.values.iter().collect()
52 }
53
54 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
55 Ok(MergeExpr {
56 values: children,
57 nullability: expr.nullability,
58 })
59 }
60
61 fn build(
62 _encoding: &Self::Encoding,
63 _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
64 children: Vec<ExprRef>,
65 ) -> VortexResult<Self::Expr> {
66 if children.is_empty() {
67 vortex_bail!(
68 "Merge expression must have at least one child, got: {:?}",
69 children
70 );
71 }
72 Ok(MergeExpr {
73 values: children,
74 nullability: Nullability::NonNullable, })
76 }
77
78 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
79 let len = scope.len();
80 let value_arrays = expr
81 .values
82 .iter()
83 .map(|value_expr| value_expr.unchecked_evaluate(scope))
84 .process_results(|it| it.collect::<Vec<_>>())?;
85
86 let mut field_names = Vec::new();
88 let mut arrays = Vec::new();
89
90 for value_array in value_arrays.iter() {
91 if value_array.dtype().is_nullable() {
93 todo!("merge nullable structs");
94 }
95 if !value_array.dtype().is_struct() {
96 vortex_bail!("merge expects non-nullable struct input");
97 }
98
99 let struct_array = value_array.to_struct();
100
101 for (i, field_name) in struct_array.names().iter().enumerate() {
102 let array = struct_array.fields()[i].clone();
103
104 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
106 arrays[idx] = array;
107 } else {
108 field_names.push(field_name.clone());
109 arrays.push(array);
110 }
111 }
112 }
113
114 let validity = match expr.nullability {
115 Nullability::NonNullable => Validity::NonNullable,
116 Nullability::Nullable => Validity::AllValid,
117 };
118 Ok(
119 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
120 .into_array(),
121 )
122 }
123
124 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
125 let mut field_names = Vec::new();
126 let mut arrays = Vec::new();
127
128 for value in expr.values.iter() {
129 let dtype = value.return_dtype(scope)?;
130 if !dtype.is_struct() {
131 vortex_bail!("merge expects non-nullable struct input");
132 }
133
134 let struct_dtype = dtype
135 .as_struct_fields_opt()
136 .vortex_expect("merge expects struct input");
137
138 for i in 0..struct_dtype.nfields() {
139 let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
140 let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
141 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
142 arrays[idx] = field_dtype;
143 } else {
144 field_names.push(field_name.clone());
145 arrays.push(field_dtype);
146 }
147 }
148 }
149
150 Ok(DType::Struct(
151 StructFields::new(FieldNames::from(field_names), arrays),
152 expr.nullability,
153 ))
154 }
155}
156
157impl MergeExpr {
158 pub fn new(values: Vec<ExprRef>, nullability: Nullability) -> Self {
159 MergeExpr {
160 values,
161 nullability,
162 }
163 }
164
165 pub fn new_expr(values: Vec<ExprRef>, nullability: Nullability) -> ExprRef {
166 Self::new(values, nullability).into_expr()
167 }
168
169 pub fn nullability(&self) -> Nullability {
170 self.nullability
171 }
172}
173
174pub fn merge(
185 elements: impl IntoIterator<Item = impl Into<ExprRef>>,
186 nullability: Nullability,
187) -> ExprRef {
188 let values = elements.into_iter().map(|value| value.into()).collect_vec();
189 MergeExpr::new(values, nullability).into_expr()
190}
191
192impl DisplayAs for MergeExpr {
193 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
194 match df {
195 DisplayFormat::Compact => {
196 write!(
197 f,
198 "merge({}){}",
199 self.values.iter().format(", "),
200 self.nullability
201 )
202 }
203 DisplayFormat::Tree => {
204 write!(f, "Merge")
205 }
206 }
207 }
208}
209
210impl AnalysisExpr for MergeExpr {}
211
212#[cfg(test)]
213mod tests {
214 use vortex_array::arrays::{PrimitiveArray, StructArray};
215 use vortex_array::{Array, IntoArray, ToCanonical};
216 use vortex_buffer::buffer;
217 use vortex_dtype::Nullability;
218 use vortex_error::{VortexResult, vortex_bail};
219
220 use crate::{MergeExpr, Scope, get_item, merge, root};
221
222 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
223 let mut field_path = field_path.iter();
224
225 let Some(field) = field_path.next() else {
226 vortex_bail!("empty field path");
227 };
228
229 let mut array = array.to_struct().field_by_name(field)?.clone();
230 for field in field_path {
231 array = array.to_struct().field_by_name(field)?.clone();
232 }
233 Ok(array.to_primitive())
234 }
235
236 #[test]
237 pub fn test_merge() {
238 let expr = MergeExpr::new(
239 vec![
240 get_item("0", root()),
241 get_item("1", root()),
242 get_item("2", root()),
243 ],
244 Nullability::NonNullable,
245 );
246
247 let test_array = StructArray::from_fields(&[
248 (
249 "0",
250 StructArray::from_fields(&[
251 ("a", buffer![0, 0, 0].into_array()),
252 ("b", buffer![1, 1, 1].into_array()),
253 ])
254 .unwrap()
255 .into_array(),
256 ),
257 (
258 "1",
259 StructArray::from_fields(&[
260 ("b", buffer![2, 2, 2].into_array()),
261 ("c", buffer![3, 3, 3].into_array()),
262 ])
263 .unwrap()
264 .into_array(),
265 ),
266 (
267 "2",
268 StructArray::from_fields(&[
269 ("d", buffer![4, 4, 4].into_array()),
270 ("e", buffer![5, 5, 5].into_array()),
271 ])
272 .unwrap()
273 .into_array(),
274 ),
275 ])
276 .unwrap()
277 .into_array();
278 let actual_array = expr.evaluate(&Scope::new(test_array)).unwrap();
279
280 assert_eq!(
281 actual_array.as_struct_typed().names(),
282 ["a", "b", "c", "d", "e"]
283 );
284
285 assert_eq!(
286 primitive_field(&actual_array, &["a"])
287 .unwrap()
288 .as_slice::<i32>(),
289 [0, 0, 0]
290 );
291 assert_eq!(
292 primitive_field(&actual_array, &["b"])
293 .unwrap()
294 .as_slice::<i32>(),
295 [2, 2, 2]
296 );
297 assert_eq!(
298 primitive_field(&actual_array, &["c"])
299 .unwrap()
300 .as_slice::<i32>(),
301 [3, 3, 3]
302 );
303 assert_eq!(
304 primitive_field(&actual_array, &["d"])
305 .unwrap()
306 .as_slice::<i32>(),
307 [4, 4, 4]
308 );
309 assert_eq!(
310 primitive_field(&actual_array, &["e"])
311 .unwrap()
312 .as_slice::<i32>(),
313 [5, 5, 5]
314 );
315 }
316
317 #[test]
318 pub fn test_empty_merge() {
319 let expr = MergeExpr::new(Vec::new(), Nullability::NonNullable);
320
321 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
322 .unwrap()
323 .into_array();
324 let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
325 assert_eq!(actual_array.len(), test_array.len());
326 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
327 }
328
329 #[test]
330 pub fn test_nested_merge() {
331 let expr = MergeExpr::new(
334 vec![get_item("0", root()), get_item("1", root())],
335 Nullability::NonNullable,
336 );
337
338 let test_array = StructArray::from_fields(&[
339 (
340 "0",
341 StructArray::from_fields(&[(
342 "a",
343 StructArray::from_fields(&[
344 ("x", buffer![0, 0, 0].into_array()),
345 ("y", buffer![1, 1, 1].into_array()),
346 ])
347 .unwrap()
348 .into_array(),
349 )])
350 .unwrap()
351 .into_array(),
352 ),
353 (
354 "1",
355 StructArray::from_fields(&[(
356 "a",
357 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
358 .unwrap()
359 .into_array(),
360 )])
361 .unwrap()
362 .into_array(),
363 ),
364 ])
365 .unwrap()
366 .into_array();
367 let actual_array = expr
368 .evaluate(&Scope::new(test_array.clone()))
369 .unwrap()
370 .to_struct();
371
372 assert_eq!(
373 actual_array
374 .field_by_name("a")
375 .unwrap()
376 .to_struct()
377 .names()
378 .iter()
379 .map(|name| name.as_ref())
380 .collect::<Vec<_>>(),
381 vec!["x"]
382 );
383 }
384
385 #[test]
386 pub fn test_merge_order() {
387 let expr = MergeExpr::new(
388 vec![get_item("0", root()), get_item("1", root())],
389 Nullability::NonNullable,
390 );
391
392 let test_array = StructArray::from_fields(&[
393 (
394 "0",
395 StructArray::from_fields(&[
396 ("a", buffer![0, 0, 0].into_array()),
397 ("c", buffer![1, 1, 1].into_array()),
398 ])
399 .unwrap()
400 .into_array(),
401 ),
402 (
403 "1",
404 StructArray::from_fields(&[
405 ("b", buffer![2, 2, 2].into_array()),
406 ("d", buffer![3, 3, 3].into_array()),
407 ])
408 .unwrap()
409 .into_array(),
410 ),
411 ])
412 .unwrap()
413 .into_array();
414 let actual_array = expr
415 .evaluate(&Scope::new(test_array.clone()))
416 .unwrap()
417 .to_struct();
418
419 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
420 }
421
422 #[test]
423 pub fn test_merge_nullable() {
424 let expr = MergeExpr::new(vec![get_item("0", root())], Nullability::Nullable);
425
426 let test_array = StructArray::from_fields(&[(
427 "0",
428 StructArray::from_fields(&[
429 ("a", buffer![0, 0, 0].into_array()),
430 ("b", buffer![1, 1, 1].into_array()),
431 ])
432 .unwrap()
433 .into_array(),
434 )])
435 .unwrap()
436 .into_array();
437 let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
438 assert!(actual_array.dtype().is_nullable());
439 }
440
441 #[test]
442 pub fn test_display() {
443 let expr = merge(
444 [get_item("struct1", root()), get_item("struct2", root())],
445 Nullability::NonNullable,
446 );
447 assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
448
449 let expr2 = MergeExpr::new(vec![get_item("a", root())], Nullability::Nullable);
450 assert_eq!(expr2.to_string(), "merge($.a)?");
451 }
452}