Skip to main content

vortex_array/arrays/chunked/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::ExecutionCtx;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::Chunked;
11use crate::arrays::ChunkedArray;
12use crate::arrays::chunked::ChunkedArrayExt;
13use crate::arrays::chunked::paired_chunks::PairedChunksExt;
14use crate::builtins::ArrayBuiltins;
15use crate::scalar_fn::fns::zip::ZipKernel;
16
17// Push down the zip call to the chunks. Without this rule
18// the default implementation canonicalises the chunked array
19// then zips once.
20impl ZipKernel for Chunked {
21    fn zip(
22        if_true: ArrayView<'_, Chunked>,
23        if_false: &ArrayRef,
24        mask: &ArrayRef,
25        _ctx: &mut ExecutionCtx,
26    ) -> VortexResult<Option<ArrayRef>> {
27        let Some(if_false) = if_false.as_opt::<Chunked>() else {
28            return Ok(None);
29        };
30        let dtype = if_true
31            .dtype()
32            .union_nullability(if_false.dtype().nullability());
33        let mut out_chunks = Vec::with_capacity(if_true.nchunks() + if_false.nchunks());
34
35        for pair in if_true.paired_chunks(&if_false) {
36            let pair = pair?;
37            let mask_slice = mask.slice(pair.pos)?;
38            out_chunks.push(mask_slice.zip(pair.left, pair.right)?);
39        }
40
41        // SAFETY: chunks originate from zipping slices of inputs that share dtype/nullability.
42        let chunked = unsafe { ChunkedArray::new_unchecked(out_chunks, dtype) };
43        Ok(Some(chunked.into_array()))
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use vortex_buffer::buffer;
50    use vortex_mask::Mask;
51
52    use crate::ArrayRef;
53    use crate::IntoArray;
54    use crate::LEGACY_SESSION;
55    use crate::ToCanonical;
56    use crate::VortexSessionExecute;
57    use crate::arrays::Chunked;
58    use crate::arrays::ChunkedArray;
59    use crate::arrays::chunked::ChunkedArrayExt;
60    use crate::builtins::ArrayBuiltins;
61    use crate::dtype::DType;
62    use crate::dtype::Nullability;
63    use crate::dtype::PType;
64
65    #[test]
66    fn test_chunked_zip_aligns_across_boundaries() {
67        let if_true = ChunkedArray::try_new(
68            vec![
69                buffer![1i32, 2].into_array(),
70                buffer![3i32].into_array(),
71                buffer![4i32, 5].into_array(),
72            ],
73            DType::Primitive(PType::I32, Nullability::NonNullable),
74        )
75        .unwrap();
76
77        let if_false = ChunkedArray::try_new(
78            vec![
79                buffer![10i32].into_array(),
80                buffer![11i32, 12].into_array(),
81                buffer![13i32, 14].into_array(),
82            ],
83            DType::Primitive(PType::I32, Nullability::NonNullable),
84        )
85        .unwrap();
86
87        let mask = Mask::from_iter([true, false, true, false, true]);
88
89        let zipped = &mask
90            .into_array()
91            .zip(if_true.into_array(), if_false.into_array())
92            .unwrap();
93        // One step of execution will push down the zip.
94        let zipped = zipped
95            .clone()
96            .execute::<ArrayRef>(&mut LEGACY_SESSION.create_execution_ctx())
97            .unwrap();
98        let zipped = zipped
99            .as_opt::<Chunked>()
100            .expect("zip should keep chunked encoding");
101
102        assert_eq!(zipped.nchunks(), 4);
103        let mut values: Vec<i32> = Vec::new();
104        for chunk in zipped.chunks() {
105            let primitive = chunk.to_primitive();
106            values.extend_from_slice(primitive.as_slice::<i32>());
107        }
108        assert_eq!(values, vec![1, 11, 3, 13, 5]);
109    }
110}