vortex_array/arrays/struct_/compute/
zip.rs1use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::Not;
7
8use vortex_error::VortexResult;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::array::ArrayView;
14use crate::arrays::Struct;
15use crate::arrays::StructArray;
16use crate::arrays::struct_::StructArrayExt;
17use crate::builtins::ArrayBuiltins;
18use crate::scalar_fn::fns::zip::ZipKernel;
19use crate::validity::Validity;
20
21impl ZipKernel for Struct {
22 fn zip(
23 if_true: ArrayView<'_, Struct>,
24 if_false: &ArrayRef,
25 mask: &ArrayRef,
26 ctx: &mut ExecutionCtx,
27 ) -> VortexResult<Option<ArrayRef>> {
28 let Some(if_false) = if_false.as_opt::<Struct>() else {
29 return Ok(None);
30 };
31 assert_eq!(
32 if_true.names(),
33 if_false.names(),
34 "input arrays to zip must have the same field names",
35 );
36
37 let fields = if_true
38 .iter_unmasked_fields()
39 .zip(if_false.iter_unmasked_fields())
40 .map(|(t, f)| ArrayBuiltins::zip(mask, t.clone(), f.clone()))
41 .collect::<VortexResult<Vec<_>>>()?;
42
43 let v1 = if_true.validity()?;
44 let v2 = if_false.validity()?;
45 let validity = match (&v1, &v2) {
46 (Validity::NonNullable, Validity::NonNullable) => Validity::NonNullable,
47 (Validity::AllValid, Validity::AllValid) => Validity::AllValid,
48 (Validity::AllInvalid, Validity::AllInvalid) => Validity::AllInvalid,
49
50 (v1, v2) => {
51 let mask_mask = mask.try_to_mask_fill_null_false(ctx)?;
52 let v1m = v1.execute_mask(if_true.len(), ctx)?;
53 let v2m = v2.execute_mask(if_false.len(), ctx)?;
54
55 let combined = (v1m.bitand(&mask_mask)).bitor(&v2m.bitand(&mask_mask.not()));
56 Validity::from_mask(
57 combined,
58 if_true.dtype().nullability() | if_false.dtype().nullability(),
59 )
60 }
61 };
62
63 Ok(Some(
64 StructArray::try_new(if_true.names().clone(), fields, if_true.len(), validity)?
65 .into_array(),
66 ))
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use vortex_mask::Mask;
73
74 use crate::IntoArray;
75 use crate::arrays::PrimitiveArray;
76 use crate::arrays::StructArray;
77 use crate::builtins::ArrayBuiltins;
78 use crate::dtype::FieldNames;
79 use crate::validity::Validity;
80
81 #[test]
82 fn test_validity_zip_both_validity_array() {
83 let if_true = StructArray::new(
85 FieldNames::from_iter(["field"]),
86 vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
87 4,
88 Validity::from_iter([true, false, true, false]),
89 )
90 .into_array();
91
92 let if_false = StructArray::new(
93 FieldNames::from_iter(["field"]),
94 vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
95 4,
96 Validity::from_iter([false, true, false, true]),
97 )
98 .into_array();
99
100 let mask = Mask::from_iter([false, false, true, false]);
101
102 let result = mask.into_array().zip(if_true, if_false).unwrap();
103
104 insta::assert_snapshot!(result.display_table(), @r"
105 ┌───────┐
106 │ field │
107 ├───────┤
108 │ null │
109 ├───────┤
110 │ 20i32 │
111 ├───────┤
112 │ 3i32 │
113 ├───────┤
114 │ 40i32 │
115 └───────┘
116 ");
117 }
118
119 #[test]
120 fn test_validity_zip_allvalid_and_array() {
121 let if_true = StructArray::new(
122 FieldNames::from_iter(["a"]),
123 vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
124 4,
125 Validity::AllValid,
126 )
127 .into_array();
128
129 let if_false = StructArray::new(
130 FieldNames::from_iter(["a"]),
131 vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
132 4,
133 Validity::from_iter([false, false, true, true]),
134 )
135 .into_array();
136
137 let mask = Mask::from_iter([true, false, false, false]);
138
139 let result = mask.into_array().zip(if_true, if_false).unwrap();
140
141 insta::assert_snapshot!(result.display_table(), @r"
142 ┌───────┐
143 │ a │
144 ├───────┤
145 │ 1i32 │
146 ├───────┤
147 │ null │
148 ├───────┤
149 │ 30i32 │
150 ├───────┤
151 │ 40i32 │
152 └───────┘
153 ");
154 }
155}