Skip to main content

vortex_array/arrays/varbinview/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_buffer::BufferMut;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_mask::AllOr;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::arrays::VarBinViewArray;
16use crate::arrays::VarBinViewVTable;
17use crate::arrays::varbinview::BinaryView;
18use crate::builders::DeduplicatedBuffers;
19use crate::builders::LazyBitBufferBuilder;
20use crate::scalar_fn::fns::zip::ZipKernel;
21
22// A dedicated VarBinView zip kernel that builds the result directly by adjusting views and validity,
23// instead of routing through the generic builder (which would redo buffer lookups per mask slice).
24impl ZipKernel for VarBinViewVTable {
25    fn zip(
26        if_true: &VarBinViewArray,
27        if_false: &ArrayRef,
28        mask: &ArrayRef,
29        ctx: &mut ExecutionCtx,
30    ) -> VortexResult<Option<ArrayRef>> {
31        let Some(if_false) = if_false.as_opt::<VarBinViewVTable>() else {
32            return Ok(None);
33        };
34
35        if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
36            vortex_bail!("input arrays to zip must have the same dtype");
37        }
38
39        let len = if_true.len();
40        let dtype = if_true
41            .dtype()
42            .union_nullability(if_false.dtype().nullability());
43
44        // build buffer lookup tables for both arrays, these map from the original buffer idx
45        // to the new buffer index in the result array
46        let mut buffers = DeduplicatedBuffers::default();
47        let true_lookup =
48            buffers.extend_from_iter(if_true.buffers().iter().map(|b| b.as_host().clone()));
49        let false_lookup =
50            buffers.extend_from_iter(if_false.buffers().iter().map(|b| b.as_host().clone()));
51
52        let mut views_builder = BufferMut::<BinaryView>::with_capacity(len);
53        let mut validity_builder = LazyBitBufferBuilder::new(len);
54
55        let true_validity = if_true.validity_mask()?;
56        let false_validity = if_false.validity_mask()?;
57
58        let mask = mask.try_to_mask_fill_null_false(ctx)?;
59        match mask.slices() {
60            AllOr::All => push_range(
61                if_true,
62                &true_lookup,
63                &true_validity,
64                0..len,
65                &mut views_builder,
66                &mut validity_builder,
67            ),
68            AllOr::None => push_range(
69                if_false,
70                &false_lookup,
71                &false_validity,
72                0..len,
73                &mut views_builder,
74                &mut validity_builder,
75            ),
76            AllOr::Some(slices) => {
77                let mut pos = 0;
78                for (start, end) in slices {
79                    if pos < *start {
80                        push_range(
81                            if_false,
82                            &false_lookup,
83                            &false_validity,
84                            pos..*start,
85                            &mut views_builder,
86                            &mut validity_builder,
87                        );
88                    }
89                    push_range(
90                        if_true,
91                        &true_lookup,
92                        &true_validity,
93                        *start..*end,
94                        &mut views_builder,
95                        &mut validity_builder,
96                    );
97                    pos = *end;
98                }
99                if pos < len {
100                    push_range(
101                        if_false,
102                        &false_lookup,
103                        &false_validity,
104                        pos..len,
105                        &mut views_builder,
106                        &mut validity_builder,
107                    );
108                }
109            }
110        }
111
112        let validity = validity_builder.finish_with_nullability(dtype.nullability());
113
114        // SAFETY: views are built with adjusted buffer indices, validity tracked alongside;
115        // buffers come from `DeduplicatedBuffers`, dtype/nullability preserved.
116        let array = unsafe {
117            VarBinViewArray::new_unchecked(
118                views_builder.freeze(),
119                buffers.finish(),
120                dtype,
121                validity,
122            )
123        };
124
125        Ok(Some(array.into_array()))
126    }
127}
128
129fn push_range(
130    array: &VarBinViewArray,
131    buffer_lookup: &[u32],
132    validity: &Mask,
133    range: Range<usize>,
134    views_builder: &mut BufferMut<BinaryView>,
135    validity_builder: &mut LazyBitBufferBuilder,
136) {
137    let views = array.views();
138
139    match validity.bit_buffer() {
140        AllOr::All => {
141            for idx in range {
142                push_view(
143                    views[idx],
144                    buffer_lookup,
145                    true,
146                    views_builder,
147                    validity_builder,
148                );
149            }
150        }
151        AllOr::None => {
152            for _ in range {
153                push_view(
154                    BinaryView::empty_view(),
155                    buffer_lookup,
156                    false,
157                    views_builder,
158                    validity_builder,
159                );
160            }
161        }
162        AllOr::Some(bit_buffer) => {
163            for idx in range {
164                let is_valid = bit_buffer.value(idx);
165                push_view(
166                    views[idx],
167                    buffer_lookup,
168                    is_valid,
169                    views_builder,
170                    validity_builder,
171                );
172            }
173        }
174    }
175}
176
177#[inline]
178fn push_view(
179    view: BinaryView,
180    buffer_lookup: &[u32],
181    is_valid: bool,
182    views_builder: &mut BufferMut<BinaryView>,
183    validity_builder: &mut LazyBitBufferBuilder,
184) {
185    if !is_valid {
186        views_builder.push(BinaryView::empty_view());
187        validity_builder.append_null();
188        return;
189    }
190
191    let adjusted = if view.is_inlined() {
192        view
193    } else {
194        let view_ref = view.as_view();
195        view_ref
196            .with_buffer_and_offset(
197                buffer_lookup[view_ref.buffer_index as usize],
198                view_ref.offset,
199            )
200            .into()
201    };
202
203    views_builder.push(adjusted);
204    validity_builder.append_non_null();
205}
206
207#[cfg(test)]
208mod tests {
209    use vortex_mask::Mask;
210
211    use crate::IntoArray;
212    use crate::accessor::ArrayAccessor;
213    use crate::arrays::VarBinViewArray;
214    use crate::builtins::ArrayBuiltins;
215    use crate::canonical::ToCanonical;
216    use crate::dtype::DType;
217    use crate::dtype::Nullability;
218
219    #[test]
220    fn zip_varbinview_kernel_zips() {
221        let a = VarBinViewArray::from_iter(
222            [
223                Some("aaaaaaaaaaaaa_long"), // outlined
224                Some("short"),
225                None,
226                Some("bbbbbbbbbbbbbbbb_long"),
227                Some("tiny"),
228                Some("cccccccccccccccc_long"),
229            ],
230            DType::Utf8(Nullability::Nullable),
231        );
232
233        let b = VarBinViewArray::from_iter(
234            [
235                Some("dddddddddddddddd_long"),
236                Some("eeeeeeeeeeeeeeee_long"),
237                Some("ffff"),
238                Some("gggggggggggggggg_long"),
239                None,
240                Some("hhhhhhhhhhhhhhhh_long"),
241            ],
242            DType::Utf8(Nullability::Nullable),
243        );
244
245        let mask = Mask::from_iter([true, false, true, false, false, true]);
246
247        let zipped = mask
248            .clone()
249            .into_array()
250            .zip(a.into_array(), b.into_array())
251            .unwrap()
252            .to_varbinview();
253
254        let values = zipped.with_iterator(|it| {
255            it.map(|v| v.map(|bytes| String::from_utf8(bytes.to_vec()).unwrap()))
256                .collect::<Vec<_>>()
257        });
258
259        assert_eq!(
260            values,
261            vec![
262                Some("aaaaaaaaaaaaa_long".to_string()),
263                Some("eeeeeeeeeeeeeeee_long".to_string()),
264                None,
265                Some("gggggggggggggggg_long".to_string()),
266                None,
267                Some("cccccccccccccccc_long".to_string())
268            ]
269        );
270        assert_eq!(zipped.len(), mask.len());
271        assert_eq!(zipped.dtype(), &DType::Utf8(Nullability::Nullable));
272    }
273}