scirs2_numpy/simd_copy.rs
1//! SIMD-accelerated copy for non-contiguous-to-contiguous coercion.
2//!
3//! When Python passes a non-contiguous (strided) NumPy array, it must be gathered
4//! into a contiguous buffer before it can be used as an `ndarray::ArrayView`.
5//! These routines provide a fast path for that gather operation.
6//!
7//! ## Dispatch strategy
8//!
9//! | Platform | Condition | Implementation |
10//! |---------------|----------------------------------|---------------------------|
11//! | x86_64 | `avx2` detected at runtime | AVX2 256-bit gather |
12//! | x86_64 | no avx2 or fallback required | scalar loop |
13//! | all others | always | scalar loop |
14//!
15//! When `stride == 1` the memory is already contiguous; [`ptr::copy_nonoverlapping`]
16//! is used for the fastest possible copy.
17//!
18//! ## Safety contract
19//!
20//! Both public functions are `unsafe` because they operate on raw pointers. The
21//! caller must guarantee:
22//!
23//! - `src` points to a valid, aligned allocation of at least
24//! `n_elements * stride * size_of::<T>()` bytes.
25//! - `dst.len() >= n_elements`.
26//! - The source and destination ranges do not overlap.
27//! - `stride * (n_elements.saturating_sub(1))` fits in `isize` (i.e., no pointer
28//! overflow on the source side).
29//! - All `n_elements` source elements are properly initialised.
30
31use std::ptr;
32
33// ── f32 ──────────────────────────────────────────────────────────────────────
34
35/// Copy `n_elements` strided `f32` values from `src` into the contiguous slice
36/// `dst`.
37///
38/// `stride` is the gap, **in elements** (not bytes), between successive source
39/// elements. A stride of 1 means the data is already contiguous.
40///
41/// # Safety
42///
43/// See the [module-level safety contract](self).
44///
45/// # Examples
46///
47/// ```
48/// use scirs2_numpy::simd_copy::copy_strided_to_contiguous_f32;
49///
50/// // Source: every second element → [1.0, 3.0, 5.0]
51/// let src = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
52/// let mut dst = vec![0.0_f32; 3];
53/// unsafe {
54/// copy_strided_to_contiguous_f32(src.as_ptr(), &mut dst, 3, 2);
55/// }
56/// assert_eq!(dst, [1.0, 3.0, 5.0]);
57/// ```
58pub unsafe fn copy_strided_to_contiguous_f32(
59 src: *const f32,
60 dst: &mut [f32],
61 n_elements: usize,
62 stride: usize,
63) {
64 debug_assert!(
65 dst.len() >= n_elements,
66 "dst must have at least n_elements slots"
67 );
68
69 if stride == 1 {
70 // Already contiguous — single memcpy.
71 // SAFETY: caller guarantees non-overlap and src validity.
72 unsafe {
73 ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), n_elements);
74 }
75 return;
76 }
77
78 #[cfg(target_arch = "x86_64")]
79 {
80 // Guard: the index vector holds 8 lanes with offsets [0..7*stride].
81 // `7 * stride` must fit in i32 for _mm256_i32gather_ps; divide to
82 // avoid multiplication overflow before the comparison.
83 const AVX2_LANES: usize = 8;
84 if is_x86_feature_detected!("avx2") && stride <= (i32::MAX as usize) / (AVX2_LANES - 1) {
85 // SAFETY: we just checked the feature flag. The stride bound ensures
86 // that (AVX2_LANES - 1) * stride <= i32::MAX, so all lane offsets
87 // in the gather index vector fit in i32 without overflow.
88 unsafe {
89 gather_f32_avx2(src, dst, n_elements, stride);
90 }
91 return;
92 }
93 }
94
95 // Scalar fallback.
96 unsafe {
97 scalar_gather_f32(src, dst, n_elements, stride);
98 }
99}
100
101// ── f64 ──────────────────────────────────────────────────────────────────────
102
103/// Copy `n_elements` strided `f64` values from `src` into the contiguous slice
104/// `dst`.
105///
106/// `stride` is the gap, **in elements** (not bytes), between successive source
107/// elements.
108///
109/// # Safety
110///
111/// See the [module-level safety contract](self).
112///
113/// # Examples
114///
115/// ```
116/// use scirs2_numpy::simd_copy::copy_strided_to_contiguous_f64;
117///
118/// // Source: every third element → [0.0, 3.0, 6.0]
119/// let src = vec![0.0_f64, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
120/// let mut dst = vec![0.0_f64; 3];
121/// unsafe {
122/// copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, 3, 3);
123/// }
124/// assert_eq!(dst, [0.0, 3.0, 6.0]);
125/// ```
126pub unsafe fn copy_strided_to_contiguous_f64(
127 src: *const f64,
128 dst: &mut [f64],
129 n_elements: usize,
130 stride: usize,
131) {
132 debug_assert!(
133 dst.len() >= n_elements,
134 "dst must have at least n_elements slots"
135 );
136
137 if stride == 1 {
138 // Already contiguous — single memcpy.
139 unsafe {
140 ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), n_elements);
141 }
142 return;
143 }
144
145 // AVX2 gather for f64 (256-bit = 4 × f64) is available but the index vector
146 // for _mm256_i64gather_pd requires 64-bit offsets. Since stride is already
147 // usize the conversion is safe when stride fits in i64, which covers all
148 // practical cases.
149 #[cfg(target_arch = "x86_64")]
150 {
151 if is_x86_feature_detected!("avx2") {
152 unsafe {
153 gather_f64_avx2(src, dst, n_elements, stride);
154 }
155 return;
156 }
157 }
158
159 unsafe {
160 scalar_gather_f64(src, dst, n_elements, stride);
161 }
162}
163
164// ── scalar helpers ────────────────────────────────────────────────────────────
165
166/// Scalar gather for f32.
167///
168/// # Safety
169/// Caller must uphold the module-level contract.
170#[inline]
171unsafe fn scalar_gather_f32(src: *const f32, dst: &mut [f32], n_elements: usize, stride: usize) {
172 for i in 0..n_elements {
173 // SAFETY: caller guarantees src validity for stride*n_elements elements.
174 *dst.get_unchecked_mut(i) = *src.add(i * stride);
175 }
176}
177
178/// Scalar gather for f64.
179///
180/// # Safety
181/// Caller must uphold the module-level contract.
182#[inline]
183unsafe fn scalar_gather_f64(src: *const f64, dst: &mut [f64], n_elements: usize, stride: usize) {
184 for i in 0..n_elements {
185 *dst.get_unchecked_mut(i) = *src.add(i * stride);
186 }
187}
188
189// ── AVX2 paths ────────────────────────────────────────────────────────────────
190
191/// AVX2 gather for f32 using `_mm256_i32gather_ps`.
192///
193/// Processes 8 elements per iteration (256 bits / 32 bits = 8 lanes).
194/// A tail scalar loop handles the remainder.
195///
196/// # Safety
197/// - AVX2 must be available (checked by caller via `is_x86_feature_detected!`).
198/// - `stride <= i32::MAX` must hold (checked by caller).
199/// - All module-level pointer-validity constraints apply.
200#[cfg(target_arch = "x86_64")]
201#[target_feature(enable = "avx2")]
202unsafe fn gather_f32_avx2(src: *const f32, dst: &mut [f32], n_elements: usize, stride: usize) {
203 #[cfg(target_arch = "x86_64")]
204 use std::arch::x86_64::*;
205
206 let stride_i32 = stride as i32;
207
208 // Build a constant index vector [0, stride, 2*stride, …, 7*stride].
209 // These are element-wise offsets; the intrinsic interprets them in bytes
210 // when scale=4 (4 bytes per f32).
211 let vindex = _mm256_set_epi32(
212 7 * stride_i32,
213 6 * stride_i32,
214 5 * stride_i32,
215 4 * stride_i32,
216 3 * stride_i32,
217 2 * stride_i32,
218 stride_i32,
219 0,
220 );
221
222 let chunks = n_elements / 8;
223 let remainder = n_elements % 8;
224
225 let mut dst_ptr = dst.as_mut_ptr();
226
227 for chunk in 0..chunks {
228 let chunk_src = src.add(chunk * 8 * stride);
229 // _mm256_i32gather_ps: gather 8 f32s at base + vindex[i] * scale.
230 // scale=4 because we provide element offsets and each f32 is 4 bytes.
231 let gathered = _mm256_i32gather_ps(chunk_src, vindex, 4);
232 // Store unaligned — dst_ptr may not be 32-byte aligned.
233 _mm256_storeu_ps(dst_ptr, gathered);
234 dst_ptr = dst_ptr.add(8);
235 }
236
237 // Scalar tail.
238 let tail_src_base = src.add(chunks * 8 * stride);
239 for i in 0..remainder {
240 *dst_ptr.add(i) = *tail_src_base.add(i * stride);
241 }
242}
243
244/// AVX2 gather for f64 using `_mm256_i64gather_pd`.
245///
246/// Processes 4 elements per iteration (256 bits / 64 bits = 4 lanes).
247///
248/// # Safety
249/// Same as [`gather_f32_avx2`].
250#[cfg(target_arch = "x86_64")]
251#[target_feature(enable = "avx2")]
252unsafe fn gather_f64_avx2(src: *const f64, dst: &mut [f64], n_elements: usize, stride: usize) {
253 #[cfg(target_arch = "x86_64")]
254 use std::arch::x86_64::*;
255
256 // stride is usize; safe cast because any stride that could OOM the machine
257 // would have already caused the allocation to fail.
258 let stride_i64 = stride as i64;
259
260 let vindex = _mm256_set_epi64x(3 * stride_i64, 2 * stride_i64, stride_i64, 0);
261
262 let chunks = n_elements / 4;
263 let remainder = n_elements % 4;
264 let mut dst_ptr = dst.as_mut_ptr();
265
266 for chunk in 0..chunks {
267 let chunk_src = src.add(chunk * 4 * stride);
268 // scale=8: each f64 occupies 8 bytes; vindex gives element offsets.
269 let gathered = _mm256_i64gather_pd(chunk_src, vindex, 8);
270 _mm256_storeu_pd(dst_ptr, gathered);
271 dst_ptr = dst_ptr.add(4);
272 }
273
274 let tail_src_base = src.add(chunks * 4 * stride);
275 for i in 0..remainder {
276 *dst_ptr.add(i) = *tail_src_base.add(i * stride);
277 }
278}
279
280// ── tests ─────────────────────────────────────────────────────────────────────
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_f32_stride1_is_memcpy() {
288 let src: Vec<f32> = (0..16).map(|x| x as f32).collect();
289 let mut dst = vec![0.0_f32; 16];
290 unsafe {
291 copy_strided_to_contiguous_f32(src.as_ptr(), &mut dst, 16, 1);
292 }
293 assert_eq!(dst, src);
294 }
295
296 #[test]
297 fn test_f32_stride2() {
298 // Every other element: [0, 2, 4, 6, 8, 10, 12, 14, 16]
299 let src: Vec<f32> = (0..18).map(|x| x as f32).collect();
300 let expected: Vec<f32> = (0..9).map(|x| (x * 2) as f32).collect();
301 let mut dst = vec![0.0_f32; 9];
302 unsafe {
303 copy_strided_to_contiguous_f32(src.as_ptr(), &mut dst, 9, 2);
304 }
305 assert_eq!(dst, expected);
306 }
307
308 #[test]
309 fn test_f32_stride3() {
310 let src: Vec<f32> = (0..21).map(|x| x as f32).collect();
311 let expected: Vec<f32> = (0..7).map(|x| (x * 3) as f32).collect();
312 let mut dst = vec![0.0_f32; 7];
313 unsafe {
314 copy_strided_to_contiguous_f32(src.as_ptr(), &mut dst, 7, 3);
315 }
316 assert_eq!(dst, expected);
317 }
318
319 #[test]
320 fn test_f64_stride1_is_memcpy() {
321 let src: Vec<f64> = (0..16).map(|x| x as f64).collect();
322 let mut dst = vec![0.0_f64; 16];
323 unsafe {
324 copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, 16, 1);
325 }
326 assert_eq!(dst, src);
327 }
328
329 #[test]
330 fn test_f64_stride2() {
331 let src: Vec<f64> = (0..18).map(|x| x as f64).collect();
332 let expected: Vec<f64> = (0..9).map(|x| (x * 2) as f64).collect();
333 let mut dst = vec![0.0_f64; 9];
334 unsafe {
335 copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, 9, 2);
336 }
337 assert_eq!(dst, expected);
338 }
339
340 #[test]
341 fn test_f64_stride4() {
342 // 1M-element benchmark documents copy overhead.
343 // We use a smaller size here to keep test fast.
344 let n = 10_000_usize;
345 let stride = 4_usize;
346 let src: Vec<f64> = (0..(n * stride)).map(|x| x as f64).collect();
347 let expected: Vec<f64> = (0..n).map(|x| (x * stride) as f64).collect();
348 let mut dst = vec![0.0_f64; n];
349 unsafe {
350 copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, n, stride);
351 }
352 assert_eq!(dst, expected);
353 }
354
355 /// Document the overhead of copying a 1M-element strided f64 array.
356 /// This test is not a performance gate — it simply ensures the operation
357 /// completes without error and produces the correct first/last values.
358 #[test]
359 fn benchmark_copy_overhead_documentation() {
360 let n = 1_000_000_usize;
361 let stride = 3_usize;
362 let src: Vec<f64> = (0..(n * stride)).map(|x| x as f64).collect();
363 let mut dst = vec![0.0_f64; n];
364
365 let start = std::time::Instant::now();
366 unsafe {
367 copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, n, stride);
368 }
369 let elapsed = start.elapsed();
370
371 // Correctness check: first and last element.
372 assert_eq!(dst[0], 0.0);
373 assert_eq!(dst[n - 1], ((n - 1) * stride) as f64);
374
375 // Overhead documentation (never fails, only informs).
376 eprintln!(
377 "copy_strided_to_contiguous_f64: {} elements, stride={}, elapsed={:.2?}",
378 n, stride, elapsed
379 );
380 }
381}