Skip to main content

vortex_array/arrays/varbinview/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_buffer::BufferMut;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_mask::AllOr;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::array::ArrayView;
16use crate::arrays::VarBinView;
17use crate::arrays::VarBinViewArray;
18use crate::arrays::varbinview::BinaryView;
19use crate::arrays::varbinview::VarBinViewArrayExt;
20use crate::builders::DeduplicatedBuffers;
21use crate::builders::LazyBitBufferBuilder;
22use crate::scalar_fn::fns::zip::ZipKernel;
23
24// A dedicated VarBinView zip kernel that builds the result directly by adjusting views and validity,
25// instead of routing through the generic builder (which would redo buffer lookups per mask slice).
26impl ZipKernel for VarBinView {
27    fn zip(
28        if_true: ArrayView<'_, VarBinView>,
29        if_false: &ArrayRef,
30        mask: &ArrayRef,
31        ctx: &mut ExecutionCtx,
32    ) -> VortexResult<Option<ArrayRef>> {
33        let Some(if_false) = if_false.as_opt::<VarBinView>() else {
34            return Ok(None);
35        };
36
37        if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
38            vortex_bail!("input arrays to zip must have the same dtype");
39        }
40
41        let len = if_true.len();
42        let dtype = if_true
43            .dtype()
44            .union_nullability(if_false.dtype().nullability());
45
46        // build buffer lookup tables for both arrays, these map from the original buffer idx
47        // to the new buffer index in the result array
48        let mut buffers = DeduplicatedBuffers::default();
49        let true_lookup =
50            buffers.extend_from_iter(if_true.data_buffers().iter().map(|b| b.as_host().clone()));
51        let false_lookup =
52            buffers.extend_from_iter(if_false.data_buffers().iter().map(|b| b.as_host().clone()));
53
54        let mut views_builder = BufferMut::<BinaryView>::with_capacity(len);
55        let mut validity_builder = LazyBitBufferBuilder::new(len);
56
57        let true_validity = if_true.varbinview_validity_mask();
58        let false_validity = if_false.varbinview_validity_mask();
59
60        let mask = mask.try_to_mask_fill_null_false(ctx)?;
61        let if_false_view = if_false;
62        match mask.slices() {
63            AllOr::All => push_range(
64                if_true,
65                &true_lookup,
66                &true_validity,
67                0..len,
68                &mut views_builder,
69                &mut validity_builder,
70            ),
71            AllOr::None => push_range(
72                if_false_view,
73                &false_lookup,
74                &false_validity,
75                0..len,
76                &mut views_builder,
77                &mut validity_builder,
78            ),
79            AllOr::Some(slices) => {
80                let mut pos = 0;
81                for (start, end) in slices {
82                    if pos < *start {
83                        push_range(
84                            if_false_view,
85                            &false_lookup,
86                            &false_validity,
87                            pos..*start,
88                            &mut views_builder,
89                            &mut validity_builder,
90                        );
91                    }
92                    push_range(
93                        if_true,
94                        &true_lookup,
95                        &true_validity,
96                        *start..*end,
97                        &mut views_builder,
98                        &mut validity_builder,
99                    );
100                    pos = *end;
101                }
102                if pos < len {
103                    push_range(
104                        if_false_view,
105                        &false_lookup,
106                        &false_validity,
107                        pos..len,
108                        &mut views_builder,
109                        &mut validity_builder,
110                    );
111                }
112            }
113        }
114
115        let validity = validity_builder.finish_with_nullability(dtype.nullability());
116
117        // SAFETY: views are built with adjusted buffer indices, validity tracked alongside;
118        // buffers come from `DeduplicatedBuffers`, dtype/nullability preserved.
119        let array = unsafe {
120            VarBinViewArray::new_unchecked(
121                views_builder.freeze(),
122                buffers.finish(),
123                dtype,
124                validity,
125            )
126        };
127
128        Ok(Some(array.into_array()))
129    }
130}
131
132fn push_range(
133    array: ArrayView<'_, VarBinView>,
134    buffer_lookup: &[u32],
135    validity: &Mask,
136    range: Range<usize>,
137    views_builder: &mut BufferMut<BinaryView>,
138    validity_builder: &mut LazyBitBufferBuilder,
139) {
140    let views = array.views();
141
142    match validity.bit_buffer() {
143        AllOr::All => {
144            for idx in range {
145                push_view(
146                    views[idx],
147                    buffer_lookup,
148                    true,
149                    views_builder,
150                    validity_builder,
151                );
152            }
153        }
154        AllOr::None => {
155            for _ in range {
156                push_view(
157                    BinaryView::empty_view(),
158                    buffer_lookup,
159                    false,
160                    views_builder,
161                    validity_builder,
162                );
163            }
164        }
165        AllOr::Some(bit_buffer) => {
166            for idx in range {
167                let is_valid = bit_buffer.value(idx);
168                push_view(
169                    views[idx],
170                    buffer_lookup,
171                    is_valid,
172                    views_builder,
173                    validity_builder,
174                );
175            }
176        }
177    }
178}
179
180fn push_view(
181    view: BinaryView,
182    buffer_lookup: &[u32],
183    is_valid: bool,
184    views_builder: &mut BufferMut<BinaryView>,
185    validity_builder: &mut LazyBitBufferBuilder,
186) {
187    if !is_valid {
188        views_builder.push(BinaryView::empty_view());
189        validity_builder.append_null();
190        return;
191    }
192
193    let adjusted = if view.is_inlined() {
194        view
195    } else {
196        let view_ref = view.as_view();
197        view_ref
198            .with_buffer_and_offset(
199                buffer_lookup[view_ref.buffer_index as usize],
200                view_ref.offset,
201            )
202            .into()
203    };
204
205    views_builder.push(adjusted);
206    validity_builder.append_non_null();
207}
208
209#[cfg(test)]
210mod tests {
211    use vortex_mask::Mask;
212
213    use crate::IntoArray;
214    use crate::accessor::ArrayAccessor;
215    use crate::arrays::VarBinViewArray;
216    use crate::builtins::ArrayBuiltins;
217    use crate::canonical::ToCanonical;
218    use crate::dtype::DType;
219    use crate::dtype::Nullability;
220
221    #[test]
222    fn zip_varbinview_kernel_zips() {
223        let a = VarBinViewArray::from_iter(
224            [
225                Some("aaaaaaaaaaaaa_long"), // outlined
226                Some("short"),
227                None,
228                Some("bbbbbbbbbbbbbbbb_long"),
229                Some("tiny"),
230                Some("cccccccccccccccc_long"),
231            ],
232            DType::Utf8(Nullability::Nullable),
233        );
234
235        let b = VarBinViewArray::from_iter(
236            [
237                Some("dddddddddddddddd_long"),
238                Some("eeeeeeeeeeeeeeee_long"),
239                Some("ffff"),
240                Some("gggggggggggggggg_long"),
241                None,
242                Some("hhhhhhhhhhhhhhhh_long"),
243            ],
244            DType::Utf8(Nullability::Nullable),
245        );
246
247        let mask = Mask::from_iter([true, false, true, false, false, true]);
248
249        let zipped = mask
250            .clone()
251            .into_array()
252            .zip(a.into_array(), b.into_array())
253            .unwrap()
254            .to_varbinview();
255
256        let values = zipped.with_iterator(|it| {
257            it.map(|v| v.map(|bytes| String::from_utf8(bytes.to_vec()).unwrap()))
258                .collect::<Vec<_>>()
259        });
260
261        assert_eq!(
262            values,
263            vec![
264                Some("aaaaaaaaaaaaa_long".to_string()),
265                Some("eeeeeeeeeeeeeeee_long".to_string()),
266                None,
267                Some("gggggggggggggggg_long".to_string()),
268                None,
269                Some("cccccccccccccccc_long".to_string())
270            ]
271        );
272        assert_eq!(zipped.len(), mask.len());
273        assert_eq!(zipped.dtype(), &DType::Utf8(Nullability::Nullable));
274    }
275}