1use crate::types::CsrMatrix;
8
9pub fn spmv_simd(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
13 assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols");
14 assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows");
15
16 #[cfg(all(feature = "simd", target_arch = "x86_64"))]
17 {
18 if is_x86_feature_detected!("avx2") {
19 unsafe {
21 spmv_avx2(matrix, x, y);
22 }
23 return;
24 }
25 }
26
27 #[cfg(target_arch = "aarch64")]
28 {
29 unsafe {
30 spmv_neon_f32(matrix, x, y);
31 }
32 return;
33 }
34
35 #[allow(unreachable_code)]
36 spmv_scalar(matrix, x, y);
37}
38
39pub fn spmv_scalar(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
41 for i in 0..matrix.rows {
42 let start = matrix.row_ptr[i];
43 let end = matrix.row_ptr[i + 1];
44 let mut sum = 0.0f32;
45 for idx in start..end {
46 let col = matrix.col_indices[idx];
47 sum += matrix.values[idx] * x[col];
48 }
49 y[i] = sum;
50 }
51}
52
53#[cfg(all(feature = "simd", target_arch = "x86_64"))]
67#[target_feature(enable = "avx2")]
68unsafe fn spmv_avx2(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
69 use std::arch::x86_64::*;
70
71 for i in 0..matrix.rows {
72 let start = matrix.row_ptr[i];
73 let end = matrix.row_ptr[i + 1];
74 let len = end - start;
75
76 let mut accum = _mm256_setzero_ps();
77 let chunks = len / 8;
78 let remainder = len % 8;
79
80 for chunk in 0..chunks {
81 let base = start + chunk * 8;
82
83 let vals = _mm256_loadu_ps(matrix.values.as_ptr().add(base));
86
87 let mut x_buf = [0.0f32; 8];
88 for k in 0..8 {
89 let col = *matrix.col_indices.get_unchecked(base + k);
93 x_buf[k] = *x.get_unchecked(col);
94 }
95 let x_vec = _mm256_loadu_ps(x_buf.as_ptr());
96
97 accum = _mm256_add_ps(accum, _mm256_mul_ps(vals, x_vec));
98 }
99
100 let mut sum = horizontal_sum_f32x8(accum);
101
102 let tail_start = start + chunks * 8;
103 for idx in tail_start..(tail_start + remainder) {
104 let col = *matrix.col_indices.get_unchecked(idx);
107 sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
108 }
109
110 *y.get_unchecked_mut(i) = sum;
112 }
113}
114
115#[cfg(all(feature = "simd", target_arch = "x86_64"))]
117#[target_feature(enable = "avx2")]
118unsafe fn horizontal_sum_f32x8(v: std::arch::x86_64::__m256) -> f32 {
119 use std::arch::x86_64::*;
120
121 let hi = _mm256_extractf128_ps(v, 1);
122 let lo = _mm256_castps256_ps128(v);
123 let sum128 = _mm_add_ps(lo, hi);
124
125 let shuf = _mm_movehdup_ps(sum128);
126 let sums = _mm_add_ps(sum128, shuf);
127 let shuf2 = _mm_movehl_ps(sums, sums);
128 let result = _mm_add_ss(sums, shuf2);
129 _mm_cvtss_f32(result)
130}
131
132pub fn spmv_simd_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
136 assert_eq!(x.len(), matrix.cols, "x length must equal matrix.cols");
137 assert_eq!(y.len(), matrix.rows, "y length must equal matrix.rows");
138
139 #[cfg(all(feature = "simd", target_arch = "x86_64"))]
140 {
141 if is_x86_feature_detected!("avx2") {
142 unsafe {
143 spmv_avx2_f64(matrix, x, y);
144 }
145 return;
146 }
147 }
148
149 #[cfg(target_arch = "aarch64")]
150 {
151 unsafe {
152 spmv_neon_f64(matrix, x, y);
153 }
154 return;
155 }
156
157 #[allow(unreachable_code)]
158 spmv_scalar_f64(matrix, x, y);
159}
160
161pub fn spmv_scalar_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
163 for i in 0..matrix.rows {
164 let start = matrix.row_ptr[i];
165 let end = matrix.row_ptr[i + 1];
166 let mut sum = 0.0f64;
167 for idx in start..end {
168 let col = matrix.col_indices[idx];
169 sum += matrix.values[idx] * x[col];
170 }
171 y[i] = sum;
172 }
173}
174
175#[cfg(all(feature = "simd", target_arch = "x86_64"))]
176#[target_feature(enable = "avx2")]
177unsafe fn spmv_avx2_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
178 use std::arch::x86_64::*;
179
180 for i in 0..matrix.rows {
181 let start = matrix.row_ptr[i];
182 let end = matrix.row_ptr[i + 1];
183 let len = end - start;
184
185 let mut accum = _mm256_setzero_pd();
186 let chunks = len / 4;
187 let remainder = len % 4;
188
189 for chunk in 0..chunks {
190 let base = start + chunk * 4;
191 let vals = _mm256_loadu_pd(matrix.values.as_ptr().add(base));
192
193 let mut x_buf = [0.0f64; 4];
194 for k in 0..4 {
195 let col = *matrix.col_indices.get_unchecked(base + k);
196 x_buf[k] = *x.get_unchecked(col);
197 }
198 let x_vec = _mm256_loadu_pd(x_buf.as_ptr());
199 accum = _mm256_add_pd(accum, _mm256_mul_pd(vals, x_vec));
200 }
201
202 let mut sum = horizontal_sum_f64x4(accum);
203
204 let tail_start = start + chunks * 4;
205 for idx in tail_start..(tail_start + remainder) {
206 let col = *matrix.col_indices.get_unchecked(idx);
207 sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
208 }
209
210 *y.get_unchecked_mut(i) = sum;
211 }
212}
213
214#[cfg(all(feature = "simd", target_arch = "x86_64"))]
215#[target_feature(enable = "avx2")]
216unsafe fn horizontal_sum_f64x4(v: std::arch::x86_64::__m256d) -> f64 {
217 use std::arch::x86_64::*;
218 let hi = _mm256_extractf128_pd(v, 1);
219 let lo = _mm256_castpd256_pd128(v);
220 let sum128 = _mm_add_pd(lo, hi);
221 let hi64 = _mm_unpackhi_pd(sum128, sum128);
222 let result = _mm_add_sd(sum128, hi64);
223 _mm_cvtsd_f64(result)
224}
225
226#[cfg(target_arch = "aarch64")]
238unsafe fn spmv_neon_f32(matrix: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
239 use std::arch::aarch64::*;
240
241 for i in 0..matrix.rows {
242 let start = matrix.row_ptr[i];
243 let end = matrix.row_ptr[i + 1];
244 let len = end - start;
245
246 let mut acc0 = vdupq_n_f32(0.0);
247 let mut acc1 = vdupq_n_f32(0.0);
248 let chunks = len / 8;
249 let mid_remainder = (len % 8) / 4;
250 let tail_remainder = len % 4;
251
252 for chunk in 0..chunks {
254 let base = start + chunk * 8;
255 let v0 = vld1q_f32(matrix.values.as_ptr().add(base));
259 let v1 = vld1q_f32(matrix.values.as_ptr().add(base + 4));
260
261 let mut xbuf0 = [0.0f32; 4];
263 let mut xbuf1 = [0.0f32; 4];
264 for k in 0..4 {
265 xbuf0[k] = *x.get_unchecked(*matrix.col_indices.get_unchecked(base + k));
266 xbuf1[k] = *x.get_unchecked(*matrix.col_indices.get_unchecked(base + 4 + k));
267 }
268 let x0 = vld1q_f32(xbuf0.as_ptr());
269 let x1 = vld1q_f32(xbuf1.as_ptr());
270
271 acc0 = vfmaq_f32(acc0, v0, x0);
272 acc1 = vfmaq_f32(acc1, v1, x1);
273 }
274
275 let mid_start = start + chunks * 8;
277 if mid_remainder > 0 {
278 let v0 = vld1q_f32(matrix.values.as_ptr().add(mid_start));
279 let mut xbuf = [0.0f32; 4];
280 for k in 0..4 {
281 xbuf[k] = *x.get_unchecked(*matrix.col_indices.get_unchecked(mid_start + k));
282 }
283 let x0 = vld1q_f32(xbuf.as_ptr());
284 acc0 = vfmaq_f32(acc0, v0, x0);
285 }
286
287 let combined = vaddq_f32(acc0, acc1);
289 let mut sum = vaddvq_f32(combined);
290
291 let tail_start = start + len - tail_remainder;
293 for idx in tail_start..end {
294 let col = *matrix.col_indices.get_unchecked(idx);
295 sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
296 }
297
298 *y.get_unchecked_mut(i) = sum;
299 }
300}
301
302#[cfg(target_arch = "aarch64")]
310unsafe fn spmv_neon_f64(matrix: &CsrMatrix<f64>, x: &[f64], y: &mut [f64]) {
311 use std::arch::aarch64::*;
312
313 for i in 0..matrix.rows {
314 let start = matrix.row_ptr[i];
315 let end = matrix.row_ptr[i + 1];
316 let len = end - start;
317
318 let mut acc0 = vdupq_n_f64(0.0);
319 let mut acc1 = vdupq_n_f64(0.0);
320 let chunks = len / 4;
321 let remainder = len % 4;
322
323 for chunk in 0..chunks {
325 let base = start + chunk * 4;
326 let v0 = vld1q_f64(matrix.values.as_ptr().add(base));
329 let v1 = vld1q_f64(matrix.values.as_ptr().add(base + 2));
330
331 let mut xbuf0 = [0.0f64; 2];
332 let mut xbuf1 = [0.0f64; 2];
333 for k in 0..2 {
334 xbuf0[k] = *x.get_unchecked(*matrix.col_indices.get_unchecked(base + k));
335 xbuf1[k] = *x.get_unchecked(*matrix.col_indices.get_unchecked(base + 2 + k));
336 }
337 let x0 = vld1q_f64(xbuf0.as_ptr());
338 let x1 = vld1q_f64(xbuf1.as_ptr());
339
340 acc0 = vfmaq_f64(acc0, v0, x0);
341 acc1 = vfmaq_f64(acc1, v1, x1);
342 }
343
344 let combined = vaddq_f64(acc0, acc1);
345 let mut sum = vgetq_lane_f64(combined, 0) + vgetq_lane_f64(combined, 1);
346
347 let tail_start = start + chunks * 4;
349 for idx in tail_start..(tail_start + remainder) {
350 let col = *matrix.col_indices.get_unchecked(idx);
351 sum += *matrix.values.get_unchecked(idx) * *x.get_unchecked(col);
352 }
353
354 *y.get_unchecked_mut(i) = sum;
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::types::CsrMatrix;
362
363 fn make_test_matrix() -> (CsrMatrix<f32>, Vec<f32>) {
364 let mat = CsrMatrix {
368 values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
369 col_indices: vec![0, 2, 1, 0, 2],
370 row_ptr: vec![0, 2, 3, 5],
371 rows: 3,
372 cols: 3,
373 };
374 let x = vec![1.0, 2.0, 3.0];
375 (mat, x)
376 }
377
378 #[test]
379 fn scalar_spmv_correctness() {
380 let (mat, x) = make_test_matrix();
381 let mut y = vec![0.0f32; 3];
382 spmv_scalar(&mat, &x, &mut y);
383 assert!((y[0] - 5.0).abs() < 1e-6);
384 assert!((y[1] - 6.0).abs() < 1e-6);
385 assert!((y[2] - 13.0).abs() < 1e-6);
386 }
387
388 #[test]
389 fn spmv_simd_dispatch() {
390 let (mat, x) = make_test_matrix();
391 let mut y = vec![0.0f32; 3];
392 spmv_simd(&mat, &x, &mut y);
393 assert!((y[0] - 5.0).abs() < 1e-6);
394 assert!((y[1] - 6.0).abs() < 1e-6);
395 assert!((y[2] - 13.0).abs() < 1e-6);
396 }
397
398 #[test]
399 fn spmv_simd_f64_correctness() {
400 let mat = CsrMatrix::<f64> {
401 values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
402 col_indices: vec![0, 2, 1, 0, 2],
403 row_ptr: vec![0, 2, 3, 5],
404 rows: 3,
405 cols: 3,
406 };
407 let x = vec![1.0, 2.0, 3.0];
408 let mut y = vec![0.0f64; 3];
409 spmv_simd_f64(&mat, &x, &mut y);
410 assert!((y[0] - 5.0).abs() < 1e-10);
411 assert!((y[1] - 6.0).abs() < 1e-10);
412 assert!((y[2] - 13.0).abs() < 1e-10);
413 }
414
415 #[test]
416 fn scalar_spmv_f64_correctness() {
417 let mat = CsrMatrix::<f64> {
418 values: vec![2.0, 1.0, 3.0, 1.0, 4.0],
419 col_indices: vec![0, 2, 1, 0, 2],
420 row_ptr: vec![0, 2, 3, 5],
421 rows: 3,
422 cols: 3,
423 };
424 let x = vec![1.0, 2.0, 3.0];
425 let mut y = vec![0.0f64; 3];
426 spmv_scalar_f64(&mat, &x, &mut y);
427 assert!((y[0] - 5.0).abs() < 1e-10);
428 assert!((y[1] - 6.0).abs() < 1e-10);
429 assert!((y[2] - 13.0).abs() < 1e-10);
430 }
431}