1use crate::types::CsrMatrix;
8
9pub fn spmv_simd(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
13 assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols");
14 assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows");
15
16 #[cfg(all(feature = "simd", target_arch = "x86_64"))]
17 {
18 if is_x86_feature_detected!("avx2") {
19 unsafe {
21 spmv_avx2(matrix, x, y);
22 }
23 return;
24 }
25 }
26
27 spmv_scalar(matrix, x, y);
28}
29
30pub fn spmv_scalar(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
32 for i in 0..matrix.rows {
33 let start = matrix.row_ptr[i];
34 let end = matrix.row_ptr[i + 1];
35 let mut sum = 0.0f32;
36 for idx in start..end {
37 let col = matrix.col_indices[idx];
38 sum += matrix.values[idx] * x[col];
39 }
40 y[i] = sum;
41 }
42}
43
44#[cfg(all(feature = "simd", target_arch = "x86_64"))]
58#[target_feature(enable = "avx2")]
59unsafe fn spmv_avx2(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
60 use std::arch::x86_64::*;
61
62 for i in 0..matrix.rows {
63 let start = matrix.row_ptr[i];
64 let end = matrix.row_ptr[i + 1];
65 let len = end - start;
66
67 let mut accum = _mm256_setzero_ps();
68 let chunks = len / 8;
69 let remainder = len % 8;
70
71 for chunk in 0..chunks {
72 let base = start + chunk * 8;
73
74 let vals = _mm256_loadu_ps(matrix.values.as_ptr().add(base));
77
78 let mut x_buf = [0.0f32; 8];
79 for k in 0..8 {
80 let col = *matrix.col_indices.get_unchecked(base + k);
84 x_buf[k] = *x.get_unchecked(col);
85 }
86 let x_vec = _mm256_loadu_ps(x_buf.as_ptr());
87
88 accum = _mm256_add_ps(accum, _mm256_mul_ps(vals, x_vec));
89 }
90
91 let mut sum = horizontal_sum_f32x8(accum);
92
93 let tail_start = start + chunks * 8;
94 for idx in tail_start..(tail_start + remainder) {
95 let col = *matrix.col_indices.get_unchecked(idx);
98 sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
99 }
100
101 *y.get_unchecked_mut(i) = sum;
103 }
104}
105
106#[cfg(all(feature = "simd", target_arch = "x86_64"))]
108#[target_feature(enable = "avx2")]
109unsafe fn horizontal_sum_f32x8(v: std::arch::x86_64::__m256) -> f32 {
110 use std::arch::x86_64::*;
111
112 let hi = _mm256_extractf128_ps(v, 1);
113 let lo = _mm256_castps256_ps128(v);
114 let sum128 = _mm_add_ps(lo, hi);
115
116 let shuf = _mm_movehdup_ps(sum128);
117 let sums = _mm_add_ps(sum128, shuf);
118 let shuf2 = _mm_movehl_ps(sums, sums);
119 let result = _mm_add_ss(sums, shuf2);
120 _mm_cvtss_f32(result)
121}
122
123pub fn spmv_simd_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
127 assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols");
128 assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows");
129
130 #[cfg(all(feature = "simd", target_arch = "x86_64"))]
131 {
132 if is_x86_feature_detected!("avx2") {
133 unsafe {
134 spmv_avx2_f64(matrix, x, y);
135 }
136 return;
137 }
138 }
139
140 spmv_scalar_f64(matrix, x, y);
141}
142
143pub fn spmv_scalar_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
145 for i in 0..matrix.rows {
146 let start = matrix.row_ptr[i];
147 let end = matrix.row_ptr[i + 1];
148 let mut sum = 0.0f64;
149 for idx in start..end {
150 let col = matrix.col_indices[idx];
151 sum += matrix.values[idx] * x[col];
152 }
153 y[i] = sum;
154 }
155}
156
157#[cfg(all(feature = "simd", target_arch = "x86_64"))]
158#[target_feature(enable = "avx2")]
159unsafe fn spmv_avx2_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
160 use std::arch::x86_64::*;
161
162 for i in 0..matrix.rows {
163 let start = matrix.row_ptr[i];
164 let end = matrix.row_ptr[i + 1];
165 let len = end - start;
166
167 let mut accum = _mm256_setzero_pd();
168 let chunks = len / 4;
169 let remainder = len % 4;
170
171 for chunk in 0..chunks {
172 let base = start + chunk * 4;
173 let vals = _mm256_loadu_pd(matrix.values.as_ptr().add(base));
174
175 let mut x_buf = [0.0f64; 4];
176 for k in 0..4 {
177 let col = *matrix.col_indices.get_unchecked(base + k);
178 x_buf[k] = *x.get_unchecked(col);
179 }
180 let x_vec = _mm256_loadu_pd(x_buf.as_ptr());
181 accum = _mm256_add_pd(accum, _mm256_mul_pd(vals, x_vec));
182 }
183
184 let mut sum = horizontal_sum_f64x4(accum);
185
186 let tail_start = start + chunks * 4;
187 for idx in tail_start..(tail_start + remainder) {
188 let col = *matrix.col_indices.get_unchecked(idx);
189 sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
190 }
191
192 *y.get_unchecked_mut(i) = sum;
193 }
194}
195
196#[cfg(all(feature = "simd", target_arch = "x86_64"))]
197#[target_feature(enable = "avx2")]
198unsafe fn horizontal_sum_f64x4(v: std::arch::x86_64::__m256d) -> f64 {
199 use std::arch::x86_64::*;
200 let hi = _mm256_extractf128_pd(v, 1);
201 let lo = _mm256_castpd256_pd128(v);
202 let sum128 = _mm_add_pd(lo, hi);
203 let hi64 = _mm_unpackhi_pd(sum128, sum128);
204 let result = _mm_add_sd(sum128, hi64);
205 _mm_cvtsd_f64(result)
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use crate::types::CsrMatrix;
212
213 fn make_test_matrix() -> (CsrMatrix<f32>, Vec<f32>) {
214 let mat = CsrMatrix {
218 values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
219 col_indices: vec![0, 2, 1, 0, 2],
220 row_ptr: vec![0, 2, 3, 5],
221 rows: 3,
222 cols: 3,
223 };
224 let x = vec![1.0, 2.0, 3.0];
225 (mat, x)
226 }
227
228 #[test]
229 fn scalar_spmv_correctness() {
230 let (mat, x) = make_test_matrix();
231 let mut y = vec![0.0f32; 3];
232 spmv_scalar(&mat, &x, &mut y);
233 assert!((y[0] - 5.0).abs() < 1e-6);
234 assert!((y[1] - 6.0).abs() < 1e-6);
235 assert!((y[2] - 13.0).abs() < 1e-6);
236 }
237
238 #[test]
239 fn spmv_simd_dispatch() {
240 let (mat, x) = make_test_matrix();
241 let mut y = vec![0.0f32; 3];
242 spmv_simd(&mat, &x, &mut y);
243 assert!((y[0] - 5.0).abs() < 1e-6);
244 assert!((y[1] - 6.0).abs() < 1e-6);
245 assert!((y[2] - 13.0).abs() < 1e-6);
246 }
247
248 #[test]
249 fn spmv_simd_f64_correctness() {
250 let mat = CsrMatrix::<f64> {
251 values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
252 col_indices: vec![0, 2, 1, 0, 2],
253 row_ptr: vec![0, 2, 3, 5],
254 rows: 3,
255 cols: 3,
256 };
257 let x = vec![1.0, 2.0, 3.0];
258 let mut y = vec![0.0f64; 3];
259 spmv_simd_f64(&mat, &x, &mut y);
260 assert!((y[0] - 5.0).abs() < 1e-10);
261 assert!((y[1] - 6.0).abs() < 1e-10);
262 assert!((y[2] - 13.0).abs() < 1e-10);
263 }
264
265 #[test]
266 fn scalar_spmv_f64_correctness() {
267 let mat = CsrMatrix::<f64> {
268 values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
269 col_indices: vec![0, 2, 1, 0, 2],
270 row_ptr: vec![0, 2, 3, 5],
271 rows: 3,
272 cols: 3,
273 };
274 let x = vec![1.0, 2.0, 3.0];
275 let mut y = vec![0.0f64; 3];
276 spmv_scalar_f64(&mat, &x, &mut y);
277 assert!((y[0] - 5.0).abs() < 1e-10);
278 assert!((y[1] - 6.0).abs() < 1e-10);
279 assert!((y[2] - 13.0).abs() < 1e-10);
280 }
281}