vortex_array/arrays/varbinview/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_buffer::BufferMut;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_mask::AllOr;
10use vortex_mask::Mask;
11use vortex_vector::binaryview::BinaryView;
12
13use crate::Array;
14use crate::ArrayRef;
15use crate::arrays::VarBinViewArray;
16use crate::arrays::VarBinViewVTable;
17use crate::builders::DeduplicatedBuffers;
18use crate::builders::LazyBitBufferBuilder;
19use crate::compute::ZipKernel;
20use crate::compute::ZipKernelAdapter;
21use crate::register_kernel;
22
23// A dedicated VarBinView zip kernel that builds the result directly by adjusting views and validity,
24// instead of routing through the generic builder (which would redo buffer lookups per mask slice).
25impl ZipKernel for VarBinViewVTable {
26    fn zip(
27        &self,
28        if_true: &VarBinViewArray,
29        if_false: &dyn Array,
30        mask: &Mask,
31    ) -> VortexResult<Option<ArrayRef>> {
32        let Some(if_false) = if_false.as_opt::<VarBinViewVTable>() else {
33            return Ok(None);
34        };
35
36        if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
37            vortex_bail!("input arrays to zip must have the same dtype");
38        }
39
40        // compute fn already asserts if_true.len() == if_false.len()
41        let len = if_true.len();
42        let dtype = if_true
43            .dtype()
44            .union_nullability(if_false.dtype().nullability());
45
46        // build buffer lookup tables for both arrays, these map from the original buffer idx
47        // to the new buffer index in the result array
48        let mut buffers = DeduplicatedBuffers::default();
49        let true_lookup = buffers.extend_from_slice(if_true.buffers());
50        let false_lookup = buffers.extend_from_slice(if_false.buffers());
51
52        let mut views_builder = BufferMut::<BinaryView>::with_capacity(len);
53        let mut validity_builder = LazyBitBufferBuilder::new(len);
54
55        let true_validity = if_true.validity_mask();
56        let false_validity = if_false.validity_mask();
57
58        match mask.slices() {
59            AllOr::All => push_range(
60                if_true,
61                &true_lookup,
62                &true_validity,
63                0..len,
64                &mut views_builder,
65                &mut validity_builder,
66            ),
67            AllOr::None => push_range(
68                if_false,
69                &false_lookup,
70                &false_validity,
71                0..len,
72                &mut views_builder,
73                &mut validity_builder,
74            ),
75            AllOr::Some(slices) => {
76                let mut pos = 0;
77                for (start, end) in slices {
78                    if pos < *start {
79                        push_range(
80                            if_false,
81                            &false_lookup,
82                            &false_validity,
83                            pos..*start,
84                            &mut views_builder,
85                            &mut validity_builder,
86                        );
87                    }
88                    push_range(
89                        if_true,
90                        &true_lookup,
91                        &true_validity,
92                        *start..*end,
93                        &mut views_builder,
94                        &mut validity_builder,
95                    );
96                    pos = *end;
97                }
98                if pos < len {
99                    push_range(
100                        if_false,
101                        &false_lookup,
102                        &false_validity,
103                        pos..len,
104                        &mut views_builder,
105                        &mut validity_builder,
106                    );
107                }
108            }
109        }
110
111        let validity = validity_builder.finish_with_nullability(dtype.nullability());
112
113        // SAFETY: views are built with adjusted buffer indices, validity tracked alongside;
114        // buffers come from `DeduplicatedBuffers`, dtype/nullability preserved.
115        let array = unsafe {
116            VarBinViewArray::new_unchecked(
117                views_builder.freeze(),
118                buffers.finish(),
119                dtype,
120                validity,
121            )
122        };
123
124        Ok(Some(array.to_array()))
125    }
126}
127
128fn push_range(
129    array: &VarBinViewArray,
130    buffer_lookup: &[u32],
131    validity: &Mask,
132    range: Range<usize>,
133    views_builder: &mut BufferMut<BinaryView>,
134    validity_builder: &mut LazyBitBufferBuilder,
135) {
136    let views = array.views();
137
138    match validity.bit_buffer() {
139        AllOr::All => {
140            for idx in range {
141                push_view(
142                    views[idx],
143                    buffer_lookup,
144                    true,
145                    views_builder,
146                    validity_builder,
147                );
148            }
149        }
150        AllOr::None => {
151            for _ in range {
152                push_view(
153                    BinaryView::empty_view(),
154                    buffer_lookup,
155                    false,
156                    views_builder,
157                    validity_builder,
158                );
159            }
160        }
161        AllOr::Some(bit_buffer) => {
162            for idx in range {
163                let is_valid = bit_buffer.value(idx);
164                push_view(
165                    views[idx],
166                    buffer_lookup,
167                    is_valid,
168                    views_builder,
169                    validity_builder,
170                );
171            }
172        }
173    }
174}
175
176#[inline]
177fn push_view(
178    view: BinaryView,
179    buffer_lookup: &[u32],
180    is_valid: bool,
181    views_builder: &mut BufferMut<BinaryView>,
182    validity_builder: &mut LazyBitBufferBuilder,
183) {
184    if !is_valid {
185        views_builder.push(BinaryView::empty_view());
186        validity_builder.append_null();
187        return;
188    }
189
190    let adjusted = if view.is_inlined() {
191        view
192    } else {
193        let view_ref = view.as_view();
194        view_ref
195            .with_buffer_and_offset(
196                buffer_lookup[view_ref.buffer_index as usize],
197                view_ref.offset,
198            )
199            .into()
200    };
201
202    views_builder.push(adjusted);
203    validity_builder.append_non_null();
204}
205
206register_kernel!(ZipKernelAdapter(VarBinViewVTable).lift());
207
208#[cfg(test)]
209mod tests {
210    use vortex_dtype::DType;
211    use vortex_dtype::Nullability;
212    use vortex_mask::Mask;
213
214    use crate::accessor::ArrayAccessor;
215    use crate::arrays::VarBinViewArray;
216    use crate::canonical::ToCanonical;
217    use crate::compute::zip;
218
219    #[test]
220    fn zip_varbinview_kernel_zips() {
221        let a = VarBinViewArray::from_iter(
222            [
223                Some("aaaaaaaaaaaaa_long"), // outlined
224                Some("short"),
225                None,
226                Some("bbbbbbbbbbbbbbbb_long"),
227                Some("tiny"),
228                Some("cccccccccccccccc_long"),
229            ],
230            DType::Utf8(Nullability::Nullable),
231        );
232
233        let b = VarBinViewArray::from_iter(
234            [
235                Some("dddddddddddddddd_long"),
236                Some("eeeeeeeeeeeeeeee_long"),
237                Some("ffff"),
238                Some("gggggggggggggggg_long"),
239                None,
240                Some("hhhhhhhhhhhhhhhh_long"),
241            ],
242            DType::Utf8(Nullability::Nullable),
243        );
244
245        let mask = Mask::from_iter([true, false, true, false, false, true]);
246
247        let zipped = zip(a.as_ref(), b.as_ref(), &mask).unwrap().to_varbinview();
248
249        let values = zipped.with_iterator(|it| {
250            it.map(|v| v.map(|bytes| String::from_utf8(bytes.to_vec()).unwrap()))
251                .collect::<Vec<_>>()
252        });
253
254        assert_eq!(
255            values,
256            vec![
257                Some("aaaaaaaaaaaaa_long".to_string()),
258                Some("eeeeeeeeeeeeeeee_long".to_string()),
259                None,
260                Some("gggggggggggggggg_long".to_string()),
261                None,
262                Some("cccccccccccccccc_long".to_string())
263            ]
264        );
265        assert_eq!(zipped.len(), mask.len());
266        assert_eq!(zipped.dtype(), &DType::Utf8(Nullability::Nullable));
267    }
268}