Skip to main content

ruvector_solver/
simd.rs

1//! SIMD-accelerated sparse matrix-vector multiply.
2//!
3//! Provides [`spmv_simd`], which dispatches to an architecture-specific
4//! implementation when the `simd` feature is enabled, and falls back to a
5//! portable scalar loop otherwise.
6
7use crate::types::CsrMatrix;
8
9/// Sparse matrix-vector multiply with optional SIMD acceleration.
10///
11/// Computes `y = A * x` where `A` is a CSR matrix of `f32` values.
12pub fn spmv_simd(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
13    assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols");
14    assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows");
15
16    #[cfg(all(feature = "simd", target_arch = "x86_64"))]
17    {
18        if is_x86_feature_detected!("avx2") {
19            // SAFETY: we have checked for AVX2 support at runtime.
20            unsafe {
21                spmv_avx2(matrix, x, y);
22            }
23            return;
24        }
25    }
26
27    #[cfg(target_arch = "aarch64")]
28    {
29        unsafe {
30            spmv_neon_f32(matrix, x, y);
31        }
32        return;
33    }
34
35    #[allow(unreachable_code)]
36    spmv_scalar(matrix, x, y);
37}
38
39/// Scalar fallback implementation of SpMV.
40pub fn spmv_scalar(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
41    for i in 0..matrix.rows {
42        let start = matrix.row_ptr[i];
43        let end = matrix.row_ptr[i + 1];
44        let mut sum = 0.0f32;
45        for idx in start..end {
46            let col = matrix.col_indices[idx];
47            sum += matrix.values[idx] * x[col];
48        }
49        y[i] = sum;
50    }
51}
52
53/// AVX2-accelerated SpMV for x86_64.
54///
55/// # Safety
56///
57/// - The caller must ensure AVX2 is supported on the current CPU (checked at
58///   runtime via `is_x86_feature_detected!("avx2")` in [`spmv_simd`]).
59/// - The caller must ensure `x.len() >= matrix.cols` and
60///   `y.len() >= matrix.rows`. These are asserted in [`spmv_simd`] before
61///   dispatching here.
62/// - The CSR matrix must be structurally valid: `row_ptr[i] <= row_ptr[i+1]`,
63///   all `col_indices[j] < matrix.cols`, and `values.len() >= row_ptr[rows]`.
64///   Use [`crate::validation::validate_csr_matrix`] before calling the solver
65///   to guarantee this.
66#[cfg(all(feature = "simd", target_arch = "x86_64"))]
67#[target_feature(enable = "avx2")]
68unsafe fn spmv_avx2(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
69    use std::arch::x86_64::*;
70
71    for i in 0..matrix.rows {
72        let start = matrix.row_ptr[i];
73        let end = matrix.row_ptr[i + 1];
74        let len = end - start;
75
76        let mut accum = _mm256_setzero_ps();
77        let chunks = len / 8;
78        let remainder = len % 8;
79
80        for chunk in 0..chunks {
81            let base = start + chunk * 8;
82
83            // SAFETY: `base + 7 < end <= values.len()` because
84            // `chunk < chunks` implies `base + 8 <= start + chunks * 8 <= end`.
85            let vals = _mm256_loadu_ps(matrix.values.as_ptr().add(base));
86
87            let mut x_buf = [0.0f32; 8];
88            for k in 0..8 {
89                // SAFETY: `base + k < end` so `col_indices[base + k]` is in
90                // bounds. `col < matrix.cols <= x.len()` by the CSR structural
91                // invariant (enforced by `validate_csr_matrix`).
92                let col = *matrix.col_indices.get_unchecked(base + k);
93                x_buf[k] = *x.get_unchecked(col);
94            }
95            let x_vec = _mm256_loadu_ps(x_buf.as_ptr());
96
97            accum = _mm256_add_ps(accum, _mm256_mul_ps(vals, x_vec));
98        }
99
100        let mut sum = horizontal_sum_f32x8(accum);
101
102        let tail_start = start + chunks * 8;
103        for idx in tail_start..(tail_start + remainder) {
104            // SAFETY: `idx < end <= values.len()` and `col < cols <= x.len()`
105            // by the same CSR structural invariant.
106            let col = *matrix.col_indices.get_unchecked(idx);
107            sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
108        }
109
110        // SAFETY: `i < matrix.rows <= y.len()` by the assert in `spmv_simd`.
111        *y.get_unchecked_mut(i) = sum;
112    }
113}
114
115/// Horizontal sum of an AVX2 register (8 x f32 -> 1 x f32).
116#[cfg(all(feature = "simd", target_arch = "x86_64"))]
117#[target_feature(enable = "avx2")]
118unsafe fn horizontal_sum_f32x8(v: std::arch::x86_64::__m256) -> f32 {
119    use std::arch::x86_64::*;
120
121    let hi = _mm256_extractf128_ps(v, 1);
122    let lo = _mm256_castps256_ps128(v);
123    let sum128 = _mm_add_ps(lo, hi);
124
125    let shuf = _mm_movehdup_ps(sum128);
126    let sums = _mm_add_ps(sum128, shuf);
127    let shuf2 = _mm_movehl_ps(sums, sums);
128    let result = _mm_add_ss(sums, shuf2);
129    _mm_cvtss_f32(result)
130}
131
132/// Sparse matrix-vector multiply with optional SIMD acceleration for f64.
133///
134/// Computes `y = A * x` where `A` is a CSR matrix of `f64` values.
135pub fn spmv_simd_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
136    assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols");
137    assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows");
138
139    #[cfg(all(feature = "simd", target_arch = "x86_64"))]
140    {
141        if is_x86_feature_detected!("avx2") {
142            unsafe {
143                spmv_avx2_f64(matrix, x, y);
144            }
145            return;
146        }
147    }
148
149    #[cfg(target_arch = "aarch64")]
150    {
151        unsafe {
152            spmv_neon_f64(matrix, x, y);
153        }
154        return;
155    }
156
157    #[allow(unreachable_code)]
158    spmv_scalar_f64(matrix, x, y);
159}
160
161/// Scalar fallback for f64 SpMV.
162pub fn spmv_scalar_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
163    for i in 0..matrix.rows {
164        let start = matrix.row_ptr[i];
165        let end = matrix.row_ptr[i + 1];
166        let mut sum = 0.0f64;
167        for idx in start..end {
168            let col = matrix.col_indices[idx];
169            sum += matrix.values[idx] * x[col];
170        }
171        y[i] = sum;
172    }
173}
174
175#[cfg(all(feature = "simd", target_arch = "x86_64"))]
176#[target_feature(enable = "avx2")]
177unsafe fn spmv_avx2_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
178    use std::arch::x86_64::*;
179
180    for i in 0..matrix.rows {
181        let start = matrix.row_ptr[i];
182        let end = matrix.row_ptr[i + 1];
183        let len = end - start;
184
185        let mut accum = _mm256_setzero_pd();
186        let chunks = len / 4;
187        let remainder = len % 4;
188
189        for chunk in 0..chunks {
190            let base = start + chunk * 4;
191            let vals = _mm256_loadu_pd(matrix.values.as_ptr().add(base));
192
193            let mut x_buf = [0.0f64; 4];
194            for k in 0..4 {
195                let col = *matrix.col_indices.get_unchecked(base + k);
196                x_buf[k] = *x.get_unchecked(col);
197            }
198            let x_vec = _mm256_loadu_pd(x_buf.as_ptr());
199            accum = _mm256_add_pd(accum, _mm256_mul_pd(vals, x_vec));
200        }
201
202        let mut sum = horizontal_sum_f64x4(accum);
203
204        let tail_start = start + chunks * 4;
205        for idx in tail_start..(tail_start + remainder) {
206            let col = *matrix.col_indices.get_unchecked(idx);
207            sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
208        }
209
210        *y.get_unchecked_mut(i) = sum;
211    }
212}
213
214#[cfg(all(feature = "simd", target_arch = "x86_64"))]
215#[target_feature(enable = "avx2")]
216unsafe fn horizontal_sum_f64x4(v: std::arch::x86_64::__m256d) -> f64 {
217    use std::arch::x86_64::*;
218    let hi = _mm256_extractf128_pd(v, 1);
219    let lo = _mm256_castpd256_pd128(v);
220    let sum128 = _mm_add_pd(lo, hi);
221    let hi64 = _mm_unpackhi_pd(sum128, sum128);
222    let result = _mm_add_sd(sum128, hi64);
223    _mm_cvtsd_f64(result)
224}
225
226// ---------------------------------------------------------------------------
227// NEON implementations for AArch64 / Apple Silicon (M1-M4)
228// ---------------------------------------------------------------------------
229
230/// NEON-accelerated SpMV for f32 on AArch64.
231///
232/// Uses `float32x4_t` (4-wide f32 NEON) with FMA and software prefetch.
233///
234/// # Safety
235/// Caller must ensure the CSR matrix is structurally valid and
236/// `x.len() >= matrix.cols`, `y.len() >= matrix.rows`.
237#[cfg(target_arch = "aarch64")]
238unsafe fn spmv_neon_f32(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
239    use std::arch::aarch64::*;
240
241    for i in 0..matrix.rows {
242        let start = matrix.row_ptr[i];
243        let end = matrix.row_ptr[i + 1];
244        let len = end - start;
245
246        let mut acc0 = vdupq_n_f32(0.0);
247        let mut acc1 = vdupq_n_f32(0.0);
248        let chunks = len / 8;
249        let mid_remainder = (len % 8) / 4;
250        let tail_remainder = len % 4;
251
252        // 2x unrolled: 8 f32 per iteration (2 NEON regs × 4)
253        for chunk in 0..chunks {
254            let base = start + chunk * 8;
255            // Prefetch next chunk
256            // Contiguous values benefit from hardware prefetch on M-series
257
258            let v0 = vld1q_f32(matrix.values.as_ptr().add(base));
259            let v1 = vld1q_f32(matrix.values.as_ptr().add(base + 4));
260
261            // Gather x values (sparse column access)
262            let mut xbuf0 = [0.0f32; 4];
263            let mut xbuf1 = [0.0f32; 4];
264            for k in 0..4 {
265                xbuf0[k] = *x.get_unchecked(*matrix.col_indices.get_unchecked(base + k));
266                xbuf1[k] = *x.get_unchecked(*matrix.col_indices.get_unchecked(base + 4 + k));
267            }
268            let x0 = vld1q_f32(xbuf0.as_ptr());
269            let x1 = vld1q_f32(xbuf1.as_ptr());
270
271            acc0 = vfmaq_f32(acc0, v0, x0);
272            acc1 = vfmaq_f32(acc1, v1, x1);
273        }
274
275        // Process remaining 4-element chunk
276        let mid_start = start + chunks * 8;
277        if mid_remainder > 0 {
278            let v0 = vld1q_f32(matrix.values.as_ptr().add(mid_start));
279            let mut xbuf = [0.0f32; 4];
280            for k in 0..4 {
281                xbuf[k] = *x.get_unchecked(*matrix.col_indices.get_unchecked(mid_start + k));
282            }
283            let x0 = vld1q_f32(xbuf.as_ptr());
284            acc0 = vfmaq_f32(acc0, v0, x0);
285        }
286
287        // Combine accumulators and reduce
288        let combined = vaddq_f32(acc0, acc1);
289        let mut sum = vaddvq_f32(combined);
290
291        // Scalar tail
292        let tail_start = start + len - tail_remainder;
293        for idx in tail_start..end {
294            let col = *matrix.col_indices.get_unchecked(idx);
295            sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
296        }
297
298        *y.get_unchecked_mut(i) = sum;
299    }
300}
301
302/// NEON-accelerated SpMV for f64 on AArch64.
303///
304/// Uses `float64x2_t` (2-wide f64 NEON) with FMA and software prefetch.
305///
306/// # Safety
307/// Caller must ensure the CSR matrix is structurally valid and
308/// `x.len() >= matrix.cols`, `y.len() >= matrix.rows`.
309#[cfg(target_arch = "aarch64")]
310unsafe fn spmv_neon_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
311    use std::arch::aarch64::*;
312
313    for i in 0..matrix.rows {
314        let start = matrix.row_ptr[i];
315        let end = matrix.row_ptr[i + 1];
316        let len = end - start;
317
318        let mut acc0 = vdupq_n_f64(0.0);
319        let mut acc1 = vdupq_n_f64(0.0);
320        let chunks = len / 4;
321        let remainder = len % 4;
322
323        // 2x unrolled: 4 f64 per iteration (2 NEON regs × 2)
324        for chunk in 0..chunks {
325            let base = start + chunk * 4;
326            // Contiguous values benefit from hardware prefetch on M-series
327
328            let v0 = vld1q_f64(matrix.values.as_ptr().add(base));
329            let v1 = vld1q_f64(matrix.values.as_ptr().add(base + 2));
330
331            let mut xbuf0 = [0.0f64; 2];
332            let mut xbuf1 = [0.0f64; 2];
333            for k in 0..2 {
334                xbuf0[k] = *x.get_unchecked(*matrix.col_indices.get_unchecked(base + k));
335                xbuf1[k] = *x.get_unchecked(*matrix.col_indices.get_unchecked(base + 2 + k));
336            }
337            let x0 = vld1q_f64(xbuf0.as_ptr());
338            let x1 = vld1q_f64(xbuf1.as_ptr());
339
340            acc0 = vfmaq_f64(acc0, v0, x0);
341            acc1 = vfmaq_f64(acc1, v1, x1);
342        }
343
344        let combined = vaddq_f64(acc0, acc1);
345        let mut sum = vgetq_lane_f64(combined, 0) + vgetq_lane_f64(combined, 1);
346
347        // Scalar tail
348        let tail_start = start + chunks * 4;
349        for idx in tail_start..(tail_start + remainder) {
350            let col = *matrix.col_indices.get_unchecked(idx);
351            sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
352        }
353
354        *y.get_unchecked_mut(i) = sum;
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::types::CsrMatrix;
362
363    fn make_test_matrix() -> (CsrMatrix<f32>, Vec<f32>) {
364        // [2 0 1]   [1]   [5]
365        // [0 3 0] * [2] = [6]
366        // [1 0 4]   [3]   [13]
367        let mat = CsrMatrix {
368            values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
369            col_indices: vec![0, 2, 1, 0, 2],
370            row_ptr: vec![0, 2, 3, 5],
371            rows: 3,
372            cols: 3,
373        };
374        let x = vec![1.0, 2.0, 3.0];
375        (mat, x)
376    }
377
378    #[test]
379    fn scalar_spmv_correctness() {
380        let (mat, x) = make_test_matrix();
381        let mut y = vec![0.0f32; 3];
382        spmv_scalar(&mat, &x, &mut y);
383        assert!((y[0] - 5.0).abs() < 1e-6);
384        assert!((y[1] - 6.0).abs() < 1e-6);
385        assert!((y[2] - 13.0).abs() < 1e-6);
386    }
387
388    #[test]
389    fn spmv_simd_dispatch() {
390        let (mat, x) = make_test_matrix();
391        let mut y = vec![0.0f32; 3];
392        spmv_simd(&mat, &x, &mut y);
393        assert!((y[0] - 5.0).abs() < 1e-6);
394        assert!((y[1] - 6.0).abs() < 1e-6);
395        assert!((y[2] - 13.0).abs() < 1e-6);
396    }
397
398    #[test]
399    fn spmv_simd_f64_correctness() {
400        let mat = CsrMatrix::<f64> {
401            values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
402            col_indices: vec![0, 2, 1, 0, 2],
403            row_ptr: vec![0, 2, 3, 5],
404            rows: 3,
405            cols: 3,
406        };
407        let x = vec![1.0, 2.0, 3.0];
408        let mut y = vec![0.0f64; 3];
409        spmv_simd_f64(&mat, &x, &mut y);
410        assert!((y[0] - 5.0).abs() < 1e-10);
411        assert!((y[1] - 6.0).abs() < 1e-10);
412        assert!((y[2] - 13.0).abs() < 1e-10);
413    }
414
415    #[test]
416    fn scalar_spmv_f64_correctness() {
417        let mat = CsrMatrix::<f64> {
418            values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
419            col_indices: vec![0, 2, 1, 0, 2],
420            row_ptr: vec![0, 2, 3, 5],
421            rows: 3,
422            cols: 3,
423        };
424        let x = vec![1.0, 2.0, 3.0];
425        let mut y = vec![0.0f64; 3];
426        spmv_scalar_f64(&mat, &x, &mut y);
427        assert!((y[0] - 5.0).abs() < 1e-10);
428        assert!((y[1] - 6.0).abs() < 1e-10);
429        assert!((y[2] - 13.0).abs() < 1e-10);
430    }
431}