Skip to main content

vortex_array/arrays/chunked/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_mask::Mask;
6
7use crate::Array;
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::arrays::ChunkedArray;
11use crate::arrays::ChunkedVTable;
12use crate::builtins::ArrayBuiltins;
13use crate::scalar_fn::fns::zip::ZipReduce;
14
15// Push down the zip call to the chunks. Without this rule
16// the default implementation canonicalises the chunked array
17// then zips once.
18impl ZipReduce for ChunkedVTable {
19    fn zip(
20        if_true: &ChunkedArray,
21        if_false: &dyn Array,
22        mask: &Mask,
23    ) -> VortexResult<Option<ArrayRef>> {
24        let Some(if_false) = if_false.as_opt::<ChunkedVTable>() else {
25            return Ok(None);
26        };
27        let dtype = if_true
28            .dtype()
29            .union_nullability(if_false.dtype().nullability());
30        let mut out_chunks = Vec::with_capacity(if_true.nchunks() + if_false.nchunks());
31
32        let mut lhs_idx = 0;
33        let mut rhs_idx = 0;
34        let mut lhs_offset = 0;
35        let mut rhs_offset = 0;
36        let mut pos = 0;
37        let total_len = if_true.len();
38
39        while pos < total_len {
40            let lhs_chunk = if_true.chunk(lhs_idx);
41            let rhs_chunk = if_false.chunk(rhs_idx);
42
43            let lhs_rem = lhs_chunk.len() - lhs_offset;
44            let rhs_rem = rhs_chunk.len() - rhs_offset;
45            let take_until = lhs_rem.min(rhs_rem);
46
47            let mask_slice = mask.slice(pos..pos + take_until);
48            let lhs_slice = lhs_chunk.slice(lhs_offset..lhs_offset + take_until)?;
49            let rhs_slice = rhs_chunk.slice(rhs_offset..rhs_offset + take_until)?;
50
51            out_chunks.push(lhs_slice.zip(rhs_slice, mask_slice.into_array())?);
52
53            pos += take_until;
54            lhs_offset += take_until;
55            rhs_offset += take_until;
56
57            if lhs_offset == lhs_chunk.len() {
58                lhs_idx += 1;
59                lhs_offset = 0;
60            }
61            if rhs_offset == rhs_chunk.len() {
62                rhs_idx += 1;
63                rhs_offset = 0;
64            }
65        }
66
67        // SAFETY: chunks originate from zipping slices of inputs that share dtype/nullability.
68        let chunked = unsafe { ChunkedArray::new_unchecked(out_chunks, dtype) };
69        Ok(Some(chunked.to_array()))
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use vortex_buffer::buffer;
76    use vortex_mask::Mask;
77
78    use crate::IntoArray;
79    use crate::ToCanonical;
80    use crate::arrays::ChunkedArray;
81    use crate::arrays::ChunkedVTable;
82    #[expect(deprecated)]
83    use crate::compute::zip;
84    use crate::dtype::DType;
85    use crate::dtype::Nullability;
86    use crate::dtype::PType;
87
88    #[test]
89    fn test_chunked_zip_aligns_across_boundaries() {
90        let if_true = ChunkedArray::try_new(
91            vec![
92                buffer![1i32, 2].into_array(),
93                buffer![3i32].into_array(),
94                buffer![4i32, 5].into_array(),
95            ],
96            DType::Primitive(PType::I32, Nullability::NonNullable),
97        )
98        .unwrap();
99
100        let if_false = ChunkedArray::try_new(
101            vec![
102                buffer![10i32].into_array(),
103                buffer![11i32, 12].into_array(),
104                buffer![13i32, 14].into_array(),
105            ],
106            DType::Primitive(PType::I32, Nullability::NonNullable),
107        )
108        .unwrap();
109
110        let mask = Mask::from_iter([true, false, true, false, true]);
111
112        #[expect(deprecated)]
113        let zipped = zip(if_true.as_ref(), if_false.as_ref(), &mask).unwrap();
114        let zipped = zipped
115            .as_opt::<ChunkedVTable>()
116            .expect("zip should keep chunked encoding");
117
118        assert_eq!(zipped.nchunks(), 4);
119        let mut values: Vec<i32> = Vec::new();
120        for chunk in zipped.chunks() {
121            let primitive = chunk.to_primitive();
122            values.extend_from_slice(primitive.as_slice::<i32>());
123        }
124        assert_eq!(values, vec![1, 11, 3, 13, 5]);
125    }
126}