1pub mod transform;
5
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::sync::Arc;
9
10use itertools::Itertools as _;
11use vortex_dtype::DType;
12use vortex_dtype::FieldNames;
13use vortex_dtype::Nullability;
14use vortex_dtype::StructFields;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_utils::aliases::hash_set::HashSet;
18
19use crate::Array;
20use crate::ArrayRef;
21use crate::IntoArray as _;
22use crate::ToCanonical;
23use crate::arrays::StructArray;
24use crate::expr::ChildName;
25use crate::expr::ExprId;
26use crate::expr::Expression;
27use crate::expr::ExpressionView;
28use crate::expr::VTable;
29use crate::expr::VTableExt;
30use crate::validity::Validity;
31
32pub struct Merge;
39
40impl VTable for Merge {
41 type Instance = DuplicateHandling;
42
43 fn id(&self) -> ExprId {
44 ExprId::new_ref("vortex.merge")
45 }
46
47 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
48 Ok(Some(match instance {
49 DuplicateHandling::RightMost => vec![0x00],
50 DuplicateHandling::Error => vec![0x01],
51 }))
52 }
53
54 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
55 let instance = match metadata {
56 [0x00] => DuplicateHandling::RightMost,
57 [0x01] => DuplicateHandling::Error,
58 _ => {
59 vortex_bail!("invalid metadata for Merge expression");
60 }
61 };
62 Ok(Some(instance))
63 }
64
65 fn validate(&self, _expr: &ExpressionView<Self>) -> VortexResult<()> {
66 Ok(())
67 }
68
69 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
70 ChildName::from(Arc::from(format!("{}", child_idx)))
71 }
72
73 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
74 write!(f, "merge(")?;
75 for (i, child) in expr.children().iter().enumerate() {
76 child.fmt_sql(f)?;
77 if i + 1 < expr.children().len() {
78 write!(f, ", ")?;
79 }
80 }
81 write!(f, ")")
82 }
83
84 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
85 let mut field_names = Vec::new();
86 let mut arrays = Vec::new();
87 let mut merge_nullability = Nullability::NonNullable;
88 let mut duplicate_names = HashSet::<_>::new();
89
90 for child in expr.children().iter() {
91 let dtype = child.return_dtype(scope)?;
92 let Some(fields) = dtype.as_struct_fields_opt() else {
93 vortex_bail!("merge expects struct input");
94 };
95 if dtype.is_nullable() {
96 vortex_bail!("merge expects non-nullable input");
97 }
98
99 merge_nullability |= dtype.nullability();
100
101 for (field_name, field_dtype) in fields.names().iter().zip_eq(fields.fields()) {
102 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
103 duplicate_names.insert(field_name.clone());
104 arrays[idx] = field_dtype;
105 } else {
106 field_names.push(field_name.clone());
107 arrays.push(field_dtype);
108 }
109 }
110 }
111
112 if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() {
113 vortex_bail!(
114 "merge: duplicate fields in children: {}",
115 duplicate_names.into_iter().format(", ")
116 )
117 }
118
119 Ok(DType::Struct(
120 StructFields::new(FieldNames::from(field_names), arrays),
121 merge_nullability,
122 ))
123 }
124
125 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
126 let mut field_names = Vec::new();
128 let mut arrays = Vec::new();
129 let mut duplicate_names = HashSet::<_>::new();
130
131 for child in expr.children().iter() {
132 let array = child.evaluate(scope)?;
134 if array.dtype().is_nullable() {
135 vortex_bail!("merge expects non-nullable input");
136 }
137 if !array.dtype().is_struct() {
138 vortex_bail!("merge expects struct input");
139 }
140 let array = array.to_struct();
141
142 for (field_name, array) in array.names().iter().zip_eq(array.fields().iter().cloned()) {
143 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
145 duplicate_names.insert(field_name.clone());
146 arrays[idx] = array;
147 } else {
148 field_names.push(field_name.clone());
149 arrays.push(array);
150 }
151 }
152 }
153
154 if expr.data() == &DuplicateHandling::Error && !duplicate_names.is_empty() {
155 vortex_bail!(
156 "merge: duplicate fields in children: {}",
157 duplicate_names.into_iter().format(", ")
158 )
159 }
160
161 let validity = Validity::NonNullable;
163 let len = scope.len();
164 Ok(
165 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
166 .into_array(),
167 )
168 }
169
170 fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
171 true
172 }
173
174 fn is_fallible(&self, instance: &Self::Instance) -> bool {
175 matches!(instance, DuplicateHandling::Error)
176 }
177}
178
179#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)]
181pub enum DuplicateHandling {
182 RightMost,
184 #[default]
186 Error,
187}
188
189pub fn merge(elements: impl IntoIterator<Item = impl Into<Expression>>) -> Expression {
200 let values = elements.into_iter().map(|value| value.into()).collect_vec();
201 Merge.new_expr(DuplicateHandling::default(), values)
202}
203
204pub fn merge_opts(
205 elements: impl IntoIterator<Item = impl Into<Expression>>,
206 duplicate_handling: DuplicateHandling,
207) -> Expression {
208 let values = elements.into_iter().map(|value| value.into()).collect_vec();
209 Merge.new_expr(duplicate_handling, values)
210}
211
212#[cfg(test)]
213mod tests {
214 use vortex_buffer::buffer;
215 use vortex_error::VortexResult;
216 use vortex_error::vortex_bail;
217
218 use super::merge;
219 use crate::Array;
220 use crate::IntoArray;
221 use crate::ToCanonical;
222 use crate::arrays::PrimitiveArray;
223 use crate::arrays::StructArray;
224 use crate::expr::Expression;
225 use crate::expr::exprs::get_item::get_item;
226 use crate::expr::exprs::merge::DuplicateHandling;
227 use crate::expr::exprs::merge::merge_opts;
228 use crate::expr::exprs::root::root;
229
230 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
231 let mut field_path = field_path.iter();
232
233 let Some(field) = field_path.next() else {
234 vortex_bail!("empty field path");
235 };
236
237 let mut array = array.to_struct().field_by_name(field)?.clone();
238 for field in field_path {
239 array = array.to_struct().field_by_name(field)?.clone();
240 }
241 Ok(array.to_primitive())
242 }
243
244 #[test]
245 pub fn test_merge_right_most() {
246 let expr = merge_opts(
247 vec![
248 get_item("0", root()),
249 get_item("1", root()),
250 get_item("2", root()),
251 ],
252 DuplicateHandling::RightMost,
253 );
254
255 let test_array = StructArray::from_fields(&[
256 (
257 "0",
258 StructArray::from_fields(&[
259 ("a", buffer![0, 0, 0].into_array()),
260 ("b", buffer![1, 1, 1].into_array()),
261 ])
262 .unwrap()
263 .into_array(),
264 ),
265 (
266 "1",
267 StructArray::from_fields(&[
268 ("b", buffer![2, 2, 2].into_array()),
269 ("c", buffer![3, 3, 3].into_array()),
270 ])
271 .unwrap()
272 .into_array(),
273 ),
274 (
275 "2",
276 StructArray::from_fields(&[
277 ("d", buffer![4, 4, 4].into_array()),
278 ("e", buffer![5, 5, 5].into_array()),
279 ])
280 .unwrap()
281 .into_array(),
282 ),
283 ])
284 .unwrap()
285 .into_array();
286 let actual_array = expr.evaluate(&test_array).unwrap();
287
288 assert_eq!(
289 actual_array.as_struct_typed().names(),
290 ["a", "b", "c", "d", "e"]
291 );
292
293 assert_eq!(
294 primitive_field(&actual_array, &["a"])
295 .unwrap()
296 .as_slice::<i32>(),
297 [0, 0, 0]
298 );
299 assert_eq!(
300 primitive_field(&actual_array, &["b"])
301 .unwrap()
302 .as_slice::<i32>(),
303 [2, 2, 2]
304 );
305 assert_eq!(
306 primitive_field(&actual_array, &["c"])
307 .unwrap()
308 .as_slice::<i32>(),
309 [3, 3, 3]
310 );
311 assert_eq!(
312 primitive_field(&actual_array, &["d"])
313 .unwrap()
314 .as_slice::<i32>(),
315 [4, 4, 4]
316 );
317 assert_eq!(
318 primitive_field(&actual_array, &["e"])
319 .unwrap()
320 .as_slice::<i32>(),
321 [5, 5, 5]
322 );
323 }
324
325 #[test]
326 #[should_panic(expected = "merge: duplicate fields in children")]
327 pub fn test_merge_error_on_dupe_return_dtype() {
328 let expr = merge_opts(
329 vec![get_item("0", root()), get_item("1", root())],
330 DuplicateHandling::Error,
331 );
332 let test_array = StructArray::try_from_iter([
333 (
334 "0",
335 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
336 ),
337 (
338 "1",
339 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
340 ),
341 ])
342 .unwrap()
343 .into_array();
344
345 expr.return_dtype(test_array.dtype()).unwrap();
346 }
347
348 #[test]
349 #[should_panic(expected = "merge: duplicate fields in children")]
350 pub fn test_merge_error_on_dupe_evaluate() {
351 let expr = merge_opts(
352 vec![get_item("0", root()), get_item("1", root())],
353 DuplicateHandling::Error,
354 );
355 let test_array = StructArray::try_from_iter([
356 (
357 "0",
358 StructArray::try_from_iter([("a", buffer![1]), ("b", buffer![1])]).unwrap(),
359 ),
360 (
361 "1",
362 StructArray::try_from_iter([("c", buffer![1]), ("b", buffer![1])]).unwrap(),
363 ),
364 ])
365 .unwrap()
366 .into_array();
367
368 expr.evaluate(&test_array).unwrap();
369 }
370
371 #[test]
372 pub fn test_empty_merge() {
373 let expr = merge(Vec::<Expression>::new());
374
375 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
376 .unwrap()
377 .into_array();
378 let actual_array = expr.evaluate(&test_array.clone()).unwrap();
379 assert_eq!(actual_array.len(), test_array.len());
380 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
381 }
382
383 #[test]
384 pub fn test_nested_merge() {
385 let expr = merge_opts(
388 vec![get_item("0", root()), get_item("1", root())],
389 DuplicateHandling::RightMost,
390 );
391
392 let test_array = StructArray::from_fields(&[
393 (
394 "0",
395 StructArray::from_fields(&[(
396 "a",
397 StructArray::from_fields(&[
398 ("x", buffer![0, 0, 0].into_array()),
399 ("y", buffer![1, 1, 1].into_array()),
400 ])
401 .unwrap()
402 .into_array(),
403 )])
404 .unwrap()
405 .into_array(),
406 ),
407 (
408 "1",
409 StructArray::from_fields(&[(
410 "a",
411 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
412 .unwrap()
413 .into_array(),
414 )])
415 .unwrap()
416 .into_array(),
417 ),
418 ])
419 .unwrap()
420 .into_array();
421 let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
422
423 assert_eq!(
424 actual_array
425 .field_by_name("a")
426 .unwrap()
427 .to_struct()
428 .names()
429 .iter()
430 .map(|name| name.as_ref())
431 .collect::<Vec<_>>(),
432 vec!["x"]
433 );
434 }
435
436 #[test]
437 pub fn test_merge_order() {
438 let expr = merge(vec![get_item("0", root()), get_item("1", root())]);
439
440 let test_array = StructArray::from_fields(&[
441 (
442 "0",
443 StructArray::from_fields(&[
444 ("a", buffer![0, 0, 0].into_array()),
445 ("c", buffer![1, 1, 1].into_array()),
446 ])
447 .unwrap()
448 .into_array(),
449 ),
450 (
451 "1",
452 StructArray::from_fields(&[
453 ("b", buffer![2, 2, 2].into_array()),
454 ("d", buffer![3, 3, 3].into_array()),
455 ])
456 .unwrap()
457 .into_array(),
458 ),
459 ])
460 .unwrap()
461 .into_array();
462 let actual_array = expr.evaluate(&test_array.clone()).unwrap().to_struct();
463
464 assert_eq!(actual_array.names(), ["a", "c", "b", "d"]);
465 }
466
467 #[test]
468 pub fn test_display() {
469 let expr = merge([get_item("struct1", root()), get_item("struct2", root())]);
470 assert_eq!(expr.to_string(), "merge($.struct1, $.struct2)");
471
472 let expr2 = merge(vec![get_item("a", root())]);
473 assert_eq!(expr2.to_string(), "merge($.a)");
474 }
475}