Skip to main content

vortex_array/arrays/primitive/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::mem::MaybeUninit;
5
6use vortex_buffer::BufferMut;
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::array::ArrayView;
16use crate::arrays::Primitive;
17use crate::arrays::PrimitiveArray;
18use crate::dtype::NativePType;
19use crate::match_each_native_ptype;
20use crate::scalar_fn::fns::zip::ZipKernel;
21use crate::scalar_fn::fns::zip::zip_validity;
22
23/// A dedicated primitive zip kernel that selects values branchlessly per row.
24///
25/// The generic zip path copies runs of `if_true`/`if_false` between mask boundaries, which is fast
26/// for clustered masks but degrades to per-element work on fragmented masks. This kernel instead
27/// walks the mask as 64-bit chunks and blends both sides per row without a data-dependent branch,
28/// so the inner loop stays branch-free and auto-vectorizable regardless of mask shape.
29impl ZipKernel for Primitive {
30    fn zip(
31        if_true: ArrayView<'_, Primitive>,
32        if_false: &ArrayRef,
33        mask: &ArrayRef,
34        ctx: &mut ExecutionCtx,
35    ) -> VortexResult<Option<ArrayRef>> {
36        let Some(if_false) = if_false.as_opt::<Primitive>() else {
37            return Ok(None);
38        };
39
40        if if_true.ptype() != if_false.ptype() {
41            vortex_bail!(
42                "zip requires if_true and if_false to share a primitive type, got {} and {}",
43                if_true.ptype(),
44                if_false.ptype()
45            );
46        }
47
48        // Null mask entries select `if_false`, matching `Zip`'s SQL ELSE semantics.
49        let mask = mask.try_to_mask_fill_null_false(ctx)?;
50        match &mask {
51            // Defer trivial masks to the generic zip, which just casts the surviving side.
52            Mask::AllTrue(_) | Mask::AllFalse(_) => return Ok(None),
53            Mask::Values(_) => {}
54        }
55
56        let validity = zip_validity(if_true.validity()?, if_false.validity()?, &mask)?;
57
58        // TODO(perf): inspect the mask's true_count (and validity) to special-case heavily-skewed
59        // masks. When one side dominates (true_count near 0 or near len), it is cheaper to bulk
60        // copy — or mutate in place, if the dominant side is uniquely owned — that side's values
61        // and validity, then conditionally pull in only the minority rows from the other side,
62        // rather than blending every row.
63        let array = match_each_native_ptype!(if_true.ptype(), |T| {
64            let values =
65                select_values::<T>(if_true.as_slice::<T>(), if_false.as_slice::<T>(), &mask);
66            PrimitiveArray::new(values.freeze(), validity).into_array()
67        });
68        Ok(Some(array))
69    }
70}
71
72/// Branchlessly blend `if_true` and `if_false` per row into a fresh value buffer.
73fn select_values<T: NativePType>(
74    true_values: &[T],
75    false_values: &[T],
76    mask: &Mask,
77) -> BufferMut<T> {
78    let len = true_values.len();
79    let mut out = BufferMut::<T>::with_capacity(len);
80    {
81        let out_slice = out.spare_capacity_mut();
82
83        let mask_bits = mask
84            .values()
85            .vortex_expect("mask is Mask::Values")
86            .bit_buffer();
87        // TODO(perf): `unaligned_chunks` is a faster single-buffer iterator than `chunks`; switch to
88        // it here, handling its lead/trailing padding.
89        let chunks = mask_bits.chunks();
90
91        let mut base = 0;
92        for word in chunks.iter() {
93            let end = base + 64;
94            select_block(
95                word,
96                &true_values[base..end],
97                &false_values[base..end],
98                &mut out_slice[base..end],
99            );
100            base = end;
101        }
102
103        let remainder = chunks.remainder_len();
104        if remainder > 0 {
105            let end = base + remainder;
106            select_block(
107                chunks.remainder_bits(),
108                &true_values[base..end],
109                &false_values[base..end],
110                &mut out_slice[base..end],
111            );
112        }
113    }
114
115    // SAFETY: `select_block` initialized every slot covered by the chunks plus remainder, i.e. `len`.
116    unsafe { out.set_len(len) };
117    out
118}
119
120/// Blend one 64-bit mask chunk's worth of rows: bit `j` (LSB-first) keeps `true_values[j]`, an unset
121/// bit keeps `false_values[j]`. Slices are trimmed to the output length up front so the compiler can
122/// elide bounds checks and lower the body to a vector blend / conditional move.
123#[inline]
124fn select_block<T: NativePType>(
125    word: u64,
126    true_values: &[T],
127    false_values: &[T],
128    out: &mut [MaybeUninit<T>],
129) {
130    let n = out.len();
131    let true_values = &true_values[..n];
132    let false_values = &false_values[..n];
133    for j in 0..n {
134        let pick = (word >> j) & 1 == 1;
135        out[j].write(if pick {
136            true_values[j]
137        } else {
138            false_values[j]
139        });
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    #![allow(
146        clippy::cast_possible_truncation,
147        reason = "test fixtures use small indices that fit the target widths"
148    )]
149
150    use vortex_error::VortexResult;
151    use vortex_mask::Mask;
152
153    use crate::ArrayRef;
154    use crate::IntoArray;
155    use crate::LEGACY_SESSION;
156    use crate::VortexSessionExecute;
157    use crate::arrays::Primitive;
158    use crate::arrays::PrimitiveArray;
159    use crate::assert_arrays_eq;
160    use crate::builtins::ArrayBuiltins;
161
162    /// The branchless kernel must agree with the scalar reference across the chunk boundary (index
163    /// 63/64) and the trailing remainder, for non-nullable inputs.
164    #[test]
165    fn zip_nonnull_spans_mask_chunks() -> VortexResult<()> {
166        let len = 150usize;
167        let if_true = PrimitiveArray::from_iter(0..len as i64).into_array();
168        let if_false = PrimitiveArray::from_iter((0..len as i64).map(|i| 1_000 + i)).into_array();
169
170        let bits: Vec<bool> = (0..len).map(|i| i.is_multiple_of(3) || i == 64).collect();
171        let mask = Mask::from_iter(bits.iter().copied());
172
173        let mut ctx = LEGACY_SESSION.create_execution_ctx();
174        let result = mask
175            .into_array()
176            .zip(if_true, if_false)?
177            .execute::<ArrayRef>(&mut ctx)?;
178        assert!(result.is::<Primitive>());
179
180        let expected = PrimitiveArray::from_iter(
181            (0..len).map(|i| if bits[i] { i as i64 } else { 1_000 + i as i64 }),
182        )
183        .into_array();
184        assert_arrays_eq!(result, expected);
185        Ok(())
186    }
187
188    /// With `Validity::Array` on both sides the kernel must select values and validity from the
189    /// chosen side across the chunk boundary.
190    #[test]
191    fn zip_nullable_selects_values_and_validity() -> VortexResult<()> {
192        let len = 130usize;
193        let if_true =
194            PrimitiveArray::from_option_iter((0..len as i64).map(|i| (i % 4 != 0).then_some(i)))
195                .into_array();
196        let if_false = PrimitiveArray::from_option_iter(
197            (0..len as i64).map(|i| (i % 5 != 0).then_some(1_000 + i)),
198        )
199        .into_array();
200
201        let bits: Vec<bool> = (0..len).map(|i| i.is_multiple_of(2)).collect();
202        let mask = Mask::from_iter(bits.iter().copied());
203
204        let mut ctx = LEGACY_SESSION.create_execution_ctx();
205        let result = mask
206            .into_array()
207            .zip(if_true, if_false)?
208            .execute::<ArrayRef>(&mut ctx)?;
209        assert!(result.is::<Primitive>());
210
211        let expected = PrimitiveArray::from_option_iter((0..len).map(|i| {
212            let v = i as i64;
213            if bits[i] {
214                (!i.is_multiple_of(4)).then_some(v)
215            } else {
216                (!i.is_multiple_of(5)).then_some(1_000 + v)
217            }
218        }))
219        .into_array();
220        assert_arrays_eq!(result, expected);
221        Ok(())
222    }
223}