vortex_array/arrays/struct_/compute/
zip.rs1use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::Not;
7
8use vortex_error::VortexResult;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::IntoArray;
13use crate::arrays::StructArray;
14use crate::arrays::StructVTable;
15use crate::builtins::ArrayBuiltins;
16use crate::scalar_fn::fns::zip::ZipKernel;
17use crate::validity::Validity;
18use crate::vtable::ValidityHelper;
19
20impl ZipKernel for StructVTable {
21 fn zip(
22 if_true: &StructArray,
23 if_false: &ArrayRef,
24 mask: &ArrayRef,
25 ctx: &mut ExecutionCtx,
26 ) -> VortexResult<Option<ArrayRef>> {
27 let Some(if_false) = if_false.as_opt::<StructVTable>() else {
28 return Ok(None);
29 };
30 assert_eq!(
31 if_true.names(),
32 if_false.names(),
33 "input arrays to zip must have the same field names",
34 );
35
36 let fields = if_true
37 .unmasked_fields()
38 .iter()
39 .zip(if_false.unmasked_fields().iter())
40 .map(|(t, f)| ArrayBuiltins::zip(mask, t.clone(), f.clone()))
41 .collect::<VortexResult<Vec<_>>>()?;
42
43 let validity = match (if_true.validity(), if_false.validity()) {
44 (&Validity::NonNullable, &Validity::NonNullable) => Validity::NonNullable,
45 (&Validity::AllValid, &Validity::AllValid) => Validity::AllValid,
46 (&Validity::AllInvalid, &Validity::AllInvalid) => Validity::AllInvalid,
47
48 (v1, v2) => {
49 let mask_mask = mask.try_to_mask_fill_null_false(ctx)?;
50 let v1m = v1.to_mask(if_true.len());
51 let v2m = v2.to_mask(if_false.len());
52
53 let combined = (v1m.bitand(&mask_mask)).bitor(&v2m.bitand(&mask_mask.not()));
54 Validity::from_mask(
55 combined,
56 if_true.dtype.nullability() | if_false.dtype.nullability(),
57 )
58 }
59 };
60
61 Ok(Some(
62 StructArray::try_new(if_true.names().clone(), fields, if_true.len(), validity)?
63 .into_array(),
64 ))
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use vortex_mask::Mask;
71
72 use crate::IntoArray;
73 use crate::arrays::PrimitiveArray;
74 use crate::arrays::StructArray;
75 use crate::builtins::ArrayBuiltins;
76 use crate::dtype::FieldNames;
77 use crate::validity::Validity;
78
79 #[test]
80 fn test_validity_zip_both_validity_array() {
81 let if_true = StructArray::new(
83 FieldNames::from_iter(["field"]),
84 vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
85 4,
86 Validity::from_iter([true, false, true, false]),
87 )
88 .into_array();
89
90 let if_false = StructArray::new(
91 FieldNames::from_iter(["field"]),
92 vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
93 4,
94 Validity::from_iter([false, true, false, true]),
95 )
96 .into_array();
97
98 let mask = Mask::from_iter([false, false, true, false]);
99
100 let result = mask
101 .into_array()
102 .zip(if_true.clone(), if_false.clone())
103 .unwrap();
104
105 insta::assert_snapshot!(result.display_table(), @r"
106 ┌───────┐
107 │ field │
108 ├───────┤
109 │ null │
110 ├───────┤
111 │ 20i32 │
112 ├───────┤
113 │ 3i32 │
114 ├───────┤
115 │ 40i32 │
116 └───────┘
117 ");
118 }
119
120 #[test]
121 fn test_validity_zip_allvalid_and_array() {
122 let if_true = StructArray::new(
123 FieldNames::from_iter(["a"]),
124 vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
125 4,
126 Validity::AllValid,
127 )
128 .into_array();
129
130 let if_false = StructArray::new(
131 FieldNames::from_iter(["a"]),
132 vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
133 4,
134 Validity::from_iter([false, false, true, true]),
135 )
136 .into_array();
137
138 let mask = Mask::from_iter([true, false, false, false]);
139
140 let result = mask
141 .into_array()
142 .zip(if_true.clone(), if_false.clone())
143 .unwrap();
144
145 insta::assert_snapshot!(result.display_table(), @r"
146 ┌───────┐
147 │ a │
148 ├───────┤
149 │ 1i32 │
150 ├───────┤
151 │ null │
152 ├───────┤
153 │ 30i32 │
154 ├───────┤
155 │ 40i32 │
156 └───────┘
157 ");
158 }
159}