Skip to main content

vortex_array/arrays/varbinview/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_buffer::BufferMut;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_mask::AllOr;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::array::ArrayView;
16use crate::arrays::VarBinView;
17use crate::arrays::VarBinViewArray;
18use crate::arrays::varbinview::BinaryView;
19use crate::arrays::varbinview::VarBinViewArrayExt;
20use crate::builders::DeduplicatedBuffers;
21use crate::builders::LazyBitBufferBuilder;
22use crate::scalar_fn::fns::zip::ZipKernel;
23
24// A dedicated VarBinView zip kernel that builds the result directly by adjusting views and validity,
25// instead of routing through the generic builder (which would redo buffer lookups per mask slice).
26impl ZipKernel for VarBinView {
27    fn zip(
28        if_true: ArrayView<'_, VarBinView>,
29        if_false: &ArrayRef,
30        mask: &ArrayRef,
31        ctx: &mut ExecutionCtx,
32    ) -> VortexResult<Option<ArrayRef>> {
33        let Some(if_false) = if_false.as_opt::<VarBinView>() else {
34            return Ok(None);
35        };
36
37        if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
38            vortex_bail!("input arrays to zip must have the same dtype");
39        }
40
41        let len = if_true.len();
42        let dtype = if_true
43            .dtype()
44            .union_nullability(if_false.dtype().nullability());
45
46        // build buffer lookup tables for both arrays, these map from the original buffer idx
47        // to the new buffer index in the result array
48        let mut buffers = DeduplicatedBuffers::default();
49        let true_lookup =
50            buffers.extend_from_iter(if_true.data_buffers().iter().map(|b| b.as_host().clone()));
51        let false_lookup =
52            buffers.extend_from_iter(if_false.data_buffers().iter().map(|b| b.as_host().clone()));
53
54        let mut views_builder = BufferMut::<BinaryView>::with_capacity(len);
55        let mut validity_builder = LazyBitBufferBuilder::new(len);
56
57        let true_validity = if_true.varbinview_validity().execute_mask(len, ctx)?;
58        let false_validity = if_false.varbinview_validity().execute_mask(len, ctx)?;
59
60        let mask = mask.try_to_mask_fill_null_false(ctx)?;
61        let if_false_view = if_false;
62        match mask.slices() {
63            AllOr::All => push_range(
64                if_true,
65                &true_lookup,
66                &true_validity,
67                0..len,
68                &mut views_builder,
69                &mut validity_builder,
70            ),
71            AllOr::None => push_range(
72                if_false_view,
73                &false_lookup,
74                &false_validity,
75                0..len,
76                &mut views_builder,
77                &mut validity_builder,
78            ),
79            AllOr::Some(slices) => {
80                let mut pos = 0;
81                for (start, end) in slices {
82                    if pos < *start {
83                        push_range(
84                            if_false_view,
85                            &false_lookup,
86                            &false_validity,
87                            pos..*start,
88                            &mut views_builder,
89                            &mut validity_builder,
90                        );
91                    }
92                    push_range(
93                        if_true,
94                        &true_lookup,
95                        &true_validity,
96                        *start..*end,
97                        &mut views_builder,
98                        &mut validity_builder,
99                    );
100                    pos = *end;
101                }
102                if pos < len {
103                    push_range(
104                        if_false_view,
105                        &false_lookup,
106                        &false_validity,
107                        pos..len,
108                        &mut views_builder,
109                        &mut validity_builder,
110                    );
111                }
112            }
113        }
114
115        let validity = validity_builder.finish_with_nullability(dtype.nullability());
116
117        // SAFETY: views are built with adjusted buffer indices, validity tracked alongside;
118        // buffers come from `DeduplicatedBuffers`, dtype/nullability preserved.
119        let array = unsafe {
120            VarBinViewArray::new_unchecked(
121                views_builder.freeze(),
122                buffers.finish(),
123                dtype,
124                validity,
125            )
126        };
127
128        Ok(Some(array.into_array()))
129    }
130}
131
132fn push_range(
133    array: ArrayView<'_, VarBinView>,
134    buffer_lookup: &[u32],
135    validity: &Mask,
136    range: Range<usize>,
137    views_builder: &mut BufferMut<BinaryView>,
138    validity_builder: &mut LazyBitBufferBuilder,
139) {
140    let views = array.views();
141
142    match validity.bit_buffer() {
143        AllOr::All => {
144            for idx in range {
145                push_view(
146                    views[idx],
147                    buffer_lookup,
148                    true,
149                    views_builder,
150                    validity_builder,
151                );
152            }
153        }
154        AllOr::None => {
155            for _ in range {
156                push_view(
157                    BinaryView::empty_view(),
158                    buffer_lookup,
159                    false,
160                    views_builder,
161                    validity_builder,
162                );
163            }
164        }
165        AllOr::Some(bit_buffer) => {
166            for idx in range {
167                let is_valid = bit_buffer.value(idx);
168                push_view(
169                    views[idx],
170                    buffer_lookup,
171                    is_valid,
172                    views_builder,
173                    validity_builder,
174                );
175            }
176        }
177    }
178}
179
180#[inline]
181fn push_view(
182    view: BinaryView,
183    buffer_lookup: &[u32],
184    is_valid: bool,
185    views_builder: &mut BufferMut<BinaryView>,
186    validity_builder: &mut LazyBitBufferBuilder,
187) {
188    if !is_valid {
189        views_builder.push(BinaryView::empty_view());
190        validity_builder.append_null();
191        return;
192    }
193
194    let adjusted = if view.is_inlined() {
195        view
196    } else {
197        let view_ref = view.as_view();
198        view_ref
199            .with_buffer_and_offset(
200                buffer_lookup[view_ref.buffer_index as usize],
201                view_ref.offset,
202            )
203            .into()
204    };
205
206    views_builder.push(adjusted);
207    validity_builder.append_non_null();
208}
209
210#[cfg(test)]
211mod tests {
212    use vortex_mask::Mask;
213
214    use crate::IntoArray;
215    use crate::accessor::ArrayAccessor;
216    use crate::arrays::VarBinViewArray;
217    use crate::builtins::ArrayBuiltins;
218    #[expect(deprecated)]
219    use crate::canonical::ToCanonical as _;
220    use crate::dtype::DType;
221    use crate::dtype::Nullability;
222
223    #[test]
224    fn zip_varbinview_kernel_zips() {
225        let a = VarBinViewArray::from_iter(
226            [
227                Some("aaaaaaaaaaaaa_long"), // outlined
228                Some("short"),
229                None,
230                Some("bbbbbbbbbbbbbbbb_long"),
231                Some("tiny"),
232                Some("cccccccccccccccc_long"),
233            ],
234            DType::Utf8(Nullability::Nullable),
235        );
236
237        let b = VarBinViewArray::from_iter(
238            [
239                Some("dddddddddddddddd_long"),
240                Some("eeeeeeeeeeeeeeee_long"),
241                Some("ffff"),
242                Some("gggggggggggggggg_long"),
243                None,
244                Some("hhhhhhhhhhhhhhhh_long"),
245            ],
246            DType::Utf8(Nullability::Nullable),
247        );
248
249        let mask = Mask::from_iter([true, false, true, false, false, true]);
250
251        #[expect(deprecated)]
252        let zipped = mask
253            .clone()
254            .into_array()
255            .zip(a.into_array(), b.into_array())
256            .unwrap()
257            .to_varbinview();
258
259        let values = zipped.with_iterator(|it| {
260            it.map(|v| v.map(|bytes| String::from_utf8(bytes.to_vec()).unwrap()))
261                .collect::<Vec<_>>()
262        });
263
264        assert_eq!(
265            values,
266            vec![
267                Some("aaaaaaaaaaaaa_long".to_string()),
268                Some("eeeeeeeeeeeeeeee_long".to_string()),
269                None,
270                Some("gggggggggggggggg_long".to_string()),
271                None,
272                Some("cccccccccccccccc_long".to_string())
273            ]
274        );
275        assert_eq!(zipped.len(), mask.len());
276        assert_eq!(zipped.dtype(), &DType::Utf8(Nullability::Nullable));
277    }
278}