1use crate::types::Precision;
7use crate::matrix::sparse::{CSRStorage, COOStorage};
8#[cfg(feature = "simd")]
9use crate::simd_ops::{matrix_vector_multiply_simd, dot_product_simd, axpy_simd};
10use alloc::vec::Vec;
11use core::sync::atomic::{AtomicUsize, Ordering};
12
13#[cfg(feature = "std")]
14use std::time::Instant;
15
16pub struct OptimizedSparseMatrix {
18 storage: CSRStorage,
19 dimensions: (usize, usize),
20 performance_stats: PerformanceStats,
21}
22
23#[derive(Debug, Default)]
25pub struct PerformanceStats {
26 pub matvec_count: AtomicUsize,
27 pub bytes_processed: AtomicUsize,
28}
29
30impl Clone for PerformanceStats {
31 fn clone(&self) -> Self {
32 Self {
33 matvec_count: AtomicUsize::new(self.matvec_count.load(Ordering::Relaxed)),
34 bytes_processed: AtomicUsize::new(self.bytes_processed.load(Ordering::Relaxed)),
35 }
36 }
37}
38
39impl OptimizedSparseMatrix {
40 pub fn from_triplets(
42 triplets: Vec<(usize, usize, Precision)>,
43 rows: usize,
44 cols: usize,
45 ) -> Result<Self, String> {
46 let coo = COOStorage::from_triplets(triplets)
47 .map_err(|e| format!("Failed to create COO storage: {:?}", e))?;
48 let storage = CSRStorage::from_coo(&coo, rows, cols)
49 .map_err(|e| format!("Failed to create CSR storage: {:?}", e))?;
50
51 Ok(Self {
52 storage,
53 dimensions: (rows, cols),
54 performance_stats: PerformanceStats::default(),
55 })
56 }
57
58 pub fn dimensions(&self) -> (usize, usize) {
60 self.dimensions
61 }
62
63 pub fn nnz(&self) -> usize {
65 self.storage.nnz()
66 }
67
68 pub fn multiply_vector(&self, x: &[Precision], y: &mut [Precision]) {
70 assert_eq!(x.len(), self.dimensions.1);
71 assert_eq!(y.len(), self.dimensions.0);
72
73 self.performance_stats.matvec_count.fetch_add(1, Ordering::Relaxed);
74 let bytes = (self.storage.values.len() * 8) + (x.len() * 8) + (y.len() * 8);
75 self.performance_stats.bytes_processed.fetch_add(bytes, Ordering::Relaxed);
76
77#[cfg(feature = "simd")]
78 {
79 matrix_vector_multiply_simd(
80 &self.storage.values,
81 &self.storage.col_indices,
82 &self.storage.row_ptr,
83 x,
84 y,
85 );
86 }
87 #[cfg(not(feature = "simd"))]
88 {
89 self.storage.multiply_vector(x, y);
90 }
91 }
92
93 pub fn get_performance_stats(&self) -> (usize, usize) {
95 (
96 self.performance_stats.matvec_count.load(Ordering::Relaxed),
97 self.performance_stats.bytes_processed.load(Ordering::Relaxed),
98 )
99 }
100
101 pub fn reset_stats(&self) {
103 self.performance_stats.matvec_count.store(0, Ordering::Relaxed);
104 self.performance_stats.bytes_processed.store(0, Ordering::Relaxed);
105 }
106}
107
108#[derive(Debug, Clone)]
110pub struct OptimizedSolverConfig {
111 pub max_iterations: usize,
113 pub tolerance: Precision,
115 pub enable_profiling: bool,
117}
118
119impl Default for OptimizedSolverConfig {
120 fn default() -> Self {
121 Self {
122 max_iterations: 1000,
123 tolerance: 1e-6,
124 enable_profiling: false,
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct OptimizedSolverResult {
132 pub solution: Vec<Precision>,
134 pub residual_norm: Precision,
136 pub iterations: usize,
138 pub converged: bool,
140 #[cfg(feature = "std")]
142 pub computation_time_ms: f64,
143 #[cfg(not(feature = "std"))]
144 pub computation_time_ms: u64,
145 pub performance_stats: OptimizedSolverStats,
147}
148
149#[derive(Debug, Clone, Default)]
151pub struct OptimizedSolverStats {
152 pub matvec_count: usize,
154 pub dot_product_count: usize,
156 pub axpy_count: usize,
158 pub total_flops: usize,
160 pub average_bandwidth_gbs: f64,
162 pub average_gflops: f64,
164}
165
166pub struct OptimizedConjugateGradientSolver {
168 config: OptimizedSolverConfig,
169 stats: OptimizedSolverStats,
170}
171
172impl OptimizedConjugateGradientSolver {
173 pub fn new(config: OptimizedSolverConfig) -> Self {
175 Self {
176 config,
177 stats: OptimizedSolverStats::default(),
178 }
179 }
180
181 pub fn solve(
183 &mut self,
184 matrix: &OptimizedSparseMatrix,
185 b: &[Precision],
186 ) -> Result<OptimizedSolverResult, String> {
187 let (rows, cols) = matrix.dimensions();
188 if rows != cols {
189 return Err("Matrix must be square".to_string());
190 }
191 if b.len() != rows {
192 return Err("Right-hand side vector length must match matrix size".to_string());
193 }
194
195 #[cfg(feature = "std")]
196 let start_time = Instant::now();
197
198 self.stats = OptimizedSolverStats::default();
200
201 let mut x = vec![0.0; rows];
203 let mut r = vec![0.0; rows];
204 let mut p = vec![0.0; rows];
205 let mut ap = vec![0.0; rows];
206
207 r.copy_from_slice(b);
209
210 let mut iteration = 0;
211 let tolerance_sq = self.config.tolerance * self.config.tolerance;
212 let mut converged = false;
213
214 let mut rsold = 0.0;
216 for &ri in r.iter() {
217 rsold += ri * ri;
218 }
219 p.copy_from_slice(&r);
220
221 while iteration < self.config.max_iterations {
222 if rsold <= tolerance_sq {
223 converged = true;
224 break;
225 }
226
227 matrix.multiply_vector(&p, &mut ap);
229 self.stats.matvec_count += 1;
230
231 let mut pap = 0.0;
233 for (&pi, &api) in p.iter().zip(ap.iter()) {
234 pap += pi * api;
235 }
236
237 if pap.abs() < 1e-16 {
238 break; }
240
241 let alpha = rsold / pap;
242
243 for (xi, &pi) in x.iter_mut().zip(p.iter()) {
245 *xi += alpha * pi;
246 }
247
248 for (ri, &api) in r.iter_mut().zip(ap.iter()) {
250 *ri -= alpha * api;
251 }
252
253 let mut rsnew = 0.0;
254 for &ri in r.iter() {
255 rsnew += ri * ri;
256 }
257
258 let beta = rsnew / rsold;
259
260 for (pi, &ri) in p.iter_mut().zip(r.iter()) {
262 *pi = ri + beta * *pi;
263 }
264
265 rsold = rsnew;
266 iteration += 1;
267 }
268
269 #[cfg(feature = "std")]
270 let computation_time_ms = start_time.elapsed().as_millis() as f64;
271 #[cfg(not(feature = "std"))]
272 let computation_time_ms = 0.0;
273
274 let final_residual_norm = rsold.sqrt();
276
277 self.stats.total_flops = self.stats.matvec_count * matrix.nnz() * 2 +
279 iteration * rows * 6; if computation_time_ms > 0.0 {
282 let total_gb = (self.stats.total_flops * 8) as f64 / 1e9;
283 self.stats.average_bandwidth_gbs = total_gb / (computation_time_ms / 1000.0);
284 self.stats.average_gflops = (self.stats.total_flops as f64) / (computation_time_ms * 1e6);
285 }
286
287 Ok(OptimizedSolverResult {
288 solution: x,
289 residual_norm: final_residual_norm,
290 iterations: iteration,
291 converged,
292 computation_time_ms,
293 performance_stats: self.stats.clone(),
294 })
295 }
296
297 fn dot_product(&mut self, x: &[Precision], y: &[Precision]) -> Precision {
299 self.stats.dot_product_count += 1;
300 #[cfg(feature = "simd")]
301 {
302 dot_product_simd(x, y)
303 }
304 #[cfg(not(feature = "simd"))]
305 {
306 x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum()
307 }
308 }
309
310 fn axpy(&mut self, alpha: Precision, x: &[Precision], y: &mut [Precision]) {
312 self.stats.axpy_count += 1;
313 #[cfg(feature = "simd")]
314 {
315 axpy_simd(alpha, x, y);
316 }
317 #[cfg(not(feature = "simd"))]
318 {
319 for (yi, &xi) in y.iter_mut().zip(x.iter()) {
320 *yi += alpha * xi;
321 }
322 }
323 }
324
325 fn l2_norm(&self, x: &[Precision]) -> Precision {
327 x.iter().map(|&xi| xi * xi).sum::<Precision>().sqrt()
328 }
329
330 pub fn get_last_iteration_count(&self) -> usize {
332 self.stats.matvec_count
333 }
334
335 pub fn solve_with_callback<F>(
337 &mut self,
338 matrix: &OptimizedSparseMatrix,
339 b: &[Precision],
340 _chunk_size: usize,
341 mut _callback: F,
342 ) -> Result<OptimizedSolverResult, String>
343 where
344 F: FnMut(&OptimizedSolverStats),
345 {
346 self.solve(matrix, b)
349 }
350}
351
352impl OptimizedSolverResult {
353 pub fn data(&self) -> &[Precision] {
355 &self.solution
356 }
357}
358
359#[derive(Debug, Clone, Default)]
361pub struct OptimizedSolverOptions {
362 pub track_performance: bool,
364 pub track_memory: bool,
366}
367
368#[cfg(all(test, feature = "std"))]
369mod tests {
370 use super::*;
371
372 fn create_test_matrix() -> OptimizedSparseMatrix {
373 let triplets = vec![
375 (0, 0, 4.0), (0, 1, 1.0),
376 (1, 0, 1.0), (1, 1, 3.0),
377 ];
378 OptimizedSparseMatrix::from_triplets(triplets, 2, 2).unwrap()
379 }
380
381 #[test]
382 fn test_optimized_matrix_creation() {
383 let matrix = create_test_matrix();
384 assert_eq!(matrix.dimensions(), (2, 2));
385 assert_eq!(matrix.nnz(), 4);
386 }
387
388 #[test]
389 fn test_optimized_matrix_vector_multiply() {
390 let matrix = create_test_matrix();
391 let x = vec![1.0, 2.0];
392 let mut y = vec![0.0; 2];
393
394 matrix.multiply_vector(&x, &mut y);
395 assert_eq!(y, vec![6.0, 7.0]); }
397
398 #[test]
399 fn test_optimized_conjugate_gradient() {
400 let matrix = create_test_matrix();
401 let b = vec![1.0, 2.0];
402
403 let config = OptimizedSolverConfig::default();
404 let mut solver = OptimizedConjugateGradientSolver::new(config);
405
406 let result = solver.solve(&matrix, &b).unwrap();
407
408 assert!(result.converged);
409 assert!(result.residual_norm < 1e-6);
410 assert!(result.iterations > 0);
411
412 let mut ax = vec![0.0; 2];
414 matrix.multiply_vector(&result.solution, &mut ax);
415
416 let error = ((ax[0] - b[0]).powi(2) + (ax[1] - b[1]).powi(2)).sqrt();
417 assert!(error < 1e-10);
418 }
419
420 #[test]
421 fn test_solver_performance_stats() {
422 let matrix = create_test_matrix();
423 let b = vec![1.0, 2.0];
424
425 let config = OptimizedSolverConfig::default();
426 let mut solver = OptimizedConjugateGradientSolver::new(config);
427
428 let result = solver.solve(&matrix, &b).unwrap();
429
430 assert!(result.performance_stats.matvec_count > 0);
431 assert!(result.performance_stats.dot_product_count > 0);
432 assert!(result.performance_stats.total_flops > 0);
433 }
434}