Skip to main content

scirs2_numpy/
simd_copy.rs

1//! SIMD-accelerated copy for non-contiguous-to-contiguous coercion.
2//!
3//! When Python passes a non-contiguous (strided) NumPy array, it must be gathered
4//! into a contiguous buffer before it can be used as an `ndarray::ArrayView`.
5//! These routines provide a fast path for that gather operation.
6//!
7//! ## Dispatch strategy
8//!
9//! | Platform      | Condition                        | Implementation            |
10//! |---------------|----------------------------------|---------------------------|
11//! | x86_64        | `avx2` detected at runtime       | AVX2 256-bit gather        |
12//! | x86_64        | no avx2 or fallback required     | scalar loop               |
13//! | all others    | always                           | scalar loop               |
14//!
15//! When `stride == 1` the memory is already contiguous; [`ptr::copy_nonoverlapping`]
16//! is used for the fastest possible copy.
17//!
18//! ## Safety contract
19//!
20//! Both public functions are `unsafe` because they operate on raw pointers.  The
21//! caller must guarantee:
22//!
23//! - `src` points to a valid, aligned allocation of at least
24//!   `n_elements * stride * size_of::<T>()` bytes.
25//! - `dst.len() >= n_elements`.
26//! - The source and destination ranges do not overlap.
27//! - `stride * (n_elements.saturating_sub(1))` fits in `isize` (i.e., no pointer
28//!   overflow on the source side).
29//! - All `n_elements` source elements are properly initialised.
30
31use std::ptr;
32
33// ── f32 ──────────────────────────────────────────────────────────────────────
34
35/// Copy `n_elements` strided `f32` values from `src` into the contiguous slice
36/// `dst`.
37///
38/// `stride` is the gap, **in elements** (not bytes), between successive source
39/// elements. A stride of 1 means the data is already contiguous.
40///
41/// # Safety
42///
43/// See the [module-level safety contract](self).
44///
45/// # Examples
46///
47/// ```
48/// use scirs2_numpy::simd_copy::copy_strided_to_contiguous_f32;
49///
50/// // Source: every second element → [1.0, 3.0, 5.0]
51/// let src = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
52/// let mut dst = vec![0.0_f32; 3];
53/// unsafe {
54///     copy_strided_to_contiguous_f32(src.as_ptr(), &mut dst, 3, 2);
55/// }
56/// assert_eq!(dst, [1.0, 3.0, 5.0]);
57/// ```
58pub unsafe fn copy_strided_to_contiguous_f32(
59    src: *const f32,
60    dst: &mut [f32],
61    n_elements: usize,
62    stride: usize,
63) {
64    debug_assert!(
65        dst.len() >= n_elements,
66        "dst must have at least n_elements slots"
67    );
68
69    if stride == 1 {
70        // Already contiguous — single memcpy.
71        // SAFETY: caller guarantees non-overlap and src validity.
72        unsafe {
73            ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), n_elements);
74        }
75        return;
76    }
77
78    #[cfg(target_arch = "x86_64")]
79    {
80        // Guard: the index vector holds 8 lanes with offsets [0..7*stride].
81        // `7 * stride` must fit in i32 for _mm256_i32gather_ps; divide to
82        // avoid multiplication overflow before the comparison.
83        const AVX2_LANES: usize = 8;
84        if is_x86_feature_detected!("avx2") && stride <= (i32::MAX as usize) / (AVX2_LANES - 1) {
85            // SAFETY: we just checked the feature flag.  The stride bound ensures
86            // that (AVX2_LANES - 1) * stride <= i32::MAX, so all lane offsets
87            // in the gather index vector fit in i32 without overflow.
88            unsafe {
89                gather_f32_avx2(src, dst, n_elements, stride);
90            }
91            return;
92        }
93    }
94
95    // Scalar fallback.
96    unsafe {
97        scalar_gather_f32(src, dst, n_elements, stride);
98    }
99}
100
101// ── f64 ──────────────────────────────────────────────────────────────────────
102
103/// Copy `n_elements` strided `f64` values from `src` into the contiguous slice
104/// `dst`.
105///
106/// `stride` is the gap, **in elements** (not bytes), between successive source
107/// elements.
108///
109/// # Safety
110///
111/// See the [module-level safety contract](self).
112///
113/// # Examples
114///
115/// ```
116/// use scirs2_numpy::simd_copy::copy_strided_to_contiguous_f64;
117///
118/// // Source: every third element → [0.0, 3.0, 6.0]
119/// let src = vec![0.0_f64, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
120/// let mut dst = vec![0.0_f64; 3];
121/// unsafe {
122///     copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, 3, 3);
123/// }
124/// assert_eq!(dst, [0.0, 3.0, 6.0]);
125/// ```
126pub unsafe fn copy_strided_to_contiguous_f64(
127    src: *const f64,
128    dst: &mut [f64],
129    n_elements: usize,
130    stride: usize,
131) {
132    debug_assert!(
133        dst.len() >= n_elements,
134        "dst must have at least n_elements slots"
135    );
136
137    if stride == 1 {
138        // Already contiguous — single memcpy.
139        unsafe {
140            ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), n_elements);
141        }
142        return;
143    }
144
145    // AVX2 gather for f64 (256-bit = 4 × f64) is available but the index vector
146    // for _mm256_i64gather_pd requires 64-bit offsets.  Since stride is already
147    // usize the conversion is safe when stride fits in i64, which covers all
148    // practical cases.
149    #[cfg(target_arch = "x86_64")]
150    {
151        if is_x86_feature_detected!("avx2") {
152            unsafe {
153                gather_f64_avx2(src, dst, n_elements, stride);
154            }
155            return;
156        }
157    }
158
159    unsafe {
160        scalar_gather_f64(src, dst, n_elements, stride);
161    }
162}
163
164// ── scalar helpers ────────────────────────────────────────────────────────────
165
166/// Scalar gather for f32.
167///
168/// # Safety
169/// Caller must uphold the module-level contract.
170#[inline]
171unsafe fn scalar_gather_f32(src: *const f32, dst: &mut [f32], n_elements: usize, stride: usize) {
172    for i in 0..n_elements {
173        // SAFETY: caller guarantees src validity for stride*n_elements elements.
174        *dst.get_unchecked_mut(i) = *src.add(i * stride);
175    }
176}
177
178/// Scalar gather for f64.
179///
180/// # Safety
181/// Caller must uphold the module-level contract.
182#[inline]
183unsafe fn scalar_gather_f64(src: *const f64, dst: &mut [f64], n_elements: usize, stride: usize) {
184    for i in 0..n_elements {
185        *dst.get_unchecked_mut(i) = *src.add(i * stride);
186    }
187}
188
189// ── AVX2 paths ────────────────────────────────────────────────────────────────
190
191/// AVX2 gather for f32 using `_mm256_i32gather_ps`.
192///
193/// Processes 8 elements per iteration (256 bits / 32 bits = 8 lanes).
194/// A tail scalar loop handles the remainder.
195///
196/// # Safety
197/// - AVX2 must be available (checked by caller via `is_x86_feature_detected!`).
198/// - `stride <= i32::MAX` must hold (checked by caller).
199/// - All module-level pointer-validity constraints apply.
200#[cfg(target_arch = "x86_64")]
201#[target_feature(enable = "avx2")]
202unsafe fn gather_f32_avx2(src: *const f32, dst: &mut [f32], n_elements: usize, stride: usize) {
203    #[cfg(target_arch = "x86_64")]
204    use std::arch::x86_64::*;
205
206    let stride_i32 = stride as i32;
207
208    // Build a constant index vector [0, stride, 2*stride, …, 7*stride].
209    // These are element-wise offsets; the intrinsic interprets them in bytes
210    // when scale=4 (4 bytes per f32).
211    let vindex = _mm256_set_epi32(
212        7 * stride_i32,
213        6 * stride_i32,
214        5 * stride_i32,
215        4 * stride_i32,
216        3 * stride_i32,
217        2 * stride_i32,
218        stride_i32,
219        0,
220    );
221
222    let chunks = n_elements / 8;
223    let remainder = n_elements % 8;
224
225    let mut dst_ptr = dst.as_mut_ptr();
226
227    for chunk in 0..chunks {
228        let chunk_src = src.add(chunk * 8 * stride);
229        // _mm256_i32gather_ps: gather 8 f32s at base + vindex[i] * scale.
230        // scale=4 because we provide element offsets and each f32 is 4 bytes.
231        let gathered = _mm256_i32gather_ps(chunk_src, vindex, 4);
232        // Store unaligned — dst_ptr may not be 32-byte aligned.
233        _mm256_storeu_ps(dst_ptr, gathered);
234        dst_ptr = dst_ptr.add(8);
235    }
236
237    // Scalar tail.
238    let tail_src_base = src.add(chunks * 8 * stride);
239    for i in 0..remainder {
240        *dst_ptr.add(i) = *tail_src_base.add(i * stride);
241    }
242}
243
244/// AVX2 gather for f64 using `_mm256_i64gather_pd`.
245///
246/// Processes 4 elements per iteration (256 bits / 64 bits = 4 lanes).
247///
248/// # Safety
249/// Same as [`gather_f32_avx2`].
250#[cfg(target_arch = "x86_64")]
251#[target_feature(enable = "avx2")]
252unsafe fn gather_f64_avx2(src: *const f64, dst: &mut [f64], n_elements: usize, stride: usize) {
253    #[cfg(target_arch = "x86_64")]
254    use std::arch::x86_64::*;
255
256    // stride is usize; safe cast because any stride that could OOM the machine
257    // would have already caused the allocation to fail.
258    let stride_i64 = stride as i64;
259
260    let vindex = _mm256_set_epi64x(3 * stride_i64, 2 * stride_i64, stride_i64, 0);
261
262    let chunks = n_elements / 4;
263    let remainder = n_elements % 4;
264    let mut dst_ptr = dst.as_mut_ptr();
265
266    for chunk in 0..chunks {
267        let chunk_src = src.add(chunk * 4 * stride);
268        // scale=8: each f64 occupies 8 bytes; vindex gives element offsets.
269        let gathered = _mm256_i64gather_pd(chunk_src, vindex, 8);
270        _mm256_storeu_pd(dst_ptr, gathered);
271        dst_ptr = dst_ptr.add(4);
272    }
273
274    let tail_src_base = src.add(chunks * 4 * stride);
275    for i in 0..remainder {
276        *dst_ptr.add(i) = *tail_src_base.add(i * stride);
277    }
278}
279
280// ── tests ─────────────────────────────────────────────────────────────────────
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn test_f32_stride1_is_memcpy() {
288        let src: Vec<f32> = (0..16).map(|x| x as f32).collect();
289        let mut dst = vec![0.0_f32; 16];
290        unsafe {
291            copy_strided_to_contiguous_f32(src.as_ptr(), &mut dst, 16, 1);
292        }
293        assert_eq!(dst, src);
294    }
295
296    #[test]
297    fn test_f32_stride2() {
298        // Every other element: [0, 2, 4, 6, 8, 10, 12, 14, 16]
299        let src: Vec<f32> = (0..18).map(|x| x as f32).collect();
300        let expected: Vec<f32> = (0..9).map(|x| (x * 2) as f32).collect();
301        let mut dst = vec![0.0_f32; 9];
302        unsafe {
303            copy_strided_to_contiguous_f32(src.as_ptr(), &mut dst, 9, 2);
304        }
305        assert_eq!(dst, expected);
306    }
307
308    #[test]
309    fn test_f32_stride3() {
310        let src: Vec<f32> = (0..21).map(|x| x as f32).collect();
311        let expected: Vec<f32> = (0..7).map(|x| (x * 3) as f32).collect();
312        let mut dst = vec![0.0_f32; 7];
313        unsafe {
314            copy_strided_to_contiguous_f32(src.as_ptr(), &mut dst, 7, 3);
315        }
316        assert_eq!(dst, expected);
317    }
318
319    #[test]
320    fn test_f64_stride1_is_memcpy() {
321        let src: Vec<f64> = (0..16).map(|x| x as f64).collect();
322        let mut dst = vec![0.0_f64; 16];
323        unsafe {
324            copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, 16, 1);
325        }
326        assert_eq!(dst, src);
327    }
328
329    #[test]
330    fn test_f64_stride2() {
331        let src: Vec<f64> = (0..18).map(|x| x as f64).collect();
332        let expected: Vec<f64> = (0..9).map(|x| (x * 2) as f64).collect();
333        let mut dst = vec![0.0_f64; 9];
334        unsafe {
335            copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, 9, 2);
336        }
337        assert_eq!(dst, expected);
338    }
339
340    #[test]
341    fn test_f64_stride4() {
342        // 1M-element benchmark documents copy overhead.
343        // We use a smaller size here to keep test fast.
344        let n = 10_000_usize;
345        let stride = 4_usize;
346        let src: Vec<f64> = (0..(n * stride)).map(|x| x as f64).collect();
347        let expected: Vec<f64> = (0..n).map(|x| (x * stride) as f64).collect();
348        let mut dst = vec![0.0_f64; n];
349        unsafe {
350            copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, n, stride);
351        }
352        assert_eq!(dst, expected);
353    }
354
355    /// Document the overhead of copying a 1M-element strided f64 array.
356    /// This test is not a performance gate — it simply ensures the operation
357    /// completes without error and produces the correct first/last values.
358    #[test]
359    fn benchmark_copy_overhead_documentation() {
360        let n = 1_000_000_usize;
361        let stride = 3_usize;
362        let src: Vec<f64> = (0..(n * stride)).map(|x| x as f64).collect();
363        let mut dst = vec![0.0_f64; n];
364
365        let start = std::time::Instant::now();
366        unsafe {
367            copy_strided_to_contiguous_f64(src.as_ptr(), &mut dst, n, stride);
368        }
369        let elapsed = start.elapsed();
370
371        // Correctness check: first and last element.
372        assert_eq!(dst[0], 0.0);
373        assert_eq!(dst[n - 1], ((n - 1) * stride) as f64);
374
375        // Overhead documentation (never fails, only informs).
376        eprintln!(
377            "copy_strided_to_contiguous_f64: {} elements, stride={}, elapsed={:.2?}",
378            n, stride, elapsed
379        );
380    }
381}