1use crate::types::Precision;
7use crate::matrix::sparse::{CSRStorage, COOStorage};
8#[cfg(feature = "simd")]
9use crate::simd_ops::{matrix_vector_multiply_simd, dot_product_simd, axpy_simd};
10use alloc::vec::Vec;
11use core::sync::atomic::{AtomicUsize, Ordering};
12
13#[cfg(feature = "std")]
14use std::time::Instant;
15
16pub struct OptimizedSparseMatrix {
18 storage: CSRStorage,
19 dimensions: (usize, usize),
20 performance_stats: PerformanceStats,
21}
22
23#[derive(Debug, Default)]
25pub struct PerformanceStats {
26 pub matvec_count: AtomicUsize,
27 pub bytes_processed: AtomicUsize,
28}
29
30impl Clone for PerformanceStats {
31 fn clone(&self) -> Self {
32 Self {
33 matvec_count: AtomicUsize::new(self.matvec_count.load(Ordering::Relaxed)),
34 bytes_processed: AtomicUsize::new(self.bytes_processed.load(Ordering::Relaxed)),
35 }
36 }
37}
38
39impl OptimizedSparseMatrix {
40 pub fn from_triplets(
42 triplets: Vec<(usize, usize, Precision)>,
43 rows: usize,
44 cols: usize,
45 ) -> Result<Self, String> {
46 let coo = COOStorage::from_triplets(triplets)
47 .map_err(|e| format!("Failed to create COO storage: {:?}", e))?;
48 let storage = CSRStorage::from_coo(&coo, rows, cols)
49 .map_err(|e| format!("Failed to create CSR storage: {:?}", e))?;
50
51 Ok(Self {
52 storage,
53 dimensions: (rows, cols),
54 performance_stats: PerformanceStats::default(),
55 })
56 }
57
58 pub fn dimensions(&self) -> (usize, usize) {
60 self.dimensions
61 }
62
63 pub fn nnz(&self) -> usize {
65 self.storage.nnz()
66 }
67
68 pub fn multiply_vector(&self, x: &[Precision], y: &mut [Precision]) {
70 assert_eq!(x.len(), self.dimensions.1);
71 assert_eq!(y.len(), self.dimensions.0);
72
73 self.performance_stats.matvec_count.fetch_add(1, Ordering::Relaxed);
74 let bytes = (self.storage.values.len() * 8) + (x.len() * 8) + (y.len() * 8);
75 self.performance_stats.bytes_processed.fetch_add(bytes, Ordering::Relaxed);
76
77#[cfg(feature = "simd")]
78 {
79 matrix_vector_multiply_simd(
80 &self.storage.values,
81 &self.storage.col_indices,
82 &self.storage.row_ptr,
83 x,
84 y,
85 );
86 }
87 #[cfg(not(feature = "simd"))]
88 {
89 self.storage.multiply_vector(x, y);
90 }
91 }
92
93 pub fn get_performance_stats(&self) -> (usize, usize) {
95 (
96 self.performance_stats.matvec_count.load(Ordering::Relaxed),
97 self.performance_stats.bytes_processed.load(Ordering::Relaxed),
98 )
99 }
100
101 pub fn reset_stats(&self) {
103 self.performance_stats.matvec_count.store(0, Ordering::Relaxed);
104 self.performance_stats.bytes_processed.store(0, Ordering::Relaxed);
105 }
106}
107
108#[derive(Debug, Clone)]
110pub struct OptimizedSolverConfig {
111 pub max_iterations: usize,
113 pub tolerance: Precision,
115 pub enable_profiling: bool,
117}
118
119impl Default for OptimizedSolverConfig {
120 fn default() -> Self {
121 Self {
122 max_iterations: 1000,
123 tolerance: 1e-6,
124 enable_profiling: false,
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct OptimizedSolverResult {
132 pub solution: Vec<Precision>,
134 pub residual_norm: Precision,
136 pub iterations: usize,
138 pub converged: bool,
140 #[cfg(feature = "std")]
142 pub computation_time_ms: f64,
143 #[cfg(not(feature = "std"))]
144 pub computation_time_ms: u64,
145 pub performance_stats: OptimizedSolverStats,
147}
148
149#[derive(Debug, Clone, Default)]
151pub struct OptimizedSolverStats {
152 pub matvec_count: usize,
154 pub dot_product_count: usize,
156 pub axpy_count: usize,
158 pub total_flops: usize,
160 pub average_bandwidth_gbs: f64,
162 pub average_gflops: f64,
164}
165
166pub struct OptimizedConjugateGradientSolver {
168 config: OptimizedSolverConfig,
169 stats: OptimizedSolverStats,
170}
171
172impl OptimizedConjugateGradientSolver {
173 pub fn new(config: OptimizedSolverConfig) -> Self {
175 Self {
176 config,
177 stats: OptimizedSolverStats::default(),
178 }
179 }
180
181 pub fn solve(
183 &mut self,
184 matrix: &OptimizedSparseMatrix,
185 b: &[Precision],
186 ) -> Result<OptimizedSolverResult, String> {
187 let (rows, cols) = matrix.dimensions();
188 if rows != cols {
189 return Err("Matrix must be square".to_string());
190 }
191 if b.len() != rows {
192 return Err("Right-hand side vector length must match matrix size".to_string());
193 }
194
195 #[cfg(feature = "std")]
196 let start_time = Instant::now();
197
198 self.stats = OptimizedSolverStats::default();
200
201 let mut x = vec![0.0; rows];
203 let mut r = vec![0.0; rows];
204 let mut p = vec![0.0; rows];
205 let mut ap = vec![0.0; rows];
206
207 r.copy_from_slice(b);
209
210 let mut iteration = 0;
211 let tolerance_sq = self.config.tolerance * self.config.tolerance;
212 let mut converged = false;
213
214 let mut rsold = self.dot_product(&r, &r);
217 p.copy_from_slice(&r);
218
219 while iteration < self.config.max_iterations {
220 if rsold <= tolerance_sq {
221 converged = true;
222 break;
223 }
224
225 matrix.multiply_vector(&p, &mut ap);
227 self.stats.matvec_count += 1;
228
229 let pap = self.dot_product(&p, &ap);
231
232 if pap.abs() < 1e-16 {
233 break; }
235
236 let alpha = rsold / pap;
237
238 self.axpy(alpha, &p, &mut x);
240
241 self.axpy(-alpha, &ap, &mut r);
243
244 let rsnew = self.dot_product(&r, &r);
245
246 let beta = rsnew / rsold;
247
248 for (pi, &ri) in p.iter_mut().zip(r.iter()) {
250 *pi = ri + beta * *pi;
251 }
252
253 rsold = rsnew;
254 iteration += 1;
255 }
256
257 #[cfg(feature = "std")]
258 let computation_time_ms = start_time.elapsed().as_millis() as f64;
259 #[cfg(not(feature = "std"))]
260 let computation_time_ms = 0.0;
261
262 let final_residual_norm = rsold.sqrt();
264
265 self.stats.total_flops = self.stats.matvec_count * matrix.nnz() * 2 +
267 iteration * rows * 6; if computation_time_ms > 0.0 {
270 let total_gb = (self.stats.total_flops * 8) as f64 / 1e9;
271 self.stats.average_bandwidth_gbs = total_gb / (computation_time_ms / 1000.0);
272 self.stats.average_gflops = (self.stats.total_flops as f64) / (computation_time_ms * 1e6);
273 }
274
275 Ok(OptimizedSolverResult {
276 solution: x,
277 residual_norm: final_residual_norm,
278 iterations: iteration,
279 converged,
280 computation_time_ms,
281 performance_stats: self.stats.clone(),
282 })
283 }
284
285 fn dot_product(&mut self, x: &[Precision], y: &[Precision]) -> Precision {
287 self.stats.dot_product_count += 1;
288 #[cfg(feature = "simd")]
289 {
290 dot_product_simd(x, y)
291 }
292 #[cfg(not(feature = "simd"))]
293 {
294 x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum()
295 }
296 }
297
298 fn axpy(&mut self, alpha: Precision, x: &[Precision], y: &mut [Precision]) {
300 self.stats.axpy_count += 1;
301 #[cfg(feature = "simd")]
302 {
303 axpy_simd(alpha, x, y);
304 }
305 #[cfg(not(feature = "simd"))]
306 {
307 for (yi, &xi) in y.iter_mut().zip(x.iter()) {
308 *yi += alpha * xi;
309 }
310 }
311 }
312
313 fn l2_norm(&self, x: &[Precision]) -> Precision {
315 x.iter().map(|&xi| xi * xi).sum::<Precision>().sqrt()
316 }
317
318 pub fn get_last_iteration_count(&self) -> usize {
320 self.stats.matvec_count
321 }
322
323 pub fn solve_with_callback<F>(
325 &mut self,
326 matrix: &OptimizedSparseMatrix,
327 b: &[Precision],
328 _chunk_size: usize,
329 mut _callback: F,
330 ) -> Result<OptimizedSolverResult, String>
331 where
332 F: FnMut(&OptimizedSolverStats),
333 {
334 self.solve(matrix, b)
337 }
338}
339
340impl OptimizedSolverResult {
341 pub fn data(&self) -> &[Precision] {
343 &self.solution
344 }
345}
346
347#[derive(Debug, Clone, Default)]
349pub struct OptimizedSolverOptions {
350 pub track_performance: bool,
352 pub track_memory: bool,
354}
355
356#[cfg(all(test, feature = "std"))]
357mod tests {
358 use super::*;
359
360 fn create_test_matrix() -> OptimizedSparseMatrix {
361 let triplets = vec![
363 (0, 0, 4.0), (0, 1, 1.0),
364 (1, 0, 1.0), (1, 1, 3.0),
365 ];
366 OptimizedSparseMatrix::from_triplets(triplets, 2, 2).unwrap()
367 }
368
369 #[test]
370 fn test_optimized_matrix_creation() {
371 let matrix = create_test_matrix();
372 assert_eq!(matrix.dimensions(), (2, 2));
373 assert_eq!(matrix.nnz(), 4);
374 }
375
376 #[test]
377 fn test_optimized_matrix_vector_multiply() {
378 let matrix = create_test_matrix();
379 let x = vec![1.0, 2.0];
380 let mut y = vec![0.0; 2];
381
382 matrix.multiply_vector(&x, &mut y);
383 assert_eq!(y, vec![6.0, 7.0]); }
385
386 #[test]
387 fn test_optimized_conjugate_gradient() {
388 let matrix = create_test_matrix();
389 let b = vec![1.0, 2.0];
390
391 let config = OptimizedSolverConfig::default();
392 let mut solver = OptimizedConjugateGradientSolver::new(config);
393
394 let result = solver.solve(&matrix, &b).unwrap();
395
396 assert!(result.converged);
397 assert!(result.residual_norm < 1e-6);
398 assert!(result.iterations > 0);
399
400 let mut ax = vec![0.0; 2];
402 matrix.multiply_vector(&result.solution, &mut ax);
403
404 let error = ((ax[0] - b[0]).powi(2) + (ax[1] - b[1]).powi(2)).sqrt();
405 assert!(error < 1e-10);
406 }
407
408 #[test]
409 fn test_solver_performance_stats() {
410 let matrix = create_test_matrix();
411 let b = vec![1.0, 2.0];
412
413 let config = OptimizedSolverConfig::default();
414 let mut solver = OptimizedConjugateGradientSolver::new(config);
415
416 let result = solver.solve(&matrix, &b).unwrap();
417
418 assert!(result.performance_stats.matvec_count > 0);
419 assert!(result.performance_stats.dot_product_count > 0);
420 assert!(result.performance_stats.total_flops > 0);
421 }
422}