Skip to main content

vortex_array/arrays/chunked/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::Array;
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::arrays::ChunkedArray;
10use crate::arrays::ChunkedVTable;
11use crate::builtins::ArrayBuiltins;
12use crate::scalar_fn::fns::zip::ZipKernel;
13
14// Push down the zip call to the chunks. Without this rule
15// the default implementation canonicalises the chunked array
16// then zips once.
17impl ZipKernel for ChunkedVTable {
18    fn zip(
19        if_true: &ChunkedArray,
20        if_false: &ArrayRef,
21        mask: &ArrayRef,
22        _ctx: &mut ExecutionCtx,
23    ) -> VortexResult<Option<ArrayRef>> {
24        let Some(if_false) = if_false.as_opt::<ChunkedVTable>() else {
25            return Ok(None);
26        };
27        let dtype = if_true
28            .dtype()
29            .union_nullability(if_false.dtype().nullability());
30        let mut out_chunks = Vec::with_capacity(if_true.nchunks() + if_false.nchunks());
31
32        let mut lhs_idx = 0;
33        let mut rhs_idx = 0;
34        let mut lhs_offset = 0;
35        let mut rhs_offset = 0;
36        let mut pos = 0;
37        let total_len = if_true.len();
38
39        while pos < total_len {
40            let lhs_chunk = if_true.chunk(lhs_idx);
41            let rhs_chunk = if_false.chunk(rhs_idx);
42
43            let lhs_rem = lhs_chunk.len() - lhs_offset;
44            let rhs_rem = rhs_chunk.len() - rhs_offset;
45            let take_until = lhs_rem.min(rhs_rem);
46
47            let mask_slice = mask.slice(pos..pos + take_until)?;
48            let lhs_slice = lhs_chunk.slice(lhs_offset..lhs_offset + take_until)?;
49            let rhs_slice = rhs_chunk.slice(rhs_offset..rhs_offset + take_until)?;
50
51            out_chunks.push(mask_slice.zip(lhs_slice, rhs_slice)?);
52
53            pos += take_until;
54            lhs_offset += take_until;
55            rhs_offset += take_until;
56
57            if lhs_offset == lhs_chunk.len() {
58                lhs_idx += 1;
59                lhs_offset = 0;
60            }
61            if rhs_offset == rhs_chunk.len() {
62                rhs_idx += 1;
63                rhs_offset = 0;
64            }
65        }
66
67        // SAFETY: chunks originate from zipping slices of inputs that share dtype/nullability.
68        let chunked = unsafe { ChunkedArray::new_unchecked(out_chunks, dtype) };
69        Ok(Some(chunked.to_array()))
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use vortex_buffer::buffer;
76    use vortex_mask::Mask;
77
78    use crate::ArrayRef;
79    use crate::IntoArray;
80    use crate::LEGACY_SESSION;
81    use crate::ToCanonical;
82    use crate::VortexSessionExecute;
83    use crate::arrays::ChunkedArray;
84    use crate::arrays::ChunkedVTable;
85    use crate::builtins::ArrayBuiltins;
86    use crate::dtype::DType;
87    use crate::dtype::Nullability;
88    use crate::dtype::PType;
89
90    #[test]
91    fn test_chunked_zip_aligns_across_boundaries() {
92        let if_true = ChunkedArray::try_new(
93            vec![
94                buffer![1i32, 2].into_array(),
95                buffer![3i32].into_array(),
96                buffer![4i32, 5].into_array(),
97            ],
98            DType::Primitive(PType::I32, Nullability::NonNullable),
99        )
100        .unwrap();
101
102        let if_false = ChunkedArray::try_new(
103            vec![
104                buffer![10i32].into_array(),
105                buffer![11i32, 12].into_array(),
106                buffer![13i32, 14].into_array(),
107            ],
108            DType::Primitive(PType::I32, Nullability::NonNullable),
109        )
110        .unwrap();
111
112        let mask = Mask::from_iter([true, false, true, false, true]);
113
114        let zipped = &mask
115            .into_array()
116            .zip(if_true.to_array(), if_false.to_array())
117            .unwrap();
118        // One step of execution will push down the zip.
119        let zipped = zipped
120            .clone()
121            .execute::<ArrayRef>(&mut LEGACY_SESSION.create_execution_ctx())
122            .unwrap();
123        let zipped = zipped
124            .as_opt::<ChunkedVTable>()
125            .expect("zip should keep chunked encoding");
126
127        assert_eq!(zipped.nchunks(), 4);
128        let mut values: Vec<i32> = Vec::new();
129        for chunk in zipped.chunks() {
130            let primitive = chunk.to_primitive();
131            values.extend_from_slice(primitive.as_slice::<i32>());
132        }
133        assert_eq!(values, vec![1, 11, 3, 13, 5]);
134    }
135}