Skip to main content

vortex_array/arrays/chunked/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5use vortex_mask::Mask;
6
7use crate::Array;
8use crate::ArrayRef;
9use crate::arrays::ChunkedArray;
10use crate::arrays::ChunkedVTable;
11use crate::compute::zip;
12use crate::expr::ZipReduce;
13
14// Push down the zip call to the chunks. Without this rule
15// the default implementation canonicalises the chunked array
16// then zips once.
17impl ZipReduce for ChunkedVTable {
18    fn zip(
19        if_true: &ChunkedArray,
20        if_false: &dyn Array,
21        mask: &Mask,
22    ) -> VortexResult<Option<ArrayRef>> {
23        let Some(if_false) = if_false.as_opt::<ChunkedVTable>() else {
24            return Ok(None);
25        };
26        let dtype = if_true
27            .dtype()
28            .union_nullability(if_false.dtype().nullability());
29        let mut out_chunks = Vec::with_capacity(if_true.nchunks() + if_false.nchunks());
30
31        let mut lhs_idx = 0;
32        let mut rhs_idx = 0;
33        let mut lhs_offset = 0;
34        let mut rhs_offset = 0;
35        let mut pos = 0;
36        let total_len = if_true.len();
37
38        while pos < total_len {
39            let lhs_chunk = if_true.chunk(lhs_idx);
40            let rhs_chunk = if_false.chunk(rhs_idx);
41
42            let lhs_rem = lhs_chunk.len() - lhs_offset;
43            let rhs_rem = rhs_chunk.len() - rhs_offset;
44            let take_until = lhs_rem.min(rhs_rem);
45
46            let mask_slice = mask.slice(pos..pos + take_until);
47            let lhs_slice = lhs_chunk.slice(lhs_offset..lhs_offset + take_until)?;
48            let rhs_slice = rhs_chunk.slice(rhs_offset..rhs_offset + take_until)?;
49
50            out_chunks.push(zip(lhs_slice.as_ref(), rhs_slice.as_ref(), &mask_slice)?);
51
52            pos += take_until;
53            lhs_offset += take_until;
54            rhs_offset += take_until;
55
56            if lhs_offset == lhs_chunk.len() {
57                lhs_idx += 1;
58                lhs_offset = 0;
59            }
60            if rhs_offset == rhs_chunk.len() {
61                rhs_idx += 1;
62                rhs_offset = 0;
63            }
64        }
65
66        // SAFETY: chunks originate from zipping slices of inputs that share dtype/nullability.
67        let chunked = unsafe { ChunkedArray::new_unchecked(out_chunks, dtype) };
68        Ok(Some(chunked.to_array()))
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use vortex_buffer::buffer;
75    use vortex_dtype::DType;
76    use vortex_dtype::Nullability;
77    use vortex_dtype::PType;
78    use vortex_mask::Mask;
79
80    use crate::IntoArray;
81    use crate::ToCanonical;
82    use crate::arrays::ChunkedArray;
83    use crate::arrays::ChunkedVTable;
84    use crate::compute::zip;
85
86    #[test]
87    fn test_chunked_zip_aligns_across_boundaries() {
88        let if_true = ChunkedArray::try_new(
89            vec![
90                buffer![1i32, 2].into_array(),
91                buffer![3i32].into_array(),
92                buffer![4i32, 5].into_array(),
93            ],
94            DType::Primitive(PType::I32, Nullability::NonNullable),
95        )
96        .unwrap();
97
98        let if_false = ChunkedArray::try_new(
99            vec![
100                buffer![10i32].into_array(),
101                buffer![11i32, 12].into_array(),
102                buffer![13i32, 14].into_array(),
103            ],
104            DType::Primitive(PType::I32, Nullability::NonNullable),
105        )
106        .unwrap();
107
108        let mask = Mask::from_iter([true, false, true, false, true]);
109
110        let zipped = zip(if_true.as_ref(), if_false.as_ref(), &mask).unwrap();
111        let zipped = zipped
112            .as_opt::<ChunkedVTable>()
113            .expect("zip should keep chunked encoding");
114
115        assert_eq!(zipped.nchunks(), 4);
116        let mut values: Vec<i32> = Vec::new();
117        for chunk in zipped.chunks() {
118            let primitive = chunk.to_primitive();
119            values.extend_from_slice(primitive.as_slice::<i32>());
120        }
121        assert_eq!(values, vec![1, 11, 3, 13, 5]);
122    }
123}