1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
10use vortex_dtype::{DType, FieldNames, Nullability, StructDType};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::{ExprRef, VortexExpr};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct Merge {
23 values: Vec<ExprRef>,
24 nullability: Nullability,
25}
26
27impl Merge {
28 pub fn new_expr(values: Vec<ExprRef>, nullability: Nullability) -> Arc<Self> {
29 Arc::new(Merge {
30 values,
31 nullability,
32 })
33 }
34
35 pub fn nullability(&self) -> Nullability {
36 self.nullability
37 }
38}
39
40pub fn merge(
41 elements: impl IntoIterator<Item = impl Into<ExprRef>>,
42 nullability: Nullability,
43) -> ExprRef {
44 let values = elements.into_iter().map(|value| value.into()).collect_vec();
45 Merge::new_expr(values, nullability)
46}
47
48impl Display for Merge {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 f.write_str("{")?;
51 self.values
52 .iter()
53 .format_with(", ", |expr, f| f(expr))
54 .fmt(f)?;
55 f.write_str("}")
56 }
57}
58
59#[cfg(feature = "proto")]
60pub(crate) mod proto {
61 use vortex_error::{VortexResult, vortex_bail};
62 use vortex_proto::expr::kind::Kind;
63
64 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Merge};
65
66 pub struct MergeSerde;
67
68 impl Id for MergeSerde {
69 fn id(&self) -> &'static str {
70 "merge"
71 }
72 }
73
74 impl ExprDeserialize for MergeSerde {
75 fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
76 vortex_bail!(NotImplemented: "", self.id())
77 }
78 }
79
80 impl ExprSerializable for Merge {
81 fn id(&self) -> &'static str {
82 MergeSerde.id()
83 }
84
85 fn serialize_kind(&self) -> VortexResult<Kind> {
86 vortex_bail!(NotImplemented: "", self.id())
87 }
88 }
89}
90
91impl VortexExpr for Merge {
92 fn as_any(&self) -> &dyn Any {
93 self
94 }
95
96 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
97 let len = batch.len();
98 let value_arrays = self
99 .values
100 .iter()
101 .map(|value_expr| value_expr.evaluate(batch))
102 .process_results(|it| it.collect::<Vec<_>>())?;
103
104 let mut field_names = Vec::new();
106 let mut arrays = Vec::new();
107
108 for value_array in value_arrays.iter() {
109 if value_array.dtype().is_nullable() {
111 todo!("merge nullable structs");
112 }
113 if !value_array.dtype().is_struct() {
114 vortex_bail!("merge expects non-nullable struct input");
115 }
116
117 let struct_array = value_array.to_struct()?;
118
119 for (i, field_name) in struct_array.names().iter().enumerate() {
120 let array = struct_array.fields()[i].clone();
121
122 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
124 arrays[idx] = array;
125 } else {
126 field_names.push(field_name.clone());
127 arrays.push(array);
128 }
129 }
130 }
131
132 let validity = match self.nullability {
133 Nullability::NonNullable => Validity::NonNullable,
134 Nullability::Nullable => Validity::AllValid,
135 };
136 Ok(
137 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
138 .into_array(),
139 )
140 }
141
142 fn children(&self) -> Vec<&ExprRef> {
143 self.values.iter().collect()
144 }
145
146 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
147 Self::new_expr(children, self.nullability)
148 }
149
150 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
151 let mut field_names = Vec::new();
152 let mut arrays = Vec::new();
153
154 for value in self.values.iter() {
155 let dtype = value.return_dtype(scope_dtype)?;
156 if !dtype.is_struct() {
157 vortex_bail!("merge expects non-nullable struct input");
158 }
159
160 let struct_dtype = dtype
161 .as_struct()
162 .vortex_expect("merge expects struct input");
163
164 for i in 0..struct_dtype.nfields() {
165 let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
166 let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
167 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
168 arrays[idx] = field_dtype;
169 } else {
170 field_names.push(field_name.clone());
171 arrays.push(field_dtype);
172 }
173 }
174 }
175
176 Ok(DType::Struct(
177 Arc::new(StructDType::new(FieldNames::from(field_names), arrays)),
178 self.nullability,
179 ))
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use vortex_array::arrays::{PrimitiveArray, StructArray};
186 use vortex_array::{Array, IntoArray, ToCanonical};
187 use vortex_buffer::buffer;
188 use vortex_dtype::Nullability;
189 use vortex_error::{VortexResult, vortex_bail};
190
191 use crate::{GetItem, Identity, Merge, VortexExpr};
192
193 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
194 let mut field_path = field_path.iter();
195
196 let Some(field) = field_path.next() else {
197 vortex_bail!("empty field path");
198 };
199
200 let mut array = array.to_struct()?.field_by_name(field)?.clone();
201 for field in field_path {
202 array = array.to_struct()?.field_by_name(field)?.clone();
203 }
204 Ok(array.to_primitive().unwrap())
205 }
206
207 #[test]
208 pub fn test_merge() {
209 let expr = Merge::new_expr(
210 vec![
211 GetItem::new_expr("0", Identity::new_expr()),
212 GetItem::new_expr("1", Identity::new_expr()),
213 GetItem::new_expr("2", Identity::new_expr()),
214 ],
215 Nullability::NonNullable,
216 );
217
218 let test_array = StructArray::from_fields(&[
219 (
220 "0",
221 StructArray::from_fields(&[
222 ("a", buffer![0, 0, 0].into_array()),
223 ("b", buffer![1, 1, 1].into_array()),
224 ])
225 .unwrap()
226 .into_array(),
227 ),
228 (
229 "1",
230 StructArray::from_fields(&[
231 ("b", buffer![2, 2, 2].into_array()),
232 ("c", buffer![3, 3, 3].into_array()),
233 ])
234 .unwrap()
235 .into_array(),
236 ),
237 (
238 "2",
239 StructArray::from_fields(&[
240 ("d", buffer![4, 4, 4].into_array()),
241 ("e", buffer![5, 5, 5].into_array()),
242 ])
243 .unwrap()
244 .into_array(),
245 ),
246 ])
247 .unwrap()
248 .into_array();
249 let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
250
251 assert_eq!(
252 actual_array.as_struct_typed().names(),
253 &["a".into(), "b".into(), "c".into(), "d".into(), "e".into()].into()
254 );
255
256 assert_eq!(
257 primitive_field(&actual_array, &["a"])
258 .unwrap()
259 .as_slice::<i32>(),
260 [0, 0, 0]
261 );
262 assert_eq!(
263 primitive_field(&actual_array, &["b"])
264 .unwrap()
265 .as_slice::<i32>(),
266 [2, 2, 2]
267 );
268 assert_eq!(
269 primitive_field(&actual_array, &["c"])
270 .unwrap()
271 .as_slice::<i32>(),
272 [3, 3, 3]
273 );
274 assert_eq!(
275 primitive_field(&actual_array, &["d"])
276 .unwrap()
277 .as_slice::<i32>(),
278 [4, 4, 4]
279 );
280 assert_eq!(
281 primitive_field(&actual_array, &["e"])
282 .unwrap()
283 .as_slice::<i32>(),
284 [5, 5, 5]
285 );
286 }
287
288 #[test]
289 pub fn test_empty_merge() {
290 let expr = Merge::new_expr(Vec::new(), Nullability::NonNullable);
291
292 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
293 .unwrap()
294 .into_array();
295 let actual_array = expr.evaluate(&test_array).unwrap();
296 assert_eq!(actual_array.len(), test_array.len());
297 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
298 }
299
300 #[test]
301 pub fn test_nested_merge() {
302 let expr = Merge::new_expr(
305 vec![
306 GetItem::new_expr("0", Identity::new_expr()),
307 GetItem::new_expr("1", Identity::new_expr()),
308 ],
309 Nullability::NonNullable,
310 );
311
312 let test_array = StructArray::from_fields(&[
313 (
314 "0",
315 StructArray::from_fields(&[(
316 "a",
317 StructArray::from_fields(&[
318 ("x", buffer![0, 0, 0].into_array()),
319 ("y", buffer![1, 1, 1].into_array()),
320 ])
321 .unwrap()
322 .into_array(),
323 )])
324 .unwrap()
325 .into_array(),
326 ),
327 (
328 "1",
329 StructArray::from_fields(&[(
330 "a",
331 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
332 .unwrap()
333 .into_array(),
334 )])
335 .unwrap()
336 .into_array(),
337 ),
338 ])
339 .unwrap()
340 .into_array();
341 let actual_array = expr
342 .evaluate(test_array.as_ref())
343 .unwrap()
344 .to_struct()
345 .unwrap();
346
347 assert_eq!(
348 actual_array
349 .field_by_name("a")
350 .unwrap()
351 .to_struct()
352 .unwrap()
353 .names()
354 .iter()
355 .map(|name| name.as_ref())
356 .collect::<Vec<_>>(),
357 vec!["x"]
358 );
359 }
360
361 #[test]
362 pub fn test_merge_order() {
363 let expr = Merge::new_expr(
364 vec![
365 GetItem::new_expr("0", Identity::new_expr()),
366 GetItem::new_expr("1", Identity::new_expr()),
367 ],
368 Nullability::NonNullable,
369 );
370
371 let test_array = StructArray::from_fields(&[
372 (
373 "0",
374 StructArray::from_fields(&[
375 ("a", buffer![0, 0, 0].into_array()),
376 ("c", buffer![1, 1, 1].into_array()),
377 ])
378 .unwrap()
379 .into_array(),
380 ),
381 (
382 "1",
383 StructArray::from_fields(&[
384 ("b", buffer![2, 2, 2].into_array()),
385 ("d", buffer![3, 3, 3].into_array()),
386 ])
387 .unwrap()
388 .into_array(),
389 ),
390 ])
391 .unwrap()
392 .into_array();
393 let actual_array = expr
394 .evaluate(test_array.as_ref())
395 .unwrap()
396 .to_struct()
397 .unwrap();
398
399 assert_eq!(
400 actual_array.names(),
401 &["a".into(), "c".into(), "b".into(), "d".into()].into()
402 );
403 }
404
405 #[test]
406 pub fn test_merge_nullable() {
407 let expr = Merge::new_expr(
408 vec![GetItem::new_expr("0", Identity::new_expr())],
409 Nullability::Nullable,
410 );
411
412 let test_array = StructArray::from_fields(&[(
413 "0",
414 StructArray::from_fields(&[
415 ("a", buffer![0, 0, 0].into_array()),
416 ("b", buffer![1, 1, 1].into_array()),
417 ])
418 .unwrap()
419 .into_array(),
420 )])
421 .unwrap()
422 .into_array();
423 let actual_array = expr.evaluate(test_array.as_ref()).unwrap();
424 assert!(actual_array.dtype().is_nullable());
425 }
426}