vortex_array/arrays/patched/compute/
take.rs1use rustc_hash::FxHashMap;
5use vortex_buffer::Buffer;
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::Patched;
13use crate::arrays::PrimitiveArray;
14use crate::arrays::dict::TakeExecute;
15use crate::arrays::patched::PatchedArrayExt;
16use crate::arrays::patched::PatchedArraySlotsExt;
17use crate::arrays::primitive::PrimitiveDataParts;
18use crate::dtype::IntegerPType;
19use crate::dtype::NativePType;
20use crate::match_each_native_ptype;
21use crate::match_each_unsigned_integer_ptype;
22
23impl TakeExecute for Patched {
24 fn take(
25 array: ArrayView<'_, Self>,
26 indices: &ArrayRef,
27 ctx: &mut ExecutionCtx,
28 ) -> VortexResult<Option<ArrayRef>> {
29 if !array.dtype().is_primitive() {
31 return Ok(None);
32 }
33
34 let inner = array
36 .inner()
37 .take(indices.clone())?
38 .execute::<PrimitiveArray>(ctx)?;
39
40 let PrimitiveDataParts {
41 buffer,
42 validity,
43 ptype,
44 } = inner.into_data_parts();
45
46 let indices_ptype = indices.dtype().as_ptype();
47
48 match_each_unsigned_integer_ptype!(indices_ptype, |I| {
49 match_each_native_ptype!(ptype, |V| {
50 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
51 let lane_offsets = array
52 .lane_offsets()
53 .clone()
54 .execute::<PrimitiveArray>(ctx)?;
55 let patch_indices = array
56 .patch_indices()
57 .clone()
58 .execute::<PrimitiveArray>(ctx)?;
59 let patch_values = array
60 .patch_values()
61 .clone()
62 .execute::<PrimitiveArray>(ctx)?;
63 let mut output = Buffer::<V>::from_byte_buffer(buffer.unwrap_host()).into_mut();
64 take_map(
65 output.as_mut(),
66 indices.as_slice::<I>(),
67 array.offset(),
68 array.len(),
69 array.n_lanes(),
70 lane_offsets.as_slice::<u32>(),
71 patch_indices.as_slice::<u16>(),
72 patch_values.as_slice::<V>(),
73 );
74
75 unsafe {
77 Ok(Some(
78 PrimitiveArray::new_unchecked(output.freeze(), validity).into_array(),
79 ))
80 }
81 })
82 })
83 }
84}
85
86#[allow(clippy::too_many_arguments)]
91fn take_map<I: IntegerPType, V: NativePType>(
92 output: &mut [V],
93 indices: &[I],
94 offset: usize,
95 len: usize,
96 n_lanes: usize,
97 lane_offsets: &[u32],
98 patch_index: &[u16],
99 patch_value: &[V],
100) {
101 let n_chunks = (offset + len).div_ceil(1024);
102 let mut index_map = FxHashMap::with_capacity_and_hasher(patch_index.len(), Default::default());
104 for chunk in 0..n_chunks {
105 for lane in 0..n_lanes {
106 let lane_start = lane_offsets[chunk * n_lanes + lane];
107 let lane_end = lane_offsets[chunk * n_lanes + lane + 1];
108 for i in lane_start..lane_end {
109 let patch_idx = patch_index[i as usize];
110 let patch_value = patch_value[i as usize];
111
112 let index = chunk * 1024 + patch_idx as usize;
113 if index >= offset && index < offset + len {
114 index_map.insert(index - offset, patch_value);
115 }
116 }
117 }
118 }
119
120 for (output_index, index) in indices.iter().enumerate() {
123 let index = index.as_();
124 if let Some(&patch_value) = index_map.get(&index) {
125 output[output_index] = patch_value;
126 }
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use std::ops::Range;
133
134 use vortex_buffer::buffer;
135 use vortex_error::VortexResult;
136 use vortex_session::VortexSession;
137
138 use crate::ArrayRef;
139 use crate::ExecutionCtx;
140 use crate::IntoArray;
141 use crate::arrays::Patched;
142 use crate::arrays::PrimitiveArray;
143 use crate::assert_arrays_eq;
144 use crate::patches::Patches;
145
146 fn make_patched_array(
147 base: &[u16],
148 patch_indices: &[u32],
149 patch_values: &[u16],
150 slice: Range<usize>,
151 ) -> VortexResult<ArrayRef> {
152 let values = PrimitiveArray::from_iter(base.iter().copied()).into_array();
153 let patches = Patches::new(
154 base.len(),
155 0,
156 PrimitiveArray::from_iter(patch_indices.iter().copied()).into_array(),
157 PrimitiveArray::from_iter(patch_values.iter().copied()).into_array(),
158 None,
159 )?;
160
161 let session = VortexSession::empty();
162 let mut ctx = ExecutionCtx::new(session);
163
164 Patched::from_array_and_patches(values, &patches, &mut ctx)?
165 .into_array()
166 .slice(slice)
167 }
168
169 #[test]
170 fn test_take_basic() -> VortexResult<()> {
171 let array = make_patched_array(&[0; 5], &[1, 3], &[10, 30], 0..5)?;
173
174 let indices = buffer![0u32, 1, 2, 3, 4].into_array();
176 let result = array.take(indices)?.to_canonical()?.into_array();
177
178 let expected = PrimitiveArray::from_iter([0u16, 10, 0, 30, 0]).into_array();
179 assert_arrays_eq!(expected, result);
180
181 Ok(())
182 }
183
184 #[test]
185 fn test_take_sliced() -> VortexResult<()> {
186 let array = make_patched_array(&[0; 10], &[1, 3], &[100, 200], 2..10)?;
187
188 let indices = buffer![0u32, 1, 2, 3, 7].into_array();
189 let result = array.take(indices)?.to_canonical()?.into_array();
190
191 let expected = PrimitiveArray::from_iter([0u16, 200, 0, 0, 0]).into_array();
192 assert_arrays_eq!(expected, result);
193
194 Ok(())
195 }
196
197 #[test]
198 fn test_take_out_of_order() -> VortexResult<()> {
199 let array = make_patched_array(&[0; 5], &[1, 3], &[10, 30], 0..5)?;
201
202 let indices = buffer![4u32, 3, 2, 1, 0].into_array();
204 let result = array.take(indices)?.to_canonical()?.into_array();
205
206 let expected = PrimitiveArray::from_iter([0u16, 30, 0, 10, 0]).into_array();
207 assert_arrays_eq!(expected, result);
208
209 Ok(())
210 }
211
212 #[test]
213 fn test_take_duplicates() -> VortexResult<()> {
214 let array = make_patched_array(&[0; 5], &[2], &[99], 0..5)?;
216
217 let indices = buffer![2u32, 2, 0, 2].into_array();
219 let result = array.take(indices)?.to_canonical()?.into_array();
220
221 let _canonical = result.to_canonical()?.into_primitive();
223
224 let expected = PrimitiveArray::from_iter([99u16, 99, 0, 99]).into_array();
225 assert_arrays_eq!(expected, result);
226
227 Ok(())
228 }
229
230 #[test]
231 fn test_take_with_null_indices() -> VortexResult<()> {
232 use crate::arrays::BoolArray;
233 use crate::validity::Validity;
234
235 let array = make_patched_array(&[0; 10], &[2, 5, 8], &[20, 50, 80], 0..10)?;
237
238 let indices = PrimitiveArray::new(
245 buffer![0u32, 2, 2, 5, 8, 0, 5, 8, 3, 1],
246 Validity::Array(
247 BoolArray::from_iter([
248 true, true, false, true, true, false, true, true, false, true,
249 ])
250 .into_array(),
251 ),
252 );
253 let result = array
254 .take(indices.into_array())?
255 .to_canonical()?
256 .into_array();
257
258 let expected = PrimitiveArray::new(
260 buffer![0u16, 20, 0, 50, 80, 0, 50, 80, 0, 0],
261 Validity::Array(
262 BoolArray::from_iter([
263 true, true, false, true, true, false, true, true, false, true,
264 ])
265 .into_array(),
266 ),
267 );
268 assert_arrays_eq!(expected.into_array(), result);
269
270 Ok(())
271 }
272}