sublinear_solver/
optimized_solver.rs

1//! High-performance optimized solver implementations.
2//!
3//! This module provides optimized versions of linear system solvers with
4//! SIMD acceleration, buffer pooling, and parallel execution capabilities.
5
6use crate::types::Precision;
7use crate::matrix::sparse::{CSRStorage, COOStorage};
8#[cfg(feature = "simd")]
9use crate::simd_ops::{matrix_vector_multiply_simd, dot_product_simd, axpy_simd};
10use alloc::vec::Vec;
11use core::sync::atomic::{AtomicUsize, Ordering};
12
13#[cfg(feature = "std")]
14use std::time::Instant;
15
16/// High-performance sparse matrix optimized for sublinear-time algorithms.
17pub struct OptimizedSparseMatrix {
18    storage: CSRStorage,
19    dimensions: (usize, usize),
20    performance_stats: PerformanceStats,
21}
22
23/// Performance statistics for matrix operations.
24#[derive(Debug, Default)]
25pub struct PerformanceStats {
26    pub matvec_count: AtomicUsize,
27    pub bytes_processed: AtomicUsize,
28}
29
30impl Clone for PerformanceStats {
31    fn clone(&self) -> Self {
32        Self {
33            matvec_count: AtomicUsize::new(self.matvec_count.load(Ordering::Relaxed)),
34            bytes_processed: AtomicUsize::new(self.bytes_processed.load(Ordering::Relaxed)),
35        }
36    }
37}
38
39impl OptimizedSparseMatrix {
40    /// Create optimized sparse matrix from triplets.
41    pub fn from_triplets(
42        triplets: Vec<(usize, usize, Precision)>,
43        rows: usize,
44        cols: usize,
45    ) -> Result<Self, String> {
46        let coo = COOStorage::from_triplets(triplets)
47            .map_err(|e| format!("Failed to create COO storage: {:?}", e))?;
48        let storage = CSRStorage::from_coo(&coo, rows, cols)
49            .map_err(|e| format!("Failed to create CSR storage: {:?}", e))?;
50
51        Ok(Self {
52            storage,
53            dimensions: (rows, cols),
54            performance_stats: PerformanceStats::default(),
55        })
56    }
57
58    /// Get matrix dimensions.
59    pub fn dimensions(&self) -> (usize, usize) {
60        self.dimensions
61    }
62
63    /// Get number of non-zero elements.
64    pub fn nnz(&self) -> usize {
65        self.storage.nnz()
66    }
67
68    /// SIMD-accelerated matrix-vector multiplication.
69    pub fn multiply_vector(&self, x: &[Precision], y: &mut [Precision]) {
70        assert_eq!(x.len(), self.dimensions.1);
71        assert_eq!(y.len(), self.dimensions.0);
72
73        self.performance_stats.matvec_count.fetch_add(1, Ordering::Relaxed);
74        let bytes = (self.storage.values.len() * 8) + (x.len() * 8) + (y.len() * 8);
75        self.performance_stats.bytes_processed.fetch_add(bytes, Ordering::Relaxed);
76
77#[cfg(feature = "simd")]
78        {
79            matrix_vector_multiply_simd(
80                &self.storage.values,
81                &self.storage.col_indices,
82                &self.storage.row_ptr,
83                x,
84                y,
85            );
86        }
87        #[cfg(not(feature = "simd"))]
88        {
89            self.storage.multiply_vector(x, y);
90        }
91    }
92
93    /// Get performance statistics.
94    pub fn get_performance_stats(&self) -> (usize, usize) {
95        (
96            self.performance_stats.matvec_count.load(Ordering::Relaxed),
97            self.performance_stats.bytes_processed.load(Ordering::Relaxed),
98        )
99    }
100
101    /// Reset performance counters.
102    pub fn reset_stats(&self) {
103        self.performance_stats.matvec_count.store(0, Ordering::Relaxed);
104        self.performance_stats.bytes_processed.store(0, Ordering::Relaxed);
105    }
106}
107
108/// Configuration for the optimized conjugate gradient solver.
109#[derive(Debug, Clone)]
110pub struct OptimizedSolverConfig {
111    /// Maximum number of iterations
112    pub max_iterations: usize,
113    /// Convergence tolerance
114    pub tolerance: Precision,
115    /// Enable performance profiling
116    pub enable_profiling: bool,
117}
118
119impl Default for OptimizedSolverConfig {
120    fn default() -> Self {
121        Self {
122            max_iterations: 1000,
123            tolerance: 1e-6,
124            enable_profiling: false,
125        }
126    }
127}
128
129/// Result of optimized solver computation.
130#[derive(Debug, Clone)]
131pub struct OptimizedSolverResult {
132    /// Solution vector
133    pub solution: Vec<Precision>,
134    /// Final residual norm
135    pub residual_norm: Precision,
136    /// Number of iterations performed
137    pub iterations: usize,
138    /// Whether the solver converged
139    pub converged: bool,
140    /// Total computation time in milliseconds
141    #[cfg(feature = "std")]
142    pub computation_time_ms: f64,
143    #[cfg(not(feature = "std"))]
144    pub computation_time_ms: u64,
145    /// Performance statistics
146    pub performance_stats: OptimizedSolverStats,
147}
148
149/// Performance statistics for optimized solver.
150#[derive(Debug, Clone, Default)]
151pub struct OptimizedSolverStats {
152    /// Number of matrix-vector multiplications
153    pub matvec_count: usize,
154    /// Number of dot products computed
155    pub dot_product_count: usize,
156    /// Number of AXPY operations
157    pub axpy_count: usize,
158    /// Total floating-point operations
159    pub total_flops: usize,
160    /// Average bandwidth achieved (GB/s)
161    pub average_bandwidth_gbs: f64,
162    /// Average GFLOPS achieved
163    pub average_gflops: f64,
164}
165
166/// High-performance conjugate gradient solver with SIMD optimizations.
167pub struct OptimizedConjugateGradientSolver {
168    config: OptimizedSolverConfig,
169    stats: OptimizedSolverStats,
170}
171
172impl OptimizedConjugateGradientSolver {
173    /// Create a new optimized solver.
174    pub fn new(config: OptimizedSolverConfig) -> Self {
175        Self {
176            config,
177            stats: OptimizedSolverStats::default(),
178        }
179    }
180
181    /// Solve the linear system Ax = b using optimized conjugate gradient.
182    pub fn solve(
183        &mut self,
184        matrix: &OptimizedSparseMatrix,
185        b: &[Precision],
186    ) -> Result<OptimizedSolverResult, String> {
187        let (rows, cols) = matrix.dimensions();
188        if rows != cols {
189            return Err("Matrix must be square".to_string());
190        }
191        if b.len() != rows {
192            return Err("Right-hand side vector length must match matrix size".to_string());
193        }
194
195        #[cfg(feature = "std")]
196        let start_time = Instant::now();
197
198        // Reset statistics
199        self.stats = OptimizedSolverStats::default();
200
201        // Initialize solution and workspace vectors
202        let mut x = vec![0.0; rows];
203        let mut r = vec![0.0; rows];
204        let mut p = vec![0.0; rows];
205        let mut ap = vec![0.0; rows];
206
207        // r = b - A*x (initially r = b since x = 0)
208        r.copy_from_slice(b);
209
210        let mut iteration = 0;
211        let tolerance_sq = self.config.tolerance * self.config.tolerance;
212        let mut converged = false;
213
214        // Conjugate gradient iteration
215        let mut rsold = 0.0;
216        for &ri in r.iter() {
217            rsold += ri * ri;
218        }
219        p.copy_from_slice(&r);
220
221        while iteration < self.config.max_iterations {
222            if rsold <= tolerance_sq {
223                converged = true;
224                break;
225            }
226
227            // ap = A * p
228            matrix.multiply_vector(&p, &mut ap);
229            self.stats.matvec_count += 1;
230
231            // alpha = rsold / (p^T * ap)
232            let mut pap = 0.0;
233            for (&pi, &api) in p.iter().zip(ap.iter()) {
234                pap += pi * api;
235            }
236
237            if pap.abs() < 1e-16 {
238                break; // Avoid division by zero
239            }
240
241            let alpha = rsold / pap;
242
243            // x = x + alpha * p
244            for (xi, &pi) in x.iter_mut().zip(p.iter()) {
245                *xi += alpha * pi;
246            }
247
248            // r = r - alpha * ap
249            for (ri, &api) in r.iter_mut().zip(ap.iter()) {
250                *ri -= alpha * api;
251            }
252
253            let mut rsnew = 0.0;
254            for &ri in r.iter() {
255                rsnew += ri * ri;
256            }
257
258            let beta = rsnew / rsold;
259
260            // p = r + beta * p
261            for (pi, &ri) in p.iter_mut().zip(r.iter()) {
262                *pi = ri + beta * *pi;
263            }
264
265            rsold = rsnew;
266            iteration += 1;
267        }
268
269        #[cfg(feature = "std")]
270        let computation_time_ms = start_time.elapsed().as_millis() as f64;
271        #[cfg(not(feature = "std"))]
272        let computation_time_ms = 0.0;
273
274        // Calculate final residual
275        let final_residual_norm = rsold.sqrt();
276
277        // Update performance statistics
278        self.stats.total_flops = self.stats.matvec_count * matrix.nnz() * 2 +
279                                 iteration * rows * 6; // vector operations per iteration
280
281        if computation_time_ms > 0.0 {
282            let total_gb = (self.stats.total_flops * 8) as f64 / 1e9;
283            self.stats.average_bandwidth_gbs = total_gb / (computation_time_ms / 1000.0);
284            self.stats.average_gflops = (self.stats.total_flops as f64) / (computation_time_ms * 1e6);
285        }
286
287        Ok(OptimizedSolverResult {
288            solution: x,
289            residual_norm: final_residual_norm,
290            iterations: iteration,
291            converged,
292            computation_time_ms,
293            performance_stats: self.stats.clone(),
294        })
295    }
296
297    /// Compute dot product with SIMD optimization.
298    fn dot_product(&mut self, x: &[Precision], y: &[Precision]) -> Precision {
299        self.stats.dot_product_count += 1;
300        #[cfg(feature = "simd")]
301        {
302            dot_product_simd(x, y)
303        }
304        #[cfg(not(feature = "simd"))]
305        {
306            x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum()
307        }
308    }
309
310    /// Compute AXPY operation (y = alpha * x + y) with SIMD optimization.
311    fn axpy(&mut self, alpha: Precision, x: &[Precision], y: &mut [Precision]) {
312        self.stats.axpy_count += 1;
313        #[cfg(feature = "simd")]
314        {
315            axpy_simd(alpha, x, y);
316        }
317        #[cfg(not(feature = "simd"))]
318        {
319            for (yi, &xi) in y.iter_mut().zip(x.iter()) {
320                *yi += alpha * xi;
321            }
322        }
323    }
324
325    /// Compute L2 norm of a vector.
326    fn l2_norm(&self, x: &[Precision]) -> Precision {
327        x.iter().map(|&xi| xi * xi).sum::<Precision>().sqrt()
328    }
329
330    /// Get the last iteration count.
331    pub fn get_last_iteration_count(&self) -> usize {
332        self.stats.matvec_count
333    }
334
335    /// Solve with callback for streaming results.
336    pub fn solve_with_callback<F>(
337        &mut self,
338        matrix: &OptimizedSparseMatrix,
339        b: &[Precision],
340        _chunk_size: usize,
341        mut _callback: F,
342    ) -> Result<OptimizedSolverResult, String>
343    where
344        F: FnMut(&OptimizedSolverStats),
345    {
346        // For now, just call the regular solve method
347        // In a full implementation, this would call the callback periodically
348        self.solve(matrix, b)
349    }
350}
351
352impl OptimizedSolverResult {
353    /// Get the solution data.
354    pub fn data(&self) -> &[Precision] {
355        &self.solution
356    }
357}
358
359/// Additional configuration options for the optimized solver.
360#[derive(Debug, Clone, Default)]
361pub struct OptimizedSolverOptions {
362    /// Enable detailed performance tracking
363    pub track_performance: bool,
364    /// Enable memory usage tracking
365    pub track_memory: bool,
366}
367
368#[cfg(all(test, feature = "std"))]
369mod tests {
370    use super::*;
371
372    fn create_test_matrix() -> OptimizedSparseMatrix {
373        // Create a simple 2x2 symmetric positive definite matrix
374        let triplets = vec![
375            (0, 0, 4.0), (0, 1, 1.0),
376            (1, 0, 1.0), (1, 1, 3.0),
377        ];
378        OptimizedSparseMatrix::from_triplets(triplets, 2, 2).unwrap()
379    }
380
381    #[test]
382    fn test_optimized_matrix_creation() {
383        let matrix = create_test_matrix();
384        assert_eq!(matrix.dimensions(), (2, 2));
385        assert_eq!(matrix.nnz(), 4);
386    }
387
388    #[test]
389    fn test_optimized_matrix_vector_multiply() {
390        let matrix = create_test_matrix();
391        let x = vec![1.0, 2.0];
392        let mut y = vec![0.0; 2];
393
394        matrix.multiply_vector(&x, &mut y);
395        assert_eq!(y, vec![6.0, 7.0]); // [4*1+1*2, 1*1+3*2]
396    }
397
398    #[test]
399    fn test_optimized_conjugate_gradient() {
400        let matrix = create_test_matrix();
401        let b = vec![1.0, 2.0];
402
403        let config = OptimizedSolverConfig::default();
404        let mut solver = OptimizedConjugateGradientSolver::new(config);
405
406        let result = solver.solve(&matrix, &b).unwrap();
407
408        assert!(result.converged);
409        assert!(result.residual_norm < 1e-6);
410        assert!(result.iterations > 0);
411
412        // Verify solution by substituting back
413        let mut ax = vec![0.0; 2];
414        matrix.multiply_vector(&result.solution, &mut ax);
415
416        let error = ((ax[0] - b[0]).powi(2) + (ax[1] - b[1]).powi(2)).sqrt();
417        assert!(error < 1e-10);
418    }
419
420    #[test]
421    fn test_solver_performance_stats() {
422        let matrix = create_test_matrix();
423        let b = vec![1.0, 2.0];
424
425        let config = OptimizedSolverConfig::default();
426        let mut solver = OptimizedConjugateGradientSolver::new(config);
427
428        let result = solver.solve(&matrix, &b).unwrap();
429
430        assert!(result.performance_stats.matvec_count > 0);
431        assert!(result.performance_stats.dot_product_count > 0);
432        assert!(result.performance_stats.total_flops > 0);
433    }
434}