Skip to main content

vortex_array/arrays/varbinview/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use vortex_buffer::BufferMut;
7use vortex_error::VortexResult;
8use vortex_error::vortex_bail;
9use vortex_mask::AllOr;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::IntoArray;
15use crate::arrays::VarBinView;
16use crate::arrays::VarBinViewArray;
17use crate::arrays::varbinview::BinaryView;
18use crate::builders::DeduplicatedBuffers;
19use crate::builders::LazyBitBufferBuilder;
20use crate::scalar_fn::fns::zip::ZipKernel;
21
22// A dedicated VarBinView zip kernel that builds the result directly by adjusting views and validity,
23// instead of routing through the generic builder (which would redo buffer lookups per mask slice).
24impl ZipKernel for VarBinView {
25    fn zip(
26        if_true: &VarBinViewArray,
27        if_false: &ArrayRef,
28        mask: &ArrayRef,
29        ctx: &mut ExecutionCtx,
30    ) -> VortexResult<Option<ArrayRef>> {
31        let Some(if_false) = if_false.as_opt::<VarBinView>() else {
32            return Ok(None);
33        };
34
35        if !if_true.dtype().eq_ignore_nullability(if_false.dtype()) {
36            vortex_bail!("input arrays to zip must have the same dtype");
37        }
38
39        let len = if_true.len();
40        let dtype = if_true
41            .dtype()
42            .union_nullability(if_false.dtype().nullability());
43
44        // build buffer lookup tables for both arrays, these map from the original buffer idx
45        // to the new buffer index in the result array
46        let mut buffers = DeduplicatedBuffers::default();
47        let true_lookup =
48            buffers.extend_from_iter(if_true.buffers().iter().map(|b| b.as_host().clone()));
49        let false_lookup =
50            buffers.extend_from_iter(if_false.buffers().iter().map(|b| b.as_host().clone()));
51
52        let mut views_builder = BufferMut::<BinaryView>::with_capacity(len);
53        let mut validity_builder = LazyBitBufferBuilder::new(len);
54
55        let true_validity = if_true.validity_mask()?;
56        let false_validity = if_false.validity_mask()?;
57
58        let mask = mask.try_to_mask_fill_null_false(ctx)?;
59        match mask.slices() {
60            AllOr::All => push_range(
61                if_true,
62                &true_lookup,
63                &true_validity,
64                0..len,
65                &mut views_builder,
66                &mut validity_builder,
67            ),
68            AllOr::None => push_range(
69                if_false,
70                &false_lookup,
71                &false_validity,
72                0..len,
73                &mut views_builder,
74                &mut validity_builder,
75            ),
76            AllOr::Some(slices) => {
77                let mut pos = 0;
78                for (start, end) in slices {
79                    if pos < *start {
80                        push_range(
81                            if_false,
82                            &false_lookup,
83                            &false_validity,
84                            pos..*start,
85                            &mut views_builder,
86                            &mut validity_builder,
87                        );
88                    }
89                    push_range(
90                        if_true,
91                        &true_lookup,
92                        &true_validity,
93                        *start..*end,
94                        &mut views_builder,
95                        &mut validity_builder,
96                    );
97                    pos = *end;
98                }
99                if pos < len {
100                    push_range(
101                        if_false,
102                        &false_lookup,
103                        &false_validity,
104                        pos..len,
105                        &mut views_builder,
106                        &mut validity_builder,
107                    );
108                }
109            }
110        }
111
112        let validity = validity_builder.finish_with_nullability(dtype.nullability());
113
114        // SAFETY: views are built with adjusted buffer indices, validity tracked alongside;
115        // buffers come from `DeduplicatedBuffers`, dtype/nullability preserved.
116        let array = unsafe {
117            VarBinViewArray::new_unchecked(
118                views_builder.freeze(),
119                buffers.finish(),
120                dtype,
121                validity,
122            )
123        };
124
125        Ok(Some(array.into_array()))
126    }
127}
128
129fn push_range(
130    array: &VarBinViewArray,
131    buffer_lookup: &[u32],
132    validity: &Mask,
133    range: Range<usize>,
134    views_builder: &mut BufferMut<BinaryView>,
135    validity_builder: &mut LazyBitBufferBuilder,
136) {
137    let views = array.views();
138
139    match validity.bit_buffer() {
140        AllOr::All => {
141            for idx in range {
142                push_view(
143                    views[idx],
144                    buffer_lookup,
145                    true,
146                    views_builder,
147                    validity_builder,
148                );
149            }
150        }
151        AllOr::None => {
152            for _ in range {
153                push_view(
154                    BinaryView::empty_view(),
155                    buffer_lookup,
156                    false,
157                    views_builder,
158                    validity_builder,
159                );
160            }
161        }
162        AllOr::Some(bit_buffer) => {
163            for idx in range {
164                let is_valid = bit_buffer.value(idx);
165                push_view(
166                    views[idx],
167                    buffer_lookup,
168                    is_valid,
169                    views_builder,
170                    validity_builder,
171                );
172            }
173        }
174    }
175}
176
177fn push_view(
178    view: BinaryView,
179    buffer_lookup: &[u32],
180    is_valid: bool,
181    views_builder: &mut BufferMut<BinaryView>,
182    validity_builder: &mut LazyBitBufferBuilder,
183) {
184    if !is_valid {
185        views_builder.push(BinaryView::empty_view());
186        validity_builder.append_null();
187        return;
188    }
189
190    let adjusted = if view.is_inlined() {
191        view
192    } else {
193        let view_ref = view.as_view();
194        view_ref
195            .with_buffer_and_offset(
196                buffer_lookup[view_ref.buffer_index as usize],
197                view_ref.offset,
198            )
199            .into()
200    };
201
202    views_builder.push(adjusted);
203    validity_builder.append_non_null();
204}
205
206#[cfg(test)]
207mod tests {
208    use vortex_mask::Mask;
209
210    use crate::IntoArray;
211    use crate::accessor::ArrayAccessor;
212    use crate::arrays::VarBinViewArray;
213    use crate::builtins::ArrayBuiltins;
214    use crate::canonical::ToCanonical;
215    use crate::dtype::DType;
216    use crate::dtype::Nullability;
217
218    #[test]
219    fn zip_varbinview_kernel_zips() {
220        let a = VarBinViewArray::from_iter(
221            [
222                Some("aaaaaaaaaaaaa_long"), // outlined
223                Some("short"),
224                None,
225                Some("bbbbbbbbbbbbbbbb_long"),
226                Some("tiny"),
227                Some("cccccccccccccccc_long"),
228            ],
229            DType::Utf8(Nullability::Nullable),
230        );
231
232        let b = VarBinViewArray::from_iter(
233            [
234                Some("dddddddddddddddd_long"),
235                Some("eeeeeeeeeeeeeeee_long"),
236                Some("ffff"),
237                Some("gggggggggggggggg_long"),
238                None,
239                Some("hhhhhhhhhhhhhhhh_long"),
240            ],
241            DType::Utf8(Nullability::Nullable),
242        );
243
244        let mask = Mask::from_iter([true, false, true, false, false, true]);
245
246        let zipped = mask
247            .clone()
248            .into_array()
249            .zip(a.into_array(), b.into_array())
250            .unwrap()
251            .to_varbinview();
252
253        let values = zipped.with_iterator(|it| {
254            it.map(|v| v.map(|bytes| String::from_utf8(bytes.to_vec()).unwrap()))
255                .collect::<Vec<_>>()
256        });
257
258        assert_eq!(
259            values,
260            vec![
261                Some("aaaaaaaaaaaaa_long".to_string()),
262                Some("eeeeeeeeeeeeeeee_long".to_string()),
263                None,
264                Some("gggggggggggggggg_long".to_string()),
265                None,
266                Some("cccccccccccccccc_long".to_string())
267            ]
268        );
269        assert_eq!(zipped.len(), mask.len());
270        assert_eq!(zipped.dtype(), &DType::Utf8(Nullability::Nullable));
271    }
272}