1use crate::error::{SparseError, SparseResult};
7use scirs2_core::gpu::{GpuBackend, GpuContext, GpuDataType};
8use scirs2_core::numeric::{Float, NumAssign, SparseElement};
9use scirs2_core::simd_ops::SimdUnifiedOps;
10use std::fmt::Debug;
11
12pub struct GpuSpMV {
14 #[allow(dead_code)]
15 context: GpuContext,
16 backend: GpuBackend,
17}
18
19impl GpuSpMV {
20 pub fn new() -> SparseResult<Self> {
22 let (context, backend) = Self::initialize_best_backend()?;
24
25 Ok(Self { context, backend })
26 }
27
28 pub fn with_backend(backend: GpuBackend) -> SparseResult<Self> {
30 let context = GpuContext::new(backend).map_err(|e| {
31 SparseError::ComputationError(format!("Failed to initialize GPU context: {e}"))
32 })?;
33
34 Ok(Self { context, backend })
35 }
36
37 fn initialize_best_backend() -> SparseResult<(GpuContext, GpuBackend)> {
39 let backends_to_try = [
41 GpuBackend::Cuda, GpuBackend::Metal, GpuBackend::OpenCL, GpuBackend::Cpu, ];
46
47 for &backend in &backends_to_try {
48 if let Ok(context) = GpuContext::new(backend) {
49 return Ok((context, backend));
50 }
51 }
52
53 Err(SparseError::ComputationError(
54 "No GPU backend available".to_string(),
55 ))
56 }
57
58 #[allow(clippy::too_many_arguments)]
60 pub fn spmv<T>(
61 &self,
62 rows: usize,
63 cols: usize,
64 indptr: &[usize],
65 indices: &[usize],
66 data: &[T],
67 x: &[T],
68 ) -> SparseResult<Vec<T>>
69 where
70 T: Float
71 + SparseElement
72 + Debug
73 + Copy
74 + Default
75 + GpuDataType
76 + Send
77 + Sync
78 + 'static
79 + NumAssign
80 + SimdUnifiedOps
81 + std::iter::Sum,
82 {
83 self.validate_spmv_inputs(rows, cols, indptr, indices, data, x)?;
85
86 match self.backend {
88 GpuBackend::Cuda => self.spmv_cuda(rows, indptr, indices, data, x),
89 GpuBackend::OpenCL => self.spmv_opencl(rows, indptr, indices, data, x),
90 GpuBackend::Metal => self.spmv_metal(rows, indptr, indices, data, x),
91 GpuBackend::Cpu => self.spmv_cpu_optimized(rows, indptr, indices, data, x),
92 GpuBackend::Rocm | GpuBackend::Wgpu => {
93 self.spmv_cpu_optimized(rows, indptr, indices, data, x)
95 }
96 }
97 }
98
99 fn validate_spmv_inputs<T>(
101 &self,
102 rows: usize,
103 cols: usize,
104 indptr: &[usize],
105 indices: &[usize],
106 data: &[T],
107 x: &[T],
108 ) -> SparseResult<()>
109 where
110 T: Float + SparseElement + Debug,
111 {
112 if indptr.len() != rows + 1 {
113 return Err(SparseError::InvalidFormat(format!(
114 "indptr length {} does not match rows + 1 = {}",
115 indptr.len(),
116 rows + 1
117 )));
118 }
119
120 if indices.len() != data.len() {
121 return Err(SparseError::InvalidFormat(format!(
122 "indices length {} does not match data length {}",
123 indices.len(),
124 data.len()
125 )));
126 }
127
128 if x.len() != cols {
129 return Err(SparseError::InvalidFormat(format!(
130 "x length {} does not match cols {}",
131 x.len(),
132 cols
133 )));
134 }
135
136 for &idx in indices {
138 if idx >= cols {
139 return Err(SparseError::InvalidFormat(format!(
140 "Column index {idx} exceeds cols {cols}"
141 )));
142 }
143 }
144
145 Ok(())
146 }
147
148 fn spmv_cuda<T>(
150 &self,
151 rows: usize,
152 indptr: &[usize],
153 indices: &[usize],
154 data: &[T],
155 x: &[T],
156 ) -> SparseResult<Vec<T>>
157 where
158 T: Float
159 + SparseElement
160 + Debug
161 + Copy
162 + Default
163 + GpuDataType
164 + Send
165 + Sync
166 + 'static
167 + NumAssign
168 + SimdUnifiedOps
169 + std::iter::Sum,
170 {
171 #[cfg(feature = "gpu")]
172 {
173 use crate::gpu_ops::{GpuBufferExt, SpMVKernel};
174
175 let indptr_buffer = self.context.create_buffer_from_slice(indptr);
177 let indices_buffer = self.context.create_buffer_from_slice(indices);
178 let data_buffer = self.context.create_buffer_from_slice(data);
179 let x_buffer = self.context.create_buffer_from_slice(x);
180 let mut y_buffer = self.context.create_buffer::<T>(rows);
181
182 use crate::csr_array::CsrArray;
184 use crate::gpu::GpuSpMatVec;
185
186 let csr_matrix = CsrArray::new(
188 data.to_vec().into(),
189 indices.to_vec().into(),
190 indptr.to_vec().into(),
191 (rows, x.len()),
192 )?;
193
194 let gpu_handler = GpuSpMatVec::with_backend(self.backend)?;
195 let result = gpu_handler.spmv(
196 &csr_matrix,
197 &scirs2_core::ndarray::ArrayView1::from(x),
198 None,
199 )?;
200 Ok(result.to_vec())
201 }
202
203 #[cfg(not(feature = "gpu"))]
204 {
205 self.spmv_cpu_optimized(rows, indptr, indices, data, x)
207 }
208 }
209
210 fn spmv_opencl<T>(
212 &self,
213 rows: usize,
214 indptr: &[usize],
215 indices: &[usize],
216 data: &[T],
217 x: &[T],
218 ) -> SparseResult<Vec<T>>
219 where
220 T: Float
221 + SparseElement
222 + Debug
223 + Copy
224 + Default
225 + GpuDataType
226 + Send
227 + Sync
228 + 'static
229 + NumAssign
230 + SimdUnifiedOps
231 + std::iter::Sum,
232 {
233 #[cfg(feature = "gpu")]
234 {
235 use crate::gpu_ops::{GpuBufferExt, SpMVKernel};
236
237 use crate::csr_array::CsrArray;
239 use crate::gpu::GpuSpMatVec;
240
241 let csr_matrix = CsrArray::new(
243 data.to_vec().into(),
244 indices.to_vec().into(),
245 indptr.to_vec().into(),
246 (rows, x.len()),
247 )?;
248
249 let gpu_handler = GpuSpMatVec::with_backend(self.backend)?;
250 let result = gpu_handler.spmv(
251 &csr_matrix,
252 &scirs2_core::ndarray::ArrayView1::from(x),
253 None,
254 )?;
255 Ok(result.to_vec())
256 }
257
258 #[cfg(not(feature = "gpu"))]
259 {
260 self.spmv_cpu_optimized(rows, indptr, indices, data, x)
262 }
263 }
264
265 fn spmv_metal<T>(
267 &self,
268 rows: usize,
269 indptr: &[usize],
270 indices: &[usize],
271 data: &[T],
272 x: &[T],
273 ) -> SparseResult<Vec<T>>
274 where
275 T: Float
276 + SparseElement
277 + Debug
278 + Copy
279 + Default
280 + GpuDataType
281 + Send
282 + Sync
283 + 'static
284 + NumAssign
285 + SimdUnifiedOps
286 + std::iter::Sum,
287 {
288 #[cfg(feature = "gpu")]
289 {
290 use crate::gpu_ops::{GpuBufferExt, SpMVKernel};
291
292 let indptr_buffer = self.context.create_buffer_from_slice(indptr);
294 let indices_buffer = self.context.create_buffer_from_slice(indices);
295 let data_buffer = self.context.create_buffer_from_slice(data);
296 let x_buffer = self.context.create_buffer_from_slice(x);
297 let mut y_buffer = self.context.create_buffer::<T>(rows);
298
299 use crate::csr_array::CsrArray;
301 use crate::gpu::GpuSpMatVec;
302
303 let csr_matrix = CsrArray::new(
305 data.to_vec().into(),
306 indices.to_vec().into(),
307 indptr.to_vec().into(),
308 (rows, x.len()),
309 )?;
310
311 let gpu_handler = GpuSpMatVec::with_backend(self.backend)?;
312 let result = gpu_handler.spmv(
313 &csr_matrix,
314 &scirs2_core::ndarray::ArrayView1::from(x),
315 None,
316 )?;
317 Ok(result.to_vec())
318 }
319
320 #[cfg(not(feature = "gpu"))]
321 {
322 self.spmv_cpu_optimized(rows, indptr, indices, data, x)
324 }
325 }
326
327 fn spmv_cpu_optimized<T>(
329 &self,
330 rows: usize,
331 indptr: &[usize],
332 indices: &[usize],
333 data: &[T],
334 x: &[T],
335 ) -> SparseResult<Vec<T>>
336 where
337 T: Float
338 + SparseElement
339 + Debug
340 + Copy
341 + Default
342 + Send
343 + Sync
344 + NumAssign
345 + SimdUnifiedOps,
346 {
347 let mut y = vec![T::sparse_zero(); rows];
348
349 #[cfg(feature = "parallel")]
351 {
352 use crate::parallel_vector_ops::parallel_sparse_matvec_csr;
353 parallel_sparse_matvec_csr(&mut y, rows, indptr, indices, data, x, None);
354 }
355
356 #[cfg(not(feature = "parallel"))]
357 {
358 for row in 0..rows {
359 let mut sum = T::sparse_zero();
360 let start = indptr[row];
361 let end = indptr[row + 1];
362
363 for idx in start..end {
364 let col = indices[idx];
365 sum = sum + data[idx] * x[col];
366 }
367 y[row] = sum;
368 }
369 }
370
371 Ok(y)
372 }
373
374 #[allow(dead_code)]
376 fn get_cuda_spmv_kernel_source(&self) -> String {
377 r#"
378 extern "C" _global_ void spmv_csr_kernel(
379 int rows,
380 const int* _restrict_ indptr,
381 const int* _restrict_ indices,
382 const float* _restrict_ data,
383 const float* _restrict_ x,
384 float* _restrict_ y
385 ) {
386 int row = blockIdx.x * blockDim.x + threadIdx.x;
387 if (row >= rows) return;
388
389 float sum = 0.0f;
390 int start = indptr[row];
391 int end = indptr[row + 1];
392
393 // Optimized loop with memory coalescing
394 for (int j = start; j < end; j++) {
395 sum += data[j] * x[indices[j]];
396 }
397
398 y[row] = sum;
399 }
400 "#
401 .to_string()
402 }
403
404 #[allow(dead_code)]
406 fn get_opencl_spmv_kernel_source(&self) -> String {
407 r#"
408 _kernel void spmv_csr_kernel(
409 const int rowsglobal const int* restrict indptr_global const int* restrict indices_global const float* restrict data_global const float* restrict x_global float* restrict y
410 ) {
411 int row = get_global_id(0);
412 if (row >= rows) return;
413
414 float sum = 0.0f;
415 int start = indptr[row];
416 int end = indptr[row + 1];
417
418 // Vectorized loop with memory coalescing
419 for (int j = start; j < end; j++) {
420 sum += data[j] * x[indices[j]];
421 }
422
423 y[row] = sum;
424 }
425 "#
426 .to_string()
427 }
428
429 #[allow(dead_code)]
431 fn get_metal_spmv_kernel_source(&self) -> String {
432 r#"
433 #include <metal_stdlib>
434 using namespace metal;
435
436 kernel void spmv_csr_kernel(
437 constant int& rows [[buffer(0)]],
438 constant int* indptr [[buffer(1)]],
439 constant int* indices [[buffer(2)]],
440 constant float* data [[buffer(3)]],
441 constant float* x [[buffer(4)]],
442 device float* y [[buffer(5)]],
443 uint row [[thread_position_in_grid]]
444 ) {
445 if (row >= rows) return;
446
447 float sum = 0.0f;
448 int start = indptr[row];
449 int end = indptr[row + 1];
450
451 // Vectorized loop optimized for Metal
452 for (int j = start; j < end; j++) {
453 sum += data[j] * x[indices[j]];
454 }
455
456 y[row] = sum;
457 }
458 "#
459 .to_string()
460 }
461
462 pub fn backend_info(&self) -> (GpuBackend, String) {
464 let backend_name = match self.backend {
465 GpuBackend::Cuda => "NVIDIA CUDA",
466 GpuBackend::OpenCL => "OpenCL",
467 GpuBackend::Metal => "Apple Metal",
468 GpuBackend::Cpu => "CPU Fallback",
469 GpuBackend::Rocm => "AMD ROCm",
470 GpuBackend::Wgpu => "WebGPU",
471 };
472
473 (self.backend, backend_name.to_string())
474 }
475}
476
477impl Default for GpuSpMV {
478 fn default() -> Self {
479 Self::new().unwrap_or_else(|_| {
480 Self {
482 context: GpuContext::new(GpuBackend::Cpu).unwrap(),
483 backend: GpuBackend::Cpu,
484 }
485 })
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 #[test]
494 fn test_gpu_spmv_creation() {
495 let gpu_spmv = GpuSpMV::new();
496 assert!(
497 gpu_spmv.is_ok(),
498 "Should be able to create GPU SpMV instance"
499 );
500 }
501
502 #[test]
503 fn test_cpu_fallback_spmv() {
504 let gpu_spmv = GpuSpMV::with_backend(GpuBackend::Cpu).unwrap();
505
506 let indptr = vec![0, 2, 3];
508 let indices = vec![0, 1, 1];
509 let data = vec![1.0, 2.0, 3.0];
510 let x = vec![1.0, 1.0];
511
512 let result = gpu_spmv.spmv(2, 2, &indptr, &indices, &data, &x).unwrap();
513 assert_eq!(result, vec![3.0, 3.0]); }
515
516 #[test]
517 fn test_backend_info() {
518 let gpu_spmv = GpuSpMV::default();
519 let (_backend, name) = gpu_spmv.backend_info();
520 assert!(!name.is_empty(), "Backend name should not be empty");
521 }
522}