vortex_array/arrays/patched/compute/
take.rs1use rustc_hash::FxHashMap;
5use vortex_buffer::Buffer;
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::Patched;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::patched::PatchedArrayExt;
16use crate::arrays::patched::PatchedArraySlotsExt;
17use crate::arrays::primitive::PrimitiveDataParts;
18use crate::dtype::IntegerPType;
19use crate::dtype::NativePType;
20use crate::match_each_native_ptype;
21use crate::match_each_unsigned_integer_ptype;
22
23impl TakeExecute for Patched {
24 fn take(
25 array: ArrayView<'_, Self>,
26 indices: &ArrayRef,
27 ctx: &mut ExecutionCtx,
28 ) -> VortexResult<Option<ArrayRef>> {
29 if !array.dtype().is_primitive() {
31 return Ok(None);
32 }
33
34 let inner = array
36 .inner()
37 .take(indices.clone())?
38 .execute::<PrimitiveArray>(ctx)?;
39
40 let PrimitiveDataParts {
41 buffer,
42 validity,
43 ptype,
44 } = inner.into_data_parts();
45
46 let indices_ptype = indices.dtype().as_ptype();
47
48 match_each_unsigned_integer_ptype!(indices_ptype, |I| {
49 match_each_native_ptype!(ptype, |V| {
50 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
51 let lane_offsets = array
52 .lane_offsets()
53 .clone()
54 .execute::<PrimitiveArray>(ctx)?;
55 let patch_indices = array
56 .patch_indices()
57 .clone()
58 .execute::<PrimitiveArray>(ctx)?;
59 let patch_values = array
60 .patch_values()
61 .clone()
62 .execute::<PrimitiveArray>(ctx)?;
63 let mut output = Buffer::<V>::from_byte_buffer(buffer.unwrap_host()).into_mut();
64 take_map(
65 output.as_mut(),
66 indices.as_slice::<I>(),
67 array.offset(),
68 array.len(),
69 array.n_lanes(),
70 lane_offsets.as_slice::<u32>(),
71 patch_indices.as_slice::<u16>(),
72 patch_values.as_slice::<V>(),
73 );
74
75 unsafe {
77 Ok(Some(
78 PrimitiveArray::new_unchecked(output.freeze(), validity).into_array(),
79 ))
80 }
81 })
82 })
83 }
84}
85
86#[expect(clippy::too_many_arguments)]
91fn take_map<I: IntegerPType, V: NativePType>(
92 output: &mut [V],
93 indices: &[I],
94 offset: usize,
95 len: usize,
96 n_lanes: usize,
97 lane_offsets: &[u32],
98 patch_index: &[u16],
99 patch_value: &[V],
100) {
101 let n_chunks = (offset + len).div_ceil(1024);
102 let mut index_map = FxHashMap::with_capacity_and_hasher(patch_index.len(), Default::default());
104 for chunk in 0..n_chunks {
105 for lane in 0..n_lanes {
106 let lane_start = lane_offsets[chunk * n_lanes + lane];
107 let lane_end = lane_offsets[chunk * n_lanes + lane + 1];
108 for i in lane_start..lane_end {
109 let patch_idx = patch_index[i as usize];
110 let patch_value = patch_value[i as usize];
111
112 let index = chunk * 1024 + patch_idx as usize;
113 if index >= offset && index < offset + len {
114 index_map.insert(index - offset, patch_value);
115 }
116 }
117 }
118 }
119
120 for (output_index, index) in indices.iter().enumerate() {
123 let index = index.as_();
124 if let Some(&patch_value) = index_map.get(&index) {
125 output[output_index] = patch_value;
126 }
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use std::ops::Range;
133
134 use vortex_buffer::buffer;
135 use vortex_error::VortexResult;
136 use vortex_session::VortexSession;
137
138 use crate::ArrayRef;
139 use crate::IntoArray;
140 use crate::VortexSessionExecute;
141 use crate::array_session;
142 use crate::arrays::Patched;
143 use crate::arrays::PrimitiveArray;
144 use crate::assert_arrays_eq;
145 use crate::patches::Patches;
146
147 fn make_patched_array(
148 base: &[u16],
149 patch_indices: &[u32],
150 patch_values: &[u16],
151 slice: Range<usize>,
152 ) -> VortexResult<ArrayRef> {
153 let values = PrimitiveArray::from_iter(base.iter().copied()).into_array();
154 let patches = Patches::new(
155 base.len(),
156 0,
157 PrimitiveArray::from_iter(patch_indices.iter().copied()).into_array(),
158 PrimitiveArray::from_iter(patch_values.iter().copied()).into_array(),
159 None,
160 )?;
161
162 let session = VortexSession::empty();
163 let mut ctx = session.create_execution_ctx();
164
165 Patched::from_array_and_patches(values, &patches, &mut ctx)?
166 .into_array()
167 .slice(slice)
168 }
169
170 #[test]
171 fn test_take_basic() -> VortexResult<()> {
172 let mut ctx = array_session().create_execution_ctx();
173 let array = make_patched_array(&[0; 5], &[1, 3], &[10, 30], 0..5)?;
175
176 let indices = buffer![0u32, 1, 2, 3, 4].into_array();
178 #[expect(deprecated)]
179 let result = array.take(indices)?.to_canonical()?.into_array();
180
181 let expected = PrimitiveArray::from_iter([0u16, 10, 0, 30, 0]).into_array();
182 assert_arrays_eq!(expected, result, &mut ctx);
183
184 Ok(())
185 }
186
187 #[test]
188 fn test_take_sliced() -> VortexResult<()> {
189 let mut ctx = array_session().create_execution_ctx();
190 let array = make_patched_array(&[0; 10], &[1, 3], &[100, 200], 2..10)?;
191
192 let indices = buffer![0u32, 1, 2, 3, 7].into_array();
193 #[expect(deprecated)]
194 let result = array.take(indices)?.to_canonical()?.into_array();
195
196 let expected = PrimitiveArray::from_iter([0u16, 200, 0, 0, 0]).into_array();
197 assert_arrays_eq!(expected, result, &mut ctx);
198
199 Ok(())
200 }
201
202 #[test]
203 fn test_take_out_of_order() -> VortexResult<()> {
204 let mut ctx = array_session().create_execution_ctx();
205 let array = make_patched_array(&[0; 5], &[1, 3], &[10, 30], 0..5)?;
207
208 let indices = buffer![4u32, 3, 2, 1, 0].into_array();
210 #[expect(deprecated)]
211 let result = array.take(indices)?.to_canonical()?.into_array();
212
213 let expected = PrimitiveArray::from_iter([0u16, 30, 0, 10, 0]).into_array();
214 assert_arrays_eq!(expected, result, &mut ctx);
215
216 Ok(())
217 }
218
219 #[test]
220 fn test_take_duplicates() -> VortexResult<()> {
221 let mut ctx = array_session().create_execution_ctx();
222 let array = make_patched_array(&[0; 5], &[2], &[99], 0..5)?;
224
225 let indices = buffer![2u32, 2, 0, 2].into_array();
227 #[expect(deprecated)]
228 let result = array.take(indices)?.to_canonical()?.into_array();
229
230 #[expect(deprecated)]
232 let _canonical = result.to_canonical()?.into_primitive();
233
234 let expected = PrimitiveArray::from_iter([99u16, 99, 0, 99]).into_array();
235 assert_arrays_eq!(expected, result, &mut ctx);
236
237 Ok(())
238 }
239
240 #[test]
241 fn test_take_with_null_indices() -> VortexResult<()> {
242 let mut ctx = array_session().create_execution_ctx();
243 use crate::arrays::BoolArray;
244 use crate::validity::Validity;
245
246 let array = make_patched_array(&[0; 10], &[2, 5, 8], &[20, 50, 80], 0..10)?;
248
249 let indices = PrimitiveArray::new(
256 buffer![0u32, 2, 2, 5, 8, 0, 5, 8, 3, 1],
257 Validity::Array(
258 BoolArray::from_iter([
259 true, true, false, true, true, false, true, true, false, true,
260 ])
261 .into_array(),
262 ),
263 );
264 #[expect(deprecated)]
265 let result = array
266 .take(indices.into_array())?
267 .to_canonical()?
268 .into_array();
269
270 let expected = PrimitiveArray::new(
272 buffer![0u16, 20, 0, 50, 80, 0, 50, 80, 0, 0],
273 Validity::Array(
274 BoolArray::from_iter([
275 true, true, false, true, true, false, true, true, false, true,
276 ])
277 .into_array(),
278 ),
279 );
280 assert_arrays_eq!(expected.into_array(), result, &mut ctx);
281
282 Ok(())
283 }
284}