Skip to main content

vortex_array/arrays/patched/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use rustc_hash::FxHashMap;
5use vortex_buffer::Buffer;
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::Patched;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::patched::PatchedArrayExt;
16use crate::arrays::patched::PatchedArraySlotsExt;
17use crate::arrays::primitive::PrimitiveDataParts;
18use crate::dtype::IntegerPType;
19use crate::dtype::NativePType;
20use crate::match_each_native_ptype;
21use crate::match_each_unsigned_integer_ptype;
22
23impl TakeExecute for Patched {
24    fn take(
25        array: ArrayView<'_, Self>,
26        indices: &ArrayRef,
27        ctx: &mut ExecutionCtx,
28    ) -> VortexResult<Option<ArrayRef>> {
29        // Only pushdown take when we have primitive types.
30        if !array.dtype().is_primitive() {
31            return Ok(None);
32        }
33
34        // Perform take on the inner array, including the placeholders.
35        let inner = array
36            .inner()
37            .take(indices.clone())?
38            .execute::<PrimitiveArray>(ctx)?;
39
40        let PrimitiveDataParts {
41            buffer,
42            validity,
43            ptype,
44        } = inner.into_data_parts();
45
46        let indices_ptype = indices.dtype().as_ptype();
47
48        match_each_unsigned_integer_ptype!(indices_ptype, |I| {
49            match_each_native_ptype!(ptype, |V| {
50                let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
51                let lane_offsets = array
52                    .lane_offsets()
53                    .clone()
54                    .execute::<PrimitiveArray>(ctx)?;
55                let patch_indices = array
56                    .patch_indices()
57                    .clone()
58                    .execute::<PrimitiveArray>(ctx)?;
59                let patch_values = array
60                    .patch_values()
61                    .clone()
62                    .execute::<PrimitiveArray>(ctx)?;
63                let mut output = Buffer::<V>::from_byte_buffer(buffer.unwrap_host()).into_mut();
64                take_map(
65                    output.as_mut(),
66                    indices.as_slice::<I>(),
67                    array.offset(),
68                    array.len(),
69                    array.n_lanes(),
70                    lane_offsets.as_slice::<u32>(),
71                    patch_indices.as_slice::<u16>(),
72                    patch_values.as_slice::<V>(),
73                );
74
75                // SAFETY: output and validity still have same length after take_map returns.
76                unsafe {
77                    Ok(Some(
78                        PrimitiveArray::new_unchecked(output.freeze(), validity).into_array(),
79                    ))
80                }
81            })
82        })
83    }
84}
85
86/// Take patches for the given `indices` and apply them onto an `output` using a hash map.
87///
88/// First, builds a hashmap from index to patch value, then uses the hashmap in a loop to collect
89/// the values.
90#[allow(clippy::too_many_arguments)]
91fn take_map<I: IntegerPType, V: NativePType>(
92    output: &mut [V],
93    indices: &[I],
94    offset: usize,
95    len: usize,
96    n_lanes: usize,
97    lane_offsets: &[u32],
98    patch_index: &[u16],
99    patch_value: &[V],
100) {
101    let n_chunks = (offset + len).div_ceil(1024);
102    // Build a hashmap of patch_index -> values.
103    let mut index_map = FxHashMap::with_capacity_and_hasher(patch_index.len(), Default::default());
104    for chunk in 0..n_chunks {
105        for lane in 0..n_lanes {
106            let lane_start = lane_offsets[chunk * n_lanes + lane];
107            let lane_end = lane_offsets[chunk * n_lanes + lane + 1];
108            for i in lane_start..lane_end {
109                let patch_idx = patch_index[i as usize];
110                let patch_value = patch_value[i as usize];
111
112                let index = chunk * 1024 + patch_idx as usize;
113                if index >= offset && index < offset + len {
114                    index_map.insert(index - offset, patch_value);
115                }
116            }
117        }
118    }
119
120    // Now, iterate the take indices using the prebuilt hashmap.
121    // Undefined/null indices will miss the hash map, which we can ignore.
122    for (output_index, index) in indices.iter().enumerate() {
123        let index = index.as_();
124        if let Some(&patch_value) = index_map.get(&index) {
125            output[output_index] = patch_value;
126        }
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use std::ops::Range;
133
134    use vortex_buffer::buffer;
135    use vortex_error::VortexResult;
136    use vortex_session::VortexSession;
137
138    use crate::ArrayRef;
139    use crate::ExecutionCtx;
140    use crate::IntoArray;
141    use crate::arrays::Patched;
142    use crate::arrays::PrimitiveArray;
143    use crate::assert_arrays_eq;
144    use crate::patches::Patches;
145
146    fn make_patched_array(
147        base: &[u16],
148        patch_indices: &[u32],
149        patch_values: &[u16],
150        slice: Range<usize>,
151    ) -> VortexResult<ArrayRef> {
152        let values = PrimitiveArray::from_iter(base.iter().copied()).into_array();
153        let patches = Patches::new(
154            base.len(),
155            0,
156            PrimitiveArray::from_iter(patch_indices.iter().copied()).into_array(),
157            PrimitiveArray::from_iter(patch_values.iter().copied()).into_array(),
158            None,
159        )?;
160
161        let session = VortexSession::empty();
162        let mut ctx = ExecutionCtx::new(session);
163
164        Patched::from_array_and_patches(values, &patches, &mut ctx)?
165            .into_array()
166            .slice(slice)
167    }
168
169    #[test]
170    fn test_take_basic() -> VortexResult<()> {
171        // Array with base values [0, 0, 0, 0, 0] patched at indices [1, 3] with values [10, 30]
172        let array = make_patched_array(&[0; 5], &[1, 3], &[10, 30], 0..5)?;
173
174        // Take indices [0, 1, 2, 3, 4] - should get [0, 10, 0, 30, 0]
175        let indices = buffer![0u32, 1, 2, 3, 4].into_array();
176        let result = array.take(indices)?.to_canonical()?.into_array();
177
178        let expected = PrimitiveArray::from_iter([0u16, 10, 0, 30, 0]).into_array();
179        assert_arrays_eq!(expected, result);
180
181        Ok(())
182    }
183
184    #[test]
185    fn test_take_sliced() -> VortexResult<()> {
186        let array = make_patched_array(&[0; 10], &[1, 3], &[100, 200], 2..10)?;
187
188        let indices = buffer![0u32, 1, 2, 3, 7].into_array();
189        let result = array.take(indices)?.to_canonical()?.into_array();
190
191        let expected = PrimitiveArray::from_iter([0u16, 200, 0, 0, 0]).into_array();
192        assert_arrays_eq!(expected, result);
193
194        Ok(())
195    }
196
197    #[test]
198    fn test_take_out_of_order() -> VortexResult<()> {
199        // Array with base values [0, 0, 0, 0, 0] patched at indices [1, 3] with values [10, 30]
200        let array = make_patched_array(&[0; 5], &[1, 3], &[10, 30], 0..5)?;
201
202        // Take indices in reverse order
203        let indices = buffer![4u32, 3, 2, 1, 0].into_array();
204        let result = array.take(indices)?.to_canonical()?.into_array();
205
206        let expected = PrimitiveArray::from_iter([0u16, 30, 0, 10, 0]).into_array();
207        assert_arrays_eq!(expected, result);
208
209        Ok(())
210    }
211
212    #[test]
213    fn test_take_duplicates() -> VortexResult<()> {
214        // Array with base values [0, 0, 0, 0, 0] patched at index [2] with value [99]
215        let array = make_patched_array(&[0; 5], &[2], &[99], 0..5)?;
216
217        // Take the same patched index multiple times
218        let indices = buffer![2u32, 2, 0, 2].into_array();
219        let result = array.take(indices)?.to_canonical()?.into_array();
220
221        // execute the array.
222        let _canonical = result.to_canonical()?.into_primitive();
223
224        let expected = PrimitiveArray::from_iter([99u16, 99, 0, 99]).into_array();
225        assert_arrays_eq!(expected, result);
226
227        Ok(())
228    }
229
230    #[test]
231    fn test_take_with_null_indices() -> VortexResult<()> {
232        use crate::arrays::BoolArray;
233        use crate::validity::Validity;
234
235        // Array: 10 elements, base value 0, patches at indices 2, 5, 8 with values 20, 50, 80
236        let array = make_patched_array(&[0; 10], &[2, 5, 8], &[20, 50, 80], 0..10)?;
237
238        // Take 10 indices, with nulls at positions 1, 4, 7
239        // Indices: [0, 2, 2, 5, 8, 0, 5, 8, 3, 1]
240        // Nulls:   [ ,  , N,  ,  , N,  ,  , N,  ]
241        // Position 2 (index=2, patched) is null
242        // Position 5 (index=0, unpatched) is null
243        // Position 8 (index=3, unpatched) is null
244        let indices = PrimitiveArray::new(
245            buffer![0u32, 2, 2, 5, 8, 0, 5, 8, 3, 1],
246            Validity::Array(
247                BoolArray::from_iter([
248                    true, true, false, true, true, false, true, true, false, true,
249                ])
250                .into_array(),
251            ),
252        );
253        let result = array
254            .take(indices.into_array())?
255            .to_canonical()?
256            .into_array();
257
258        // Expected: [0, 20, null, 50, 80, null, 50, 80, null, 0]
259        let expected = PrimitiveArray::new(
260            buffer![0u16, 20, 0, 50, 80, 0, 50, 80, 0, 0],
261            Validity::Array(
262                BoolArray::from_iter([
263                    true, true, false, true, true, false, true, true, false, true,
264                ])
265                .into_array(),
266            ),
267        );
268        assert_arrays_eq!(expected.into_array(), result);
269
270        Ok(())
271    }
272}