Skip to main content

ruvector_solver/
simd.rs

1//! SIMD-accelerated sparse matrix-vector multiply.
2//!
3//! Provides [`spmv_simd`], which dispatches to an architecture-specific
4//! implementation when the `simd` feature is enabled, and falls back to a
5//! portable scalar loop otherwise.
6
7use crate::types::CsrMatrix;
8
9/// Sparse matrix-vector multiply with optional SIMD acceleration.
10///
11/// Computes `y = A * x` where `A` is a CSR matrix of `f32` values.
12pub fn spmv_simd(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
13    assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols");
14    assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows");
15
16    #[cfg(all(feature = "simd", target_arch = "x86_64"))]
17    {
18        if is_x86_feature_detected!("avx2") {
19            // SAFETY: we have checked for AVX2 support at runtime.
20            unsafe {
21                spmv_avx2(matrix, x, y);
22            }
23            return;
24        }
25    }
26
27    spmv_scalar(matrix, x, y);
28}
29
30/// Scalar fallback implementation of SpMV.
31pub fn spmv_scalar(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
32    for i in 0..matrix.rows {
33        let start = matrix.row_ptr[i];
34        let end = matrix.row_ptr[i + 1];
35        let mut sum = 0.0f32;
36        for idx in start..end {
37            let col = matrix.col_indices[idx];
38            sum += matrix.values[idx] * x[col];
39        }
40        y[i] = sum;
41    }
42}
43
44/// AVX2-accelerated SpMV for x86_64.
45///
46/// # Safety
47///
48/// - The caller must ensure AVX2 is supported on the current CPU (checked at
49///   runtime via `is_x86_feature_detected!("avx2")` in [`spmv_simd`]).
50/// - The caller must ensure `x.len() >= matrix.cols` and
51///   `y.len() >= matrix.rows`. These are asserted in [`spmv_simd`] before
52///   dispatching here.
53/// - The CSR matrix must be structurally valid: `row_ptr[i] <= row_ptr[i+1]`,
54///   all `col_indices[j] < matrix.cols`, and `values.len() >= row_ptr[rows]`.
55///   Use [`crate::validation::validate_csr_matrix`] before calling the solver
56///   to guarantee this.
57#[cfg(all(feature = "simd", target_arch = "x86_64"))]
58#[target_feature(enable = "avx2")]
59unsafe fn spmv_avx2(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
60    use std::arch::x86_64::*;
61
62    for i in 0..matrix.rows {
63        let start = matrix.row_ptr[i];
64        let end = matrix.row_ptr[i + 1];
65        let len = end - start;
66
67        let mut accum = _mm256_setzero_ps();
68        let chunks = len / 8;
69        let remainder = len % 8;
70
71        for chunk in 0..chunks {
72            let base = start + chunk * 8;
73
74            // SAFETY: `base + 7 < end <= values.len()` because
75            // `chunk < chunks` implies `base + 8 <= start + chunks * 8 <= end`.
76            let vals = _mm256_loadu_ps(matrix.values.as_ptr().add(base));
77
78            let mut x_buf = [0.0f32; 8];
79            for k in 0..8 {
80                // SAFETY: `base + k < end` so `col_indices[base + k]` is in
81                // bounds. `col < matrix.cols <= x.len()` by the CSR structural
82                // invariant (enforced by `validate_csr_matrix`).
83                let col = *matrix.col_indices.get_unchecked(base + k);
84                x_buf[k] = *x.get_unchecked(col);
85            }
86            let x_vec = _mm256_loadu_ps(x_buf.as_ptr());
87
88            accum = _mm256_add_ps(accum, _mm256_mul_ps(vals, x_vec));
89        }
90
91        let mut sum = horizontal_sum_f32x8(accum);
92
93        let tail_start = start + chunks * 8;
94        for idx in tail_start..(tail_start + remainder) {
95            // SAFETY: `idx < end <= values.len()` and `col < cols <= x.len()`
96            // by the same CSR structural invariant.
97            let col = *matrix.col_indices.get_unchecked(idx);
98            sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
99        }
100
101        // SAFETY: `i < matrix.rows <= y.len()` by the assert in `spmv_simd`.
102        *y.get_unchecked_mut(i) = sum;
103    }
104}
105
106/// Horizontal sum of an AVX2 register (8 x f32 -> 1 x f32).
107#[cfg(all(feature = "simd", target_arch = "x86_64"))]
108#[target_feature(enable = "avx2")]
109unsafe fn horizontal_sum_f32x8(v: std::arch::x86_64::__m256) -> f32 {
110    use std::arch::x86_64::*;
111
112    let hi = _mm256_extractf128_ps(v, 1);
113    let lo = _mm256_castps256_ps128(v);
114    let sum128 = _mm_add_ps(lo, hi);
115
116    let shuf = _mm_movehdup_ps(sum128);
117    let sums = _mm_add_ps(sum128, shuf);
118    let shuf2 = _mm_movehl_ps(sums, sums);
119    let result = _mm_add_ss(sums, shuf2);
120    _mm_cvtss_f32(result)
121}
122
123/// Sparse matrix-vector multiply with optional SIMD acceleration for f64.
124///
125/// Computes `y = A * x` where `A` is a CSR matrix of `f64` values.
126pub fn spmv_simd_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
127    assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols");
128    assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows");
129
130    #[cfg(all(feature = "simd", target_arch = "x86_64"))]
131    {
132        if is_x86_feature_detected!("avx2") {
133            unsafe {
134                spmv_avx2_f64(matrix, x, y);
135            }
136            return;
137        }
138    }
139
140    spmv_scalar_f64(matrix, x, y);
141}
142
143/// Scalar fallback for f64 SpMV.
144pub fn spmv_scalar_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
145    for i in 0..matrix.rows {
146        let start = matrix.row_ptr[i];
147        let end = matrix.row_ptr[i + 1];
148        let mut sum = 0.0f64;
149        for idx in start..end {
150            let col = matrix.col_indices[idx];
151            sum += matrix.values[idx] * x[col];
152        }
153        y[i] = sum;
154    }
155}
156
157#[cfg(all(feature = "simd", target_arch = "x86_64"))]
158#[target_feature(enable = "avx2")]
159unsafe fn spmv_avx2_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
160    use std::arch::x86_64::*;
161
162    for i in 0..matrix.rows {
163        let start = matrix.row_ptr[i];
164        let end = matrix.row_ptr[i + 1];
165        let len = end - start;
166
167        let mut accum = _mm256_setzero_pd();
168        let chunks = len / 4;
169        let remainder = len % 4;
170
171        for chunk in 0..chunks {
172            let base = start + chunk * 4;
173            let vals = _mm256_loadu_pd(matrix.values.as_ptr().add(base));
174
175            let mut x_buf = [0.0f64; 4];
176            for k in 0..4 {
177                let col = *matrix.col_indices.get_unchecked(base + k);
178                x_buf[k] = *x.get_unchecked(col);
179            }
180            let x_vec = _mm256_loadu_pd(x_buf.as_ptr());
181            accum = _mm256_add_pd(accum, _mm256_mul_pd(vals, x_vec));
182        }
183
184        let mut sum = horizontal_sum_f64x4(accum);
185
186        let tail_start = start + chunks * 4;
187        for idx in tail_start..(tail_start + remainder) {
188            let col = *matrix.col_indices.get_unchecked(idx);
189            sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
190        }
191
192        *y.get_unchecked_mut(i) = sum;
193    }
194}
195
196#[cfg(all(feature = "simd", target_arch = "x86_64"))]
197#[target_feature(enable = "avx2")]
198unsafe fn horizontal_sum_f64x4(v: std::arch::x86_64::__m256d) -> f64 {
199    use std::arch::x86_64::*;
200    let hi = _mm256_extractf128_pd(v, 1);
201    let lo = _mm256_castpd256_pd128(v);
202    let sum128 = _mm_add_pd(lo, hi);
203    let hi64 = _mm_unpackhi_pd(sum128, sum128);
204    let result = _mm_add_sd(sum128, hi64);
205    _mm_cvtsd_f64(result)
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use crate::types::CsrMatrix;
212
213    fn make_test_matrix() -> (CsrMatrix<f32>, Vec<f32>) {
214        // [2 0 1]   [1]   [5]
215        // [0 3 0] * [2] = [6]
216        // [1 0 4]   [3]   [13]
217        let mat = CsrMatrix {
218            values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
219            col_indices: vec![0, 2, 1, 0, 2],
220            row_ptr: vec![0, 2, 3, 5],
221            rows: 3,
222            cols: 3,
223        };
224        let x = vec![1.0, 2.0, 3.0];
225        (mat, x)
226    }
227
228    #[test]
229    fn scalar_spmv_correctness() {
230        let (mat, x) = make_test_matrix();
231        let mut y = vec![0.0f32; 3];
232        spmv_scalar(&mat, &x, &mut y);
233        assert!((y[0] - 5.0).abs() < 1e-6);
234        assert!((y[1] - 6.0).abs() < 1e-6);
235        assert!((y[2] - 13.0).abs() < 1e-6);
236    }
237
238    #[test]
239    fn spmv_simd_dispatch() {
240        let (mat, x) = make_test_matrix();
241        let mut y = vec![0.0f32; 3];
242        spmv_simd(&mat, &x, &mut y);
243        assert!((y[0] - 5.0).abs() < 1e-6);
244        assert!((y[1] - 6.0).abs() < 1e-6);
245        assert!((y[2] - 13.0).abs() < 1e-6);
246    }
247
248    #[test]
249    fn spmv_simd_f64_correctness() {
250        let mat = CsrMatrix::<f64> {
251            values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
252            col_indices: vec![0, 2, 1, 0, 2],
253            row_ptr: vec![0, 2, 3, 5],
254            rows: 3,
255            cols: 3,
256        };
257        let x = vec![1.0, 2.0, 3.0];
258        let mut y = vec![0.0f64; 3];
259        spmv_simd_f64(&mat, &x, &mut y);
260        assert!((y[0] - 5.0).abs() < 1e-10);
261        assert!((y[1] - 6.0).abs() < 1e-10);
262        assert!((y[2] - 13.0).abs() < 1e-10);
263    }
264
265    #[test]
266    fn scalar_spmv_f64_correctness() {
267        let mat = CsrMatrix::<f64> {
268            values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
269            col_indices: vec![0, 2, 1, 0, 2],
270            row_ptr: vec![0, 2, 3, 5],
271            rows: 3,
272            cols: 3,
273        };
274        let x = vec![1.0, 2.0, 3.0];
275        let mut y = vec![0.0f64; 3];
276        spmv_scalar_f64(&mat, &x, &mut y);
277        assert!((y[0] - 5.0).abs() < 1e-10);
278        assert!((y[1] - 6.0).abs() < 1e-10);
279        assert!((y[2] - 13.0).abs() < 1e-10);
280    }
281}