vortex_array/arrays/struct_/compute/
zip.rs1use std::ops::BitAnd;
5use std::ops::BitOr;
6use std::ops::Not;
7
8use vortex_error::VortexResult;
9
10use crate::ArrayRef;
11use crate::ExecutionCtx;
12use crate::arrays::StructArray;
13use crate::arrays::StructVTable;
14use crate::builtins::ArrayBuiltins;
15use crate::scalar_fn::fns::zip::ZipKernel;
16use crate::validity::Validity;
17use crate::vtable::ValidityHelper;
18
19impl ZipKernel for StructVTable {
20 fn zip(
21 if_true: &StructArray,
22 if_false: &ArrayRef,
23 mask: &ArrayRef,
24 ctx: &mut ExecutionCtx,
25 ) -> VortexResult<Option<ArrayRef>> {
26 let Some(if_false) = if_false.as_opt::<StructVTable>() else {
27 return Ok(None);
28 };
29 assert_eq!(
30 if_true.names(),
31 if_false.names(),
32 "input arrays to zip must have the same field names",
33 );
34
35 let fields = if_true
36 .unmasked_fields()
37 .iter()
38 .zip(if_false.unmasked_fields().iter())
39 .map(|(t, f)| ArrayBuiltins::zip(mask, t.clone(), f.clone()))
40 .collect::<VortexResult<Vec<_>>>()?;
41
42 let validity = match (if_true.validity(), if_false.validity()) {
43 (&Validity::NonNullable, &Validity::NonNullable) => Validity::NonNullable,
44 (&Validity::AllValid, &Validity::AllValid) => Validity::AllValid,
45 (&Validity::AllInvalid, &Validity::AllInvalid) => Validity::AllInvalid,
46
47 (v1, v2) => {
48 let mask_mask = mask.try_to_mask_fill_null_false(ctx)?;
49 let v1m = v1.to_mask(if_true.len());
50 let v2m = v2.to_mask(if_false.len());
51
52 let combined = (v1m.bitand(&mask_mask)).bitor(&v2m.bitand(&mask_mask.not()));
53 Validity::from_mask(
54 combined,
55 if_true.dtype.nullability() | if_false.dtype.nullability(),
56 )
57 }
58 };
59
60 Ok(Some(
61 StructArray::try_new(if_true.names().clone(), fields, if_true.len(), validity)?
62 .to_array(),
63 ))
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use vortex_mask::Mask;
70
71 use crate::IntoArray;
72 use crate::arrays::PrimitiveArray;
73 use crate::arrays::StructArray;
74 use crate::builtins::ArrayBuiltins;
75 use crate::dtype::FieldNames;
76 use crate::validity::Validity;
77
78 #[test]
79 fn test_validity_zip_both_validity_array() {
80 let if_true = StructArray::new(
82 FieldNames::from_iter(["field"]),
83 vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
84 4,
85 Validity::from_iter([true, false, true, false]),
86 )
87 .into_array();
88
89 let if_false = StructArray::new(
90 FieldNames::from_iter(["field"]),
91 vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
92 4,
93 Validity::from_iter([false, true, false, true]),
94 )
95 .into_array();
96
97 let mask = Mask::from_iter([false, false, true, false]);
98
99 let result = mask
100 .into_array()
101 .zip(if_true.clone(), if_false.clone())
102 .unwrap();
103
104 insta::assert_snapshot!(result.display_table(), @r"
105 ┌───────┐
106 │ field │
107 ├───────┤
108 │ null │
109 ├───────┤
110 │ 20i32 │
111 ├───────┤
112 │ 3i32 │
113 ├───────┤
114 │ 40i32 │
115 └───────┘
116 ");
117 }
118
119 #[test]
120 fn test_validity_zip_allvalid_and_array() {
121 let if_true = StructArray::new(
122 FieldNames::from_iter(["a"]),
123 vec![PrimitiveArray::from_iter([1, 2, 3, 4]).into_array()],
124 4,
125 Validity::AllValid,
126 )
127 .into_array();
128
129 let if_false = StructArray::new(
130 FieldNames::from_iter(["a"]),
131 vec![PrimitiveArray::from_iter([10, 20, 30, 40]).into_array()],
132 4,
133 Validity::from_iter([false, false, true, true]),
134 )
135 .into_array();
136
137 let mask = Mask::from_iter([true, false, false, false]);
138
139 let result = mask
140 .into_array()
141 .zip(if_true.clone(), if_false.clone())
142 .unwrap();
143
144 insta::assert_snapshot!(result.display_table(), @r"
145 ┌───────┐
146 │ a │
147 ├───────┤
148 │ 1i32 │
149 ├───────┤
150 │ null │
151 ├───────┤
152 │ 30i32 │
153 ├───────┤
154 │ 40i32 │
155 └───────┘
156 ");
157 }
158}