Skip to main content

vortex_array/arrays/patched/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use rustc_hash::FxHashMap;
5use vortex_buffer::Buffer;
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::Patched;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::patched::PatchedArrayExt;
16use crate::arrays::patched::PatchedArraySlotsExt;
17use crate::arrays::primitive::PrimitiveDataParts;
18use crate::dtype::IntegerPType;
19use crate::dtype::NativePType;
20use crate::match_each_native_ptype;
21use crate::match_each_unsigned_integer_ptype;
22
23impl TakeExecute for Patched {
24    fn take(
25        array: ArrayView<'_, Self>,
26        indices: &ArrayRef,
27        ctx: &mut ExecutionCtx,
28    ) -> VortexResult<Option<ArrayRef>> {
29        // Only pushdown take when we have primitive types.
30        if !array.dtype().is_primitive() {
31            return Ok(None);
32        }
33
34        // Perform take on the inner array, including the placeholders.
35        let inner = array
36            .inner()
37            .take(indices.clone())?
38            .execute::<PrimitiveArray>(ctx)?;
39
40        let PrimitiveDataParts {
41            buffer,
42            validity,
43            ptype,
44        } = inner.into_data_parts();
45
46        let indices_ptype = indices.dtype().as_ptype();
47
48        match_each_unsigned_integer_ptype!(indices_ptype, |I| {
49            match_each_native_ptype!(ptype, |V| {
50                let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
51                let lane_offsets = array
52                    .lane_offsets()
53                    .clone()
54                    .execute::<PrimitiveArray>(ctx)?;
55                let patch_indices = array
56                    .patch_indices()
57                    .clone()
58                    .execute::<PrimitiveArray>(ctx)?;
59                let patch_values = array
60                    .patch_values()
61                    .clone()
62                    .execute::<PrimitiveArray>(ctx)?;
63                let mut output = Buffer::<V>::from_byte_buffer(buffer.unwrap_host()).into_mut();
64                take_map(
65                    output.as_mut(),
66                    indices.as_slice::<I>(),
67                    array.offset(),
68                    array.len(),
69                    array.n_lanes(),
70                    lane_offsets.as_slice::<u32>(),
71                    patch_indices.as_slice::<u16>(),
72                    patch_values.as_slice::<V>(),
73                );
74
75                // SAFETY: output and validity still have same length after take_map returns.
76                unsafe {
77                    Ok(Some(
78                        PrimitiveArray::new_unchecked(output.freeze(), validity).into_array(),
79                    ))
80                }
81            })
82        })
83    }
84}
85
86/// Take patches for the given `indices` and apply them onto an `output` using a hash map.
87///
88/// First, builds a hashmap from index to patch value, then uses the hashmap in a loop to collect
89/// the values.
90#[expect(clippy::too_many_arguments)]
91fn take_map<I: IntegerPType, V: NativePType>(
92    output: &mut [V],
93    indices: &[I],
94    offset: usize,
95    len: usize,
96    n_lanes: usize,
97    lane_offsets: &[u32],
98    patch_index: &[u16],
99    patch_value: &[V],
100) {
101    let n_chunks = (offset + len).div_ceil(1024);
102    // Build a hashmap of patch_index -> values.
103    let mut index_map = FxHashMap::with_capacity_and_hasher(patch_index.len(), Default::default());
104    for chunk in 0..n_chunks {
105        for lane in 0..n_lanes {
106            let lane_start = lane_offsets[chunk * n_lanes + lane];
107            let lane_end = lane_offsets[chunk * n_lanes + lane + 1];
108            for i in lane_start..lane_end {
109                let patch_idx = patch_index[i as usize];
110                let patch_value = patch_value[i as usize];
111
112                let index = chunk * 1024 + patch_idx as usize;
113                if index >= offset && index < offset + len {
114                    index_map.insert(index - offset, patch_value);
115                }
116            }
117        }
118    }
119
120    // Now, iterate the take indices using the prebuilt hashmap.
121    // Undefined/null indices will miss the hash map, which we can ignore.
122    for (output_index, index) in indices.iter().enumerate() {
123        let index = index.as_();
124        if let Some(&patch_value) = index_map.get(&index) {
125            output[output_index] = patch_value;
126        }
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use std::ops::Range;
133
134    use vortex_buffer::buffer;
135    use vortex_error::VortexResult;
136    use vortex_session::VortexSession;
137
138    use crate::ArrayRef;
139    use crate::IntoArray;
140    use crate::VortexSessionExecute;
141    use crate::array_session;
142    use crate::arrays::Patched;
143    use crate::arrays::PrimitiveArray;
144    use crate::assert_arrays_eq;
145    use crate::patches::Patches;
146
147    fn make_patched_array(
148        base: &[u16],
149        patch_indices: &[u32],
150        patch_values: &[u16],
151        slice: Range<usize>,
152    ) -> VortexResult<ArrayRef> {
153        let values = PrimitiveArray::from_iter(base.iter().copied()).into_array();
154        let patches = Patches::new(
155            base.len(),
156            0,
157            PrimitiveArray::from_iter(patch_indices.iter().copied()).into_array(),
158            PrimitiveArray::from_iter(patch_values.iter().copied()).into_array(),
159            None,
160        )?;
161
162        let session = VortexSession::empty();
163        let mut ctx = session.create_execution_ctx();
164
165        Patched::from_array_and_patches(values, &patches, &mut ctx)?
166            .into_array()
167            .slice(slice)
168    }
169
170    #[test]
171    fn test_take_basic() -> VortexResult<()> {
172        let mut ctx = array_session().create_execution_ctx();
173        // Array with base values [0, 0, 0, 0, 0] patched at indices [1, 3] with values [10, 30]
174        let array = make_patched_array(&[0; 5], &[1, 3], &[10, 30], 0..5)?;
175
176        // Take indices [0, 1, 2, 3, 4] - should get [0, 10, 0, 30, 0]
177        let indices = buffer![0u32, 1, 2, 3, 4].into_array();
178        #[expect(deprecated)]
179        let result = array.take(indices)?.to_canonical()?.into_array();
180
181        let expected = PrimitiveArray::from_iter([0u16, 10, 0, 30, 0]).into_array();
182        assert_arrays_eq!(expected, result, &mut ctx);
183
184        Ok(())
185    }
186
187    #[test]
188    fn test_take_sliced() -> VortexResult<()> {
189        let mut ctx = array_session().create_execution_ctx();
190        let array = make_patched_array(&[0; 10], &[1, 3], &[100, 200], 2..10)?;
191
192        let indices = buffer![0u32, 1, 2, 3, 7].into_array();
193        #[expect(deprecated)]
194        let result = array.take(indices)?.to_canonical()?.into_array();
195
196        let expected = PrimitiveArray::from_iter([0u16, 200, 0, 0, 0]).into_array();
197        assert_arrays_eq!(expected, result, &mut ctx);
198
199        Ok(())
200    }
201
202    #[test]
203    fn test_take_out_of_order() -> VortexResult<()> {
204        let mut ctx = array_session().create_execution_ctx();
205        // Array with base values [0, 0, 0, 0, 0] patched at indices [1, 3] with values [10, 30]
206        let array = make_patched_array(&[0; 5], &[1, 3], &[10, 30], 0..5)?;
207
208        // Take indices in reverse order
209        let indices = buffer![4u32, 3, 2, 1, 0].into_array();
210        #[expect(deprecated)]
211        let result = array.take(indices)?.to_canonical()?.into_array();
212
213        let expected = PrimitiveArray::from_iter([0u16, 30, 0, 10, 0]).into_array();
214        assert_arrays_eq!(expected, result, &mut ctx);
215
216        Ok(())
217    }
218
219    #[test]
220    fn test_take_duplicates() -> VortexResult<()> {
221        let mut ctx = array_session().create_execution_ctx();
222        // Array with base values [0, 0, 0, 0, 0] patched at index [2] with value [99]
223        let array = make_patched_array(&[0; 5], &[2], &[99], 0..5)?;
224
225        // Take the same patched index multiple times
226        let indices = buffer![2u32, 2, 0, 2].into_array();
227        #[expect(deprecated)]
228        let result = array.take(indices)?.to_canonical()?.into_array();
229
230        // execute the array.
231        #[expect(deprecated)]
232        let _canonical = result.to_canonical()?.into_primitive();
233
234        let expected = PrimitiveArray::from_iter([99u16, 99, 0, 99]).into_array();
235        assert_arrays_eq!(expected, result, &mut ctx);
236
237        Ok(())
238    }
239
240    #[test]
241    fn test_take_with_null_indices() -> VortexResult<()> {
242        let mut ctx = array_session().create_execution_ctx();
243        use crate::arrays::BoolArray;
244        use crate::validity::Validity;
245
246        // Array: 10 elements, base value 0, patches at indices 2, 5, 8 with values 20, 50, 80
247        let array = make_patched_array(&[0; 10], &[2, 5, 8], &[20, 50, 80], 0..10)?;
248
249        // Take 10 indices, with nulls at positions 1, 4, 7
250        // Indices: [0, 2, 2, 5, 8, 0, 5, 8, 3, 1]
251        // Nulls:   [ ,  , N,  ,  , N,  ,  , N,  ]
252        // Position 2 (index=2, patched) is null
253        // Position 5 (index=0, unpatched) is null
254        // Position 8 (index=3, unpatched) is null
255        let indices = PrimitiveArray::new(
256            buffer![0u32, 2, 2, 5, 8, 0, 5, 8, 3, 1],
257            Validity::Array(
258                BoolArray::from_iter([
259                    true, true, false, true, true, false, true, true, false, true,
260                ])
261                .into_array(),
262            ),
263        );
264        #[expect(deprecated)]
265        let result = array
266            .take(indices.into_array())?
267            .to_canonical()?
268            .into_array();
269
270        // Expected: [0, 20, null, 50, 80, null, 50, 80, null, 0]
271        let expected = PrimitiveArray::new(
272            buffer![0u16, 20, 0, 50, 80, 0, 50, 80, 0, 0],
273            Validity::Array(
274                BoolArray::from_iter([
275                    true, true, false, true, true, false, true, true, false, true,
276                ])
277                .into_array(),
278            ),
279        );
280        assert_arrays_eq!(expected.into_array(), result, &mut ctx);
281
282        Ok(())
283    }
284}