vortex_array/arrays/primitive/compute/
zip.rs1use std::mem::MaybeUninit;
5
6use vortex_buffer::BufferMut;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::array::ArrayView;
16use crate::arrays::Primitive;
17use crate::arrays::PrimitiveArray;
18use crate::dtype::NativePType;
19use crate::match_each_native_ptype;
20use crate::scalar_fn::fns::zip::ZipKernel;
21use crate::scalar_fn::fns::zip::zip_validity;
22
23impl ZipKernel for Primitive {
30 fn zip(
31 if_true: ArrayView<'_, Primitive>,
32 if_false: &ArrayRef,
33 mask: &ArrayRef,
34 ctx: &mut ExecutionCtx,
35 ) -> VortexResult<Option<ArrayRef>> {
36 let Some(if_false) = if_false.as_opt::<Primitive>() else {
37 return Ok(None);
38 };
39
40 if if_true.ptype() != if_false.ptype() {
41 vortex_bail!(
42 "zip requires if_true and if_false to share a primitive type, got {} and {}",
43 if_true.ptype(),
44 if_false.ptype()
45 );
46 }
47
48 let mask = mask.try_to_mask_fill_null_false(ctx)?;
50 match &mask {
51 Mask::AllTrue(_) | Mask::AllFalse(_) => return Ok(None),
53 Mask::Values(_) => {}
54 }
55
56 let validity = zip_validity(if_true.validity()?, if_false.validity()?, &mask)?;
57
58 let array = match_each_native_ptype!(if_true.ptype(), |T| {
64 let values =
65 select_values::<T>(if_true.as_slice::<T>(), if_false.as_slice::<T>(), &mask);
66 PrimitiveArray::new(values.freeze(), validity).into_array()
67 });
68 Ok(Some(array))
69 }
70}
71
72fn select_values<T: NativePType>(
74 true_values: &[T],
75 false_values: &[T],
76 mask: &Mask,
77) -> BufferMut<T> {
78 let len = true_values.len();
79 let mut out = BufferMut::<T>::with_capacity(len);
80 {
81 let out_slice = out.spare_capacity_mut();
82
83 let mask_bits = mask
84 .values()
85 .vortex_expect("mask is Mask::Values")
86 .bit_buffer();
87 let chunks = mask_bits.chunks();
90
91 let mut base = 0;
92 for word in chunks.iter() {
93 let end = base + 64;
94 select_block(
95 word,
96 &true_values[base..end],
97 &false_values[base..end],
98 &mut out_slice[base..end],
99 );
100 base = end;
101 }
102
103 let remainder = chunks.remainder_len();
104 if remainder > 0 {
105 let end = base + remainder;
106 select_block(
107 chunks.remainder_bits(),
108 &true_values[base..end],
109 &false_values[base..end],
110 &mut out_slice[base..end],
111 );
112 }
113 }
114
115 unsafe { out.set_len(len) };
117 out
118}
119
120#[inline]
124fn select_block<T: NativePType>(
125 word: u64,
126 true_values: &[T],
127 false_values: &[T],
128 out: &mut [MaybeUninit<T>],
129) {
130 let n = out.len();
131 let true_values = &true_values[..n];
132 let false_values = &false_values[..n];
133 for j in 0..n {
134 let pick = (word >> j) & 1 == 1;
135 out[j].write(if pick {
136 true_values[j]
137 } else {
138 false_values[j]
139 });
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 #![allow(
146 clippy::cast_possible_truncation,
147 reason = "test fixtures use small indices that fit the target widths"
148 )]
149
150 use vortex_error::VortexResult;
151 use vortex_mask::Mask;
152
153 use crate::ArrayRef;
154 use crate::IntoArray;
155 use crate::LEGACY_SESSION;
156 use crate::VortexSessionExecute;
157 use crate::arrays::Primitive;
158 use crate::arrays::PrimitiveArray;
159 use crate::assert_arrays_eq;
160 use crate::builtins::ArrayBuiltins;
161
162 #[test]
165 fn zip_nonnull_spans_mask_chunks() -> VortexResult<()> {
166 let len = 150usize;
167 let if_true = PrimitiveArray::from_iter(0..len as i64).into_array();
168 let if_false = PrimitiveArray::from_iter((0..len as i64).map(|i| 1_000 + i)).into_array();
169
170 let bits: Vec<bool> = (0..len).map(|i| i.is_multiple_of(3) || i == 64).collect();
171 let mask = Mask::from_iter(bits.iter().copied());
172
173 let mut ctx = LEGACY_SESSION.create_execution_ctx();
174 let result = mask
175 .into_array()
176 .zip(if_true, if_false)?
177 .execute::<ArrayRef>(&mut ctx)?;
178 assert!(result.is::<Primitive>());
179
180 let expected = PrimitiveArray::from_iter(
181 (0..len).map(|i| if bits[i] { i as i64 } else { 1_000 + i as i64 }),
182 )
183 .into_array();
184 assert_arrays_eq!(result, expected);
185 Ok(())
186 }
187
188 #[test]
191 fn zip_nullable_selects_values_and_validity() -> VortexResult<()> {
192 let len = 130usize;
193 let if_true =
194 PrimitiveArray::from_option_iter((0..len as i64).map(|i| (i % 4 != 0).then_some(i)))
195 .into_array();
196 let if_false = PrimitiveArray::from_option_iter(
197 (0..len as i64).map(|i| (i % 5 != 0).then_some(1_000 + i)),
198 )
199 .into_array();
200
201 let bits: Vec<bool> = (0..len).map(|i| i.is_multiple_of(2)).collect();
202 let mask = Mask::from_iter(bits.iter().copied());
203
204 let mut ctx = LEGACY_SESSION.create_execution_ctx();
205 let result = mask
206 .into_array()
207 .zip(if_true, if_false)?
208 .execute::<ArrayRef>(&mut ctx)?;
209 assert!(result.is::<Primitive>());
210
211 let expected = PrimitiveArray::from_option_iter((0..len).map(|i| {
212 let v = i as i64;
213 if bits[i] {
214 (!i.is_multiple_of(4)).then_some(v)
215 } else {
216 (!i.is_multiple_of(5)).then_some(1_000 + v)
217 }
218 }))
219 .into_array();
220 assert_arrays_eq!(result, expected);
221 Ok(())
222 }
223}