1use crate::matrix::sparse::{COOStorage, CSRStorage};
7#[cfg(feature = "simd")]
8use crate::simd_ops::{axpy_simd, dot_product_simd, matrix_vector_multiply_simd};
9use crate::types::Precision;
10use alloc::vec::Vec;
11use core::sync::atomic::{AtomicUsize, Ordering};
12
13#[cfg(feature = "std")]
14use std::time::Instant;
15
16pub struct OptimizedSparseMatrix {
18 storage: CSRStorage,
19 dimensions: (usize, usize),
20 performance_stats: PerformanceStats,
21}
22
23#[derive(Debug, Default)]
25pub struct PerformanceStats {
26 pub matvec_count: AtomicUsize,
27 pub bytes_processed: AtomicUsize,
28}
29
30impl Clone for PerformanceStats {
31 fn clone(&self) -> Self {
32 Self {
33 matvec_count: AtomicUsize::new(self.matvec_count.load(Ordering::Relaxed)),
34 bytes_processed: AtomicUsize::new(self.bytes_processed.load(Ordering::Relaxed)),
35 }
36 }
37}
38
39impl OptimizedSparseMatrix {
40 pub fn from_triplets(
42 triplets: Vec<(usize, usize, Precision)>,
43 rows: usize,
44 cols: usize,
45 ) -> Result<Self, String> {
46 let coo = COOStorage::from_triplets(triplets)
47 .map_err(|e| format!("Failed to create COO storage: {:?}", e))?;
48 let storage = CSRStorage::from_coo(&coo, rows, cols)
49 .map_err(|e| format!("Failed to create CSR storage: {:?}", e))?;
50
51 Ok(Self {
52 storage,
53 dimensions: (rows, cols),
54 performance_stats: PerformanceStats::default(),
55 })
56 }
57
58 pub fn dimensions(&self) -> (usize, usize) {
60 self.dimensions
61 }
62
63 pub fn nnz(&self) -> usize {
65 self.storage.nnz()
66 }
67
68 pub fn multiply_vector(&self, x: &[Precision], y: &mut [Precision]) {
70 assert_eq!(x.len(), self.dimensions.1);
71 assert_eq!(y.len(), self.dimensions.0);
72
73 self.performance_stats
74 .matvec_count
75 .fetch_add(1, Ordering::Relaxed);
76 let bytes = (self.storage.values.len() * 8) + (x.len() * 8) + (y.len() * 8);
77 self.performance_stats
78 .bytes_processed
79 .fetch_add(bytes, Ordering::Relaxed);
80
81 #[cfg(feature = "simd")]
82 {
83 matrix_vector_multiply_simd(
84 &self.storage.values,
85 &self.storage.col_indices,
86 &self.storage.row_ptr,
87 x,
88 y,
89 );
90 }
91 #[cfg(not(feature = "simd"))]
92 {
93 self.storage.multiply_vector(x, y);
94 }
95 }
96
97 pub fn get_performance_stats(&self) -> (usize, usize) {
99 (
100 self.performance_stats.matvec_count.load(Ordering::Relaxed),
101 self.performance_stats
102 .bytes_processed
103 .load(Ordering::Relaxed),
104 )
105 }
106
107 pub fn reset_stats(&self) {
109 self.performance_stats
110 .matvec_count
111 .store(0, Ordering::Relaxed);
112 self.performance_stats
113 .bytes_processed
114 .store(0, Ordering::Relaxed);
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct OptimizedSolverConfig {
121 pub max_iterations: usize,
123 pub tolerance: Precision,
125 pub enable_profiling: bool,
127}
128
129impl Default for OptimizedSolverConfig {
130 fn default() -> Self {
131 Self {
132 max_iterations: 1000,
133 tolerance: 1e-6,
134 enable_profiling: false,
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct OptimizedSolverResult {
142 pub solution: Vec<Precision>,
144 pub residual_norm: Precision,
146 pub iterations: usize,
148 pub converged: bool,
150 #[cfg(feature = "std")]
152 pub computation_time_ms: f64,
153 #[cfg(not(feature = "std"))]
154 pub computation_time_ms: u64,
155 pub performance_stats: OptimizedSolverStats,
157}
158
159#[derive(Debug, Clone, Default)]
161pub struct OptimizedSolverStats {
162 pub matvec_count: usize,
164 pub dot_product_count: usize,
166 pub axpy_count: usize,
168 pub total_flops: usize,
170 pub average_bandwidth_gbs: f64,
172 pub average_gflops: f64,
174}
175
176pub struct OptimizedConjugateGradientSolver {
178 config: OptimizedSolverConfig,
179 stats: OptimizedSolverStats,
180}
181
182impl OptimizedConjugateGradientSolver {
183 pub fn new(config: OptimizedSolverConfig) -> Self {
185 Self {
186 config,
187 stats: OptimizedSolverStats::default(),
188 }
189 }
190
191 pub fn solve(
193 &mut self,
194 matrix: &OptimizedSparseMatrix,
195 b: &[Precision],
196 ) -> Result<OptimizedSolverResult, String> {
197 let (rows, cols) = matrix.dimensions();
198 if rows != cols {
199 return Err("Matrix must be square".to_string());
200 }
201 if b.len() != rows {
202 return Err("Right-hand side vector length must match matrix size".to_string());
203 }
204
205 #[cfg(feature = "std")]
206 let start_time = Instant::now();
207
208 self.stats = OptimizedSolverStats::default();
210
211 let mut x = vec![0.0; rows];
213 let mut r = vec![0.0; rows];
214 let mut p = vec![0.0; rows];
215 let mut ap = vec![0.0; rows];
216
217 r.copy_from_slice(b);
219
220 let mut iteration = 0;
221 let tolerance_sq = self.config.tolerance * self.config.tolerance;
222 let mut converged = false;
223
224 let mut rsold = self.dot_product(&r, &r);
227 p.copy_from_slice(&r);
228
229 while iteration < self.config.max_iterations {
230 if rsold <= tolerance_sq {
231 converged = true;
232 break;
233 }
234
235 matrix.multiply_vector(&p, &mut ap);
237 self.stats.matvec_count += 1;
238
239 let pap = self.dot_product(&p, &ap);
241
242 if pap.abs() < 1e-16 {
243 break; }
245
246 let alpha = rsold / pap;
247
248 self.axpy(alpha, &p, &mut x);
250
251 self.axpy(-alpha, &ap, &mut r);
253
254 let rsnew = self.dot_product(&r, &r);
255
256 let beta = rsnew / rsold;
257
258 for (pi, &ri) in p.iter_mut().zip(r.iter()) {
260 *pi = ri + beta * *pi;
261 }
262
263 rsold = rsnew;
264 iteration += 1;
265 }
266
267 #[cfg(feature = "std")]
268 let computation_time_ms = start_time.elapsed().as_millis() as f64;
269 #[cfg(not(feature = "std"))]
270 let computation_time_ms = 0.0;
271
272 let final_residual_norm = rsold.sqrt();
274
275 self.stats.total_flops = self.stats.matvec_count * matrix.nnz() * 2 + iteration * rows * 6; if computation_time_ms > 0.0 {
279 let total_gb = (self.stats.total_flops * 8) as f64 / 1e9;
280 self.stats.average_bandwidth_gbs = total_gb / (computation_time_ms / 1000.0);
281 self.stats.average_gflops =
282 (self.stats.total_flops as f64) / (computation_time_ms * 1e6);
283 }
284
285 Ok(OptimizedSolverResult {
286 solution: x,
287 residual_norm: final_residual_norm,
288 iterations: iteration,
289 converged,
290 computation_time_ms,
291 performance_stats: self.stats.clone(),
292 })
293 }
294
295 fn dot_product(&mut self, x: &[Precision], y: &[Precision]) -> Precision {
297 self.stats.dot_product_count += 1;
298 #[cfg(feature = "simd")]
299 {
300 dot_product_simd(x, y)
301 }
302 #[cfg(not(feature = "simd"))]
303 {
304 x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum()
305 }
306 }
307
308 fn axpy(&mut self, alpha: Precision, x: &[Precision], y: &mut [Precision]) {
310 self.stats.axpy_count += 1;
311 #[cfg(feature = "simd")]
312 {
313 axpy_simd(alpha, x, y);
314 }
315 #[cfg(not(feature = "simd"))]
316 {
317 for (yi, &xi) in y.iter_mut().zip(x.iter()) {
318 *yi += alpha * xi;
319 }
320 }
321 }
322
323 fn l2_norm(&self, x: &[Precision]) -> Precision {
325 x.iter().map(|&xi| xi * xi).sum::<Precision>().sqrt()
326 }
327
328 pub fn get_last_iteration_count(&self) -> usize {
330 self.stats.matvec_count
331 }
332
333 pub fn solve_with_callback<F>(
335 &mut self,
336 matrix: &OptimizedSparseMatrix,
337 b: &[Precision],
338 _chunk_size: usize,
339 mut _callback: F,
340 ) -> Result<OptimizedSolverResult, String>
341 where
342 F: FnMut(&OptimizedSolverStats),
343 {
344 self.solve(matrix, b)
347 }
348}
349
350impl OptimizedSolverResult {
351 pub fn data(&self) -> &[Precision] {
353 &self.solution
354 }
355}
356
357#[derive(Debug, Clone, Default)]
359pub struct OptimizedSolverOptions {
360 pub track_performance: bool,
362 pub track_memory: bool,
364}
365
366#[cfg(all(test, feature = "std"))]
367mod tests {
368 use super::*;
369
370 fn create_test_matrix() -> OptimizedSparseMatrix {
371 let triplets = vec![(0, 0, 4.0), (0, 1, 1.0), (1, 0, 1.0), (1, 1, 3.0)];
373 OptimizedSparseMatrix::from_triplets(triplets, 2, 2).unwrap()
374 }
375
376 #[test]
377 fn test_optimized_matrix_creation() {
378 let matrix = create_test_matrix();
379 assert_eq!(matrix.dimensions(), (2, 2));
380 assert_eq!(matrix.nnz(), 4);
381 }
382
383 #[test]
384 fn test_optimized_matrix_vector_multiply() {
385 let matrix = create_test_matrix();
386 let x = vec![1.0, 2.0];
387 let mut y = vec![0.0; 2];
388
389 matrix.multiply_vector(&x, &mut y);
390 assert_eq!(y, vec![6.0, 7.0]); }
392
393 #[test]
394 fn test_optimized_conjugate_gradient() {
395 let matrix = create_test_matrix();
396 let b = vec![1.0, 2.0];
397
398 let config = OptimizedSolverConfig::default();
399 let mut solver = OptimizedConjugateGradientSolver::new(config);
400
401 let result = solver.solve(&matrix, &b).unwrap();
402
403 assert!(result.converged);
404 assert!(result.residual_norm < 1e-6);
405 assert!(result.iterations > 0);
406
407 let mut ax = vec![0.0; 2];
409 matrix.multiply_vector(&result.solution, &mut ax);
410
411 let error = ((ax[0] - b[0]).powi(2) + (ax[1] - b[1]).powi(2)).sqrt();
412 assert!(error < 1e-10);
413 }
414
415 #[test]
416 fn test_solver_performance_stats() {
417 let matrix = create_test_matrix();
418 let b = vec![1.0, 2.0];
419
420 let config = OptimizedSolverConfig::default();
421 let mut solver = OptimizedConjugateGradientSolver::new(config);
422
423 let result = solver.solve(&matrix, &b).unwrap();
424
425 assert!(result.performance_stats.matvec_count > 0);
426 assert!(result.performance_stats.dot_product_count > 0);
427 assert!(result.performance_stats.total_flops > 0);
428 }
429}