Skip to main content

vortex_array/compute/
zip.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_dtype::DType;
5use vortex_error::VortexResult;
6use vortex_mask::AllOr;
7use vortex_mask::Mask;
8
9use crate::Array;
10use crate::ArrayRef;
11use crate::IntoArray;
12use crate::builders::ArrayBuilder;
13use crate::builders::builder_with_capacity;
14use crate::builtins::ArrayBuiltins;
15
16/// Performs element-wise conditional selection between two arrays based on a mask.
17///
18/// Returns a new array where `result[i] = if_true[i]` when `mask[i]` is true,
19/// otherwise `result[i] = if_false[i]`.
20///
21/// Null values in the mask are treated as false (selecting `if_false`). This follows
22/// SQL semantics (DuckDB, Trino) where a null condition falls through to the ELSE branch,
23/// rather than Arrow's `if_else` which propagates null conditions to the output.
24pub fn zip(if_true: &dyn Array, if_false: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
25    if_true
26        .to_array()
27        .zip(if_false.to_array(), mask.clone().into_array())
28}
29
30pub(crate) fn zip_return_dtype(if_true: &dyn Array, if_false: &dyn Array) -> DType {
31    if_true
32        .dtype()
33        .union_nullability(if_false.dtype().nullability())
34}
35
36pub(crate) fn zip_impl(
37    if_true: &dyn Array,
38    if_false: &dyn Array,
39    mask: &Mask,
40) -> VortexResult<ArrayRef> {
41    assert_eq!(
42        if_true.len(),
43        if_false.len(),
44        "zip requires arrays to have the same size"
45    );
46
47    let return_type = zip_return_dtype(if_true, if_false);
48    zip_impl_with_builder(
49        if_true,
50        if_false,
51        mask,
52        builder_with_capacity(&return_type, if_true.len()),
53    )
54}
55
56fn zip_impl_with_builder(
57    if_true: &dyn Array,
58    if_false: &dyn Array,
59    mask: &Mask,
60    mut builder: Box<dyn ArrayBuilder>,
61) -> VortexResult<ArrayRef> {
62    match mask.slices() {
63        AllOr::All => Ok(if_true.to_array()),
64        AllOr::None => Ok(if_false.to_array()),
65        AllOr::Some(slices) => {
66            for (start, end) in slices {
67                builder.extend_from_array(&if_false.slice(builder.len()..*start)?);
68                builder.extend_from_array(&if_true.slice(*start..*end)?);
69            }
70            if builder.len() < if_false.len() {
71                builder.extend_from_array(&if_false.slice(builder.len()..if_false.len())?);
72            }
73            Ok(builder.finish())
74        }
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use arrow_array::cast::AsArray;
81    use arrow_select::zip::zip as arrow_zip;
82    use vortex_buffer::buffer;
83    use vortex_dtype::DType;
84    use vortex_dtype::Nullability;
85    use vortex_mask::Mask;
86
87    use crate::Array;
88    use crate::IntoArray;
89    use crate::arrays::ConstantArray;
90    use crate::arrays::PrimitiveArray;
91    use crate::arrays::StructArray;
92    use crate::arrays::VarBinViewVTable;
93    use crate::arrow::IntoArrowArray;
94    use crate::assert_arrays_eq;
95    use crate::builders::ArrayBuilder;
96    use crate::builders::BufferGrowthStrategy;
97    use crate::builders::VarBinViewBuilder;
98    use crate::compute::zip;
99    use crate::scalar::Scalar;
100
101    #[test]
102    fn test_zip_basic() {
103        let mask = Mask::from_iter([true, false, false, true, false]);
104        let if_true = buffer![10, 20, 30, 40, 50].into_array();
105        let if_false = buffer![1, 2, 3, 4, 5].into_array();
106
107        let result = zip(&if_true, &if_false, &mask).unwrap();
108        let expected = buffer![10, 2, 3, 40, 5].into_array();
109
110        assert_arrays_eq!(result, expected);
111    }
112
113    #[test]
114    fn test_zip_all_true() {
115        let mask = Mask::new_true(4);
116        let if_true = buffer![10, 20, 30, 40].into_array();
117        let if_false =
118            PrimitiveArray::from_option_iter([Some(1), Some(2), Some(3), None]).into_array();
119
120        let result = zip(&if_true, &if_false, &mask).unwrap();
121        let expected =
122            PrimitiveArray::from_option_iter([Some(10), Some(20), Some(30), Some(40)]).into_array();
123
124        assert_arrays_eq!(result, expected);
125
126        // result must be nullable even if_true was not
127        assert_eq!(result.dtype(), if_false.dtype())
128    }
129
130    #[test]
131    #[should_panic]
132    fn test_invalid_lengths() {
133        let mask = Mask::new_false(4);
134        let if_true = buffer![10, 20, 30].into_array();
135        let if_false = buffer![1, 2, 3, 4].into_array();
136
137        zip(&if_true, &if_false, &mask).unwrap();
138    }
139
140    #[test]
141    fn test_fragmentation() {
142        let len = 100;
143
144        let const1 = ConstantArray::new(
145            Scalar::utf8("hello_this_is_a_longer_string", Nullability::Nullable),
146            len,
147        )
148        .to_array();
149
150        let const2 = ConstantArray::new(
151            Scalar::utf8("world_this_is_another_string", Nullability::Nullable),
152            len,
153        )
154        .to_array();
155
156        // Create a mask that alternates frequently to cause fragmentation
157        // Pattern: take from const1 at even indices, const2 at odd indices
158        let indices: Vec<usize> = (0..len).step_by(2).collect();
159        let mask = Mask::from_indices(len, indices);
160
161        let result = zip(&const1, &const2, &mask).unwrap();
162
163        insta::assert_snapshot!(result.display_tree(), @r"
164        root: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
165          metadata: EmptyMetadata
166          buffer: buffer_0 host 29 B (align=1) (1.75%)
167          buffer: buffer_1 host 28 B (align=1) (1.69%)
168          buffer: views host 1.60 kB (align=16) (96.56%)
169        ");
170
171        // test wrapped in a struct
172        let wrapped1 = StructArray::try_from_iter([("nested", const1)])
173            .unwrap()
174            .to_array();
175        let wrapped2 = StructArray::try_from_iter([("nested", const2)])
176            .unwrap()
177            .to_array();
178
179        let wrapped_result = zip(&wrapped1, &wrapped2, &mask).unwrap();
180        insta::assert_snapshot!(wrapped_result.display_tree(), @r"
181        root: vortex.struct({nested=utf8?}, len=100) nbytes=1.66 kB (100.00%)
182          metadata: EmptyMetadata
183          nested: vortex.varbinview(utf8?, len=100) nbytes=1.66 kB (100.00%) [all_valid]
184            metadata: EmptyMetadata
185            buffer: buffer_0 host 29 B (align=1) (1.75%)
186            buffer: buffer_1 host 28 B (align=1) (1.69%)
187            buffer: views host 1.60 kB (align=16) (96.56%)
188        ");
189    }
190
191    #[test]
192    fn test_varbinview_zip() {
193        let if_true = {
194            let mut builder = VarBinViewBuilder::new(
195                DType::Utf8(Nullability::NonNullable),
196                10,
197                Default::default(),
198                BufferGrowthStrategy::fixed(64 * 1024),
199                0.0,
200            );
201            for _ in 0..100 {
202                builder.append_value("Hello");
203                builder.append_value("Hello this is a long string that won't be inlined.");
204            }
205            builder.finish()
206        };
207
208        let if_false = {
209            let mut builder = VarBinViewBuilder::new(
210                DType::Utf8(Nullability::NonNullable),
211                10,
212                Default::default(),
213                BufferGrowthStrategy::fixed(64 * 1024),
214                0.0,
215            );
216            for _ in 0..100 {
217                builder.append_value("Hello2");
218                builder.append_value("Hello2 this is a long string that won't be inlined.");
219            }
220            builder.finish()
221        };
222
223        // [1,2,4,5,7,8,..]
224        let mask = Mask::from_indices(200, (0..100).filter(|i| i % 3 != 0).collect());
225
226        let zipped = zip(&if_true, &if_false, &mask).unwrap();
227        let zipped = zipped.as_opt::<VarBinViewVTable>().unwrap();
228        assert_eq!(zipped.nbuffers(), 2);
229
230        // assert the result is the same as arrow
231        let expected = arrow_zip(
232            mask.into_array()
233                .into_arrow_preferred()
234                .unwrap()
235                .as_boolean(),
236            &if_true.into_arrow_preferred().unwrap(),
237            &if_false.into_arrow_preferred().unwrap(),
238        )
239        .unwrap();
240
241        let actual = zipped.clone().into_array().into_arrow_preferred().unwrap();
242        assert_eq!(actual.as_ref(), expected.as_ref());
243    }
244}