1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
10use vortex_dtype::{DType, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12
13use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct Merge {
23 values: Vec<ExprRef>,
24 nullability: Nullability,
25}
26
27impl Merge {
28 pub fn new_expr(values: Vec<ExprRef>, nullability: Nullability) -> Arc<Self> {
29 Arc::new(Merge {
30 values,
31 nullability,
32 })
33 }
34
35 pub fn nullability(&self) -> Nullability {
36 self.nullability
37 }
38}
39
40pub fn merge(
41 elements: impl IntoIterator<Item = impl Into<ExprRef>>,
42 nullability: Nullability,
43) -> ExprRef {
44 let values = elements.into_iter().map(|value| value.into()).collect_vec();
45 Merge::new_expr(values, nullability)
46}
47
48impl Display for Merge {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 f.write_str("{")?;
51 self.values
52 .iter()
53 .format_with(", ", |expr, f| f(expr))
54 .fmt(f)?;
55 f.write_str("}")
56 }
57}
58
59#[cfg(feature = "proto")]
60pub(crate) mod proto {
61 use vortex_error::{VortexResult, vortex_bail};
62 use vortex_proto::expr::kind::Kind;
63
64 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Merge};
65
66 pub struct MergeSerde;
67
68 impl Id for MergeSerde {
69 fn id(&self) -> &'static str {
70 "merge"
71 }
72 }
73
74 impl ExprDeserialize for MergeSerde {
75 fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
76 vortex_bail!(NotImplemented: "", self.id())
77 }
78 }
79
80 impl ExprSerializable for Merge {
81 fn id(&self) -> &'static str {
82 MergeSerde.id()
83 }
84
85 fn serialize_kind(&self) -> VortexResult<Kind> {
86 vortex_bail!(NotImplemented: "", self.id())
87 }
88 }
89}
90
91impl AnalysisExpr for Merge {}
92
93impl VortexExpr for Merge {
94 fn as_any(&self) -> &dyn Any {
95 self
96 }
97
98 fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef> {
99 let len = ctx.len();
100 let value_arrays = self
101 .values
102 .iter()
103 .map(|value_expr| value_expr.unchecked_evaluate(ctx))
104 .process_results(|it| it.collect::<Vec<_>>())?;
105
106 let mut field_names = Vec::new();
108 let mut arrays = Vec::new();
109
110 for value_array in value_arrays.iter() {
111 if value_array.dtype().is_nullable() {
113 todo!("merge nullable structs");
114 }
115 if !value_array.dtype().is_struct() {
116 vortex_bail!("merge expects non-nullable struct input");
117 }
118
119 let struct_array = value_array.to_struct()?;
120
121 for (i, field_name) in struct_array.names().iter().enumerate() {
122 let array = struct_array.fields()[i].clone();
123
124 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
126 arrays[idx] = array;
127 } else {
128 field_names.push(field_name.clone());
129 arrays.push(array);
130 }
131 }
132 }
133
134 let validity = match self.nullability {
135 Nullability::NonNullable => Validity::NonNullable,
136 Nullability::Nullable => Validity::AllValid,
137 };
138 Ok(
139 StructArray::try_new(FieldNames::from(field_names), arrays, len, validity)?
140 .into_array(),
141 )
142 }
143
144 fn children(&self) -> Vec<&ExprRef> {
145 self.values.iter().collect()
146 }
147
148 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
149 Self::new_expr(children, self.nullability)
150 }
151
152 fn return_dtype(&self, ctx: &ScopeDType) -> VortexResult<DType> {
153 let mut field_names = Vec::new();
154 let mut arrays = Vec::new();
155
156 for value in self.values.iter() {
157 let dtype = value.return_dtype(ctx)?;
158 if !dtype.is_struct() {
159 vortex_bail!("merge expects non-nullable struct input");
160 }
161
162 let struct_dtype = dtype
163 .as_struct()
164 .vortex_expect("merge expects struct input");
165
166 for i in 0..struct_dtype.nfields() {
167 let field_name = struct_dtype.field_name(i).vortex_expect("never OOB");
168 let field_dtype = struct_dtype.field_by_index(i).vortex_expect("never OOB");
169 if let Some(idx) = field_names.iter().position(|name| name == field_name) {
170 arrays[idx] = field_dtype;
171 } else {
172 field_names.push(field_name.clone());
173 arrays.push(field_dtype);
174 }
175 }
176 }
177
178 Ok(DType::Struct(
179 Arc::new(StructFields::new(FieldNames::from(field_names), arrays)),
180 self.nullability,
181 ))
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use vortex_array::arrays::{PrimitiveArray, StructArray};
188 use vortex_array::{Array, IntoArray, ToCanonical};
189 use vortex_buffer::buffer;
190 use vortex_dtype::Nullability;
191 use vortex_error::{VortexResult, vortex_bail};
192
193 use crate::{Merge, Scope, VortexExpr, get_item, root};
194
195 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
196 let mut field_path = field_path.iter();
197
198 let Some(field) = field_path.next() else {
199 vortex_bail!("empty field path");
200 };
201
202 let mut array = array.to_struct()?.field_by_name(field)?.clone();
203 for field in field_path {
204 array = array.to_struct()?.field_by_name(field)?.clone();
205 }
206 Ok(array.to_primitive().unwrap())
207 }
208
209 #[test]
210 pub fn test_merge() {
211 let expr = Merge::new_expr(
212 vec![
213 get_item("0", root()),
214 get_item("1", root()),
215 get_item("2", root()),
216 ],
217 Nullability::NonNullable,
218 );
219
220 let test_array = StructArray::from_fields(&[
221 (
222 "0",
223 StructArray::from_fields(&[
224 ("a", buffer![0, 0, 0].into_array()),
225 ("b", buffer![1, 1, 1].into_array()),
226 ])
227 .unwrap()
228 .into_array(),
229 ),
230 (
231 "1",
232 StructArray::from_fields(&[
233 ("b", buffer![2, 2, 2].into_array()),
234 ("c", buffer![3, 3, 3].into_array()),
235 ])
236 .unwrap()
237 .into_array(),
238 ),
239 (
240 "2",
241 StructArray::from_fields(&[
242 ("d", buffer![4, 4, 4].into_array()),
243 ("e", buffer![5, 5, 5].into_array()),
244 ])
245 .unwrap()
246 .into_array(),
247 ),
248 ])
249 .unwrap()
250 .into_array();
251 let actual_array = expr.evaluate(&Scope::new(test_array)).unwrap();
252
253 assert_eq!(
254 actual_array.as_struct_typed().names(),
255 &["a".into(), "b".into(), "c".into(), "d".into(), "e".into()].into()
256 );
257
258 assert_eq!(
259 primitive_field(&actual_array, &["a"])
260 .unwrap()
261 .as_slice::<i32>(),
262 [0, 0, 0]
263 );
264 assert_eq!(
265 primitive_field(&actual_array, &["b"])
266 .unwrap()
267 .as_slice::<i32>(),
268 [2, 2, 2]
269 );
270 assert_eq!(
271 primitive_field(&actual_array, &["c"])
272 .unwrap()
273 .as_slice::<i32>(),
274 [3, 3, 3]
275 );
276 assert_eq!(
277 primitive_field(&actual_array, &["d"])
278 .unwrap()
279 .as_slice::<i32>(),
280 [4, 4, 4]
281 );
282 assert_eq!(
283 primitive_field(&actual_array, &["e"])
284 .unwrap()
285 .as_slice::<i32>(),
286 [5, 5, 5]
287 );
288 }
289
290 #[test]
291 pub fn test_empty_merge() {
292 let expr = Merge::new_expr(Vec::new(), Nullability::NonNullable);
293
294 let test_array = StructArray::from_fields(&[("a", buffer![0, 1, 2].into_array())])
295 .unwrap()
296 .into_array();
297 let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
298 assert_eq!(actual_array.len(), test_array.len());
299 assert_eq!(actual_array.as_struct_typed().nfields(), 0);
300 }
301
302 #[test]
303 pub fn test_nested_merge() {
304 let expr = Merge::new_expr(
307 vec![get_item("0", root()), get_item("1", root())],
308 Nullability::NonNullable,
309 );
310
311 let test_array = StructArray::from_fields(&[
312 (
313 "0",
314 StructArray::from_fields(&[(
315 "a",
316 StructArray::from_fields(&[
317 ("x", buffer![0, 0, 0].into_array()),
318 ("y", buffer![1, 1, 1].into_array()),
319 ])
320 .unwrap()
321 .into_array(),
322 )])
323 .unwrap()
324 .into_array(),
325 ),
326 (
327 "1",
328 StructArray::from_fields(&[(
329 "a",
330 StructArray::from_fields(&[("x", buffer![0, 0, 0].into_array())])
331 .unwrap()
332 .into_array(),
333 )])
334 .unwrap()
335 .into_array(),
336 ),
337 ])
338 .unwrap()
339 .into_array();
340 let actual_array = expr
341 .evaluate(&Scope::new(test_array.clone()))
342 .unwrap()
343 .to_struct()
344 .unwrap();
345
346 assert_eq!(
347 actual_array
348 .field_by_name("a")
349 .unwrap()
350 .to_struct()
351 .unwrap()
352 .names()
353 .iter()
354 .map(|name| name.as_ref())
355 .collect::<Vec<_>>(),
356 vec!["x"]
357 );
358 }
359
360 #[test]
361 pub fn test_merge_order() {
362 let expr = Merge::new_expr(
363 vec![get_item("0", root()), get_item("1", root())],
364 Nullability::NonNullable,
365 );
366
367 let test_array = StructArray::from_fields(&[
368 (
369 "0",
370 StructArray::from_fields(&[
371 ("a", buffer![0, 0, 0].into_array()),
372 ("c", buffer![1, 1, 1].into_array()),
373 ])
374 .unwrap()
375 .into_array(),
376 ),
377 (
378 "1",
379 StructArray::from_fields(&[
380 ("b", buffer![2, 2, 2].into_array()),
381 ("d", buffer![3, 3, 3].into_array()),
382 ])
383 .unwrap()
384 .into_array(),
385 ),
386 ])
387 .unwrap()
388 .into_array();
389 let actual_array = expr
390 .evaluate(&Scope::new(test_array.clone()))
391 .unwrap()
392 .to_struct()
393 .unwrap();
394
395 assert_eq!(
396 actual_array.names(),
397 &["a".into(), "c".into(), "b".into(), "d".into()].into()
398 );
399 }
400
401 #[test]
402 pub fn test_merge_nullable() {
403 let expr = Merge::new_expr(vec![get_item("0", root())], Nullability::Nullable);
404
405 let test_array = StructArray::from_fields(&[(
406 "0",
407 StructArray::from_fields(&[
408 ("a", buffer![0, 0, 0].into_array()),
409 ("b", buffer![1, 1, 1].into_array()),
410 ])
411 .unwrap()
412 .into_array(),
413 )])
414 .unwrap()
415 .into_array();
416 let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
417 assert!(actual_array.dtype().is_nullable());
418 }
419}