Skip to main content

sublinear_solver/
optimized_solver.rs

1//! High-performance optimized solver implementations.
2//!
3//! This module provides optimized versions of linear system solvers with
4//! SIMD acceleration, buffer pooling, and parallel execution capabilities.
5
6use crate::types::Precision;
7use crate::matrix::sparse::{CSRStorage, COOStorage};
8#[cfg(feature = "simd")]
9use crate::simd_ops::{matrix_vector_multiply_simd, dot_product_simd, axpy_simd};
10use alloc::vec::Vec;
11use core::sync::atomic::{AtomicUsize, Ordering};
12
13#[cfg(feature = "std")]
14use std::time::Instant;
15
16/// High-performance sparse matrix optimized for sublinear-time algorithms.
17pub struct OptimizedSparseMatrix {
18    storage: CSRStorage,
19    dimensions: (usize, usize),
20    performance_stats: PerformanceStats,
21}
22
23/// Performance statistics for matrix operations.
24#[derive(Debug, Default)]
25pub struct PerformanceStats {
26    pub matvec_count: AtomicUsize,
27    pub bytes_processed: AtomicUsize,
28}
29
30impl Clone for PerformanceStats {
31    fn clone(&self) -> Self {
32        Self {
33            matvec_count: AtomicUsize::new(self.matvec_count.load(Ordering::Relaxed)),
34            bytes_processed: AtomicUsize::new(self.bytes_processed.load(Ordering::Relaxed)),
35        }
36    }
37}
38
39impl OptimizedSparseMatrix {
40    /// Create optimized sparse matrix from triplets.
41    pub fn from_triplets(
42        triplets: Vec<(usize, usize, Precision)>,
43        rows: usize,
44        cols: usize,
45    ) -> Result<Self, String> {
46        let coo = COOStorage::from_triplets(triplets)
47            .map_err(|e| format!("Failed to create COO storage: {:?}", e))?;
48        let storage = CSRStorage::from_coo(&coo, rows, cols)
49            .map_err(|e| format!("Failed to create CSR storage: {:?}", e))?;
50
51        Ok(Self {
52            storage,
53            dimensions: (rows, cols),
54            performance_stats: PerformanceStats::default(),
55        })
56    }
57
58    /// Get matrix dimensions.
59    pub fn dimensions(&self) -> (usize, usize) {
60        self.dimensions
61    }
62
63    /// Get number of non-zero elements.
64    pub fn nnz(&self) -> usize {
65        self.storage.nnz()
66    }
67
68    /// SIMD-accelerated matrix-vector multiplication.
69    pub fn multiply_vector(&self, x: &[Precision], y: &mut [Precision]) {
70        assert_eq!(x.len(), self.dimensions.1);
71        assert_eq!(y.len(), self.dimensions.0);
72
73        self.performance_stats.matvec_count.fetch_add(1, Ordering::Relaxed);
74        let bytes = (self.storage.values.len() * 8) + (x.len() * 8) + (y.len() * 8);
75        self.performance_stats.bytes_processed.fetch_add(bytes, Ordering::Relaxed);
76
77#[cfg(feature = "simd")]
78        {
79            matrix_vector_multiply_simd(
80                &self.storage.values,
81                &self.storage.col_indices,
82                &self.storage.row_ptr,
83                x,
84                y,
85            );
86        }
87        #[cfg(not(feature = "simd"))]
88        {
89            self.storage.multiply_vector(x, y);
90        }
91    }
92
93    /// Get performance statistics.
94    pub fn get_performance_stats(&self) -> (usize, usize) {
95        (
96            self.performance_stats.matvec_count.load(Ordering::Relaxed),
97            self.performance_stats.bytes_processed.load(Ordering::Relaxed),
98        )
99    }
100
101    /// Reset performance counters.
102    pub fn reset_stats(&self) {
103        self.performance_stats.matvec_count.store(0, Ordering::Relaxed);
104        self.performance_stats.bytes_processed.store(0, Ordering::Relaxed);
105    }
106}
107
108/// Configuration for the optimized conjugate gradient solver.
109#[derive(Debug, Clone)]
110pub struct OptimizedSolverConfig {
111    /// Maximum number of iterations
112    pub max_iterations: usize,
113    /// Convergence tolerance
114    pub tolerance: Precision,
115    /// Enable performance profiling
116    pub enable_profiling: bool,
117}
118
119impl Default for OptimizedSolverConfig {
120    fn default() -> Self {
121        Self {
122            max_iterations: 1000,
123            tolerance: 1e-6,
124            enable_profiling: false,
125        }
126    }
127}
128
129/// Result of optimized solver computation.
130#[derive(Debug, Clone)]
131pub struct OptimizedSolverResult {
132    /// Solution vector
133    pub solution: Vec<Precision>,
134    /// Final residual norm
135    pub residual_norm: Precision,
136    /// Number of iterations performed
137    pub iterations: usize,
138    /// Whether the solver converged
139    pub converged: bool,
140    /// Total computation time in milliseconds
141    #[cfg(feature = "std")]
142    pub computation_time_ms: f64,
143    #[cfg(not(feature = "std"))]
144    pub computation_time_ms: u64,
145    /// Performance statistics
146    pub performance_stats: OptimizedSolverStats,
147}
148
149/// Performance statistics for optimized solver.
150#[derive(Debug, Clone, Default)]
151pub struct OptimizedSolverStats {
152    /// Number of matrix-vector multiplications
153    pub matvec_count: usize,
154    /// Number of dot products computed
155    pub dot_product_count: usize,
156    /// Number of AXPY operations
157    pub axpy_count: usize,
158    /// Total floating-point operations
159    pub total_flops: usize,
160    /// Average bandwidth achieved (GB/s)
161    pub average_bandwidth_gbs: f64,
162    /// Average GFLOPS achieved
163    pub average_gflops: f64,
164}
165
166/// High-performance conjugate gradient solver with SIMD optimizations.
167pub struct OptimizedConjugateGradientSolver {
168    config: OptimizedSolverConfig,
169    stats: OptimizedSolverStats,
170}
171
172impl OptimizedConjugateGradientSolver {
173    /// Create a new optimized solver.
174    pub fn new(config: OptimizedSolverConfig) -> Self {
175        Self {
176            config,
177            stats: OptimizedSolverStats::default(),
178        }
179    }
180
181    /// Solve the linear system Ax = b using optimized conjugate gradient.
182    pub fn solve(
183        &mut self,
184        matrix: &OptimizedSparseMatrix,
185        b: &[Precision],
186    ) -> Result<OptimizedSolverResult, String> {
187        let (rows, cols) = matrix.dimensions();
188        if rows != cols {
189            return Err("Matrix must be square".to_string());
190        }
191        if b.len() != rows {
192            return Err("Right-hand side vector length must match matrix size".to_string());
193        }
194
195        #[cfg(feature = "std")]
196        let start_time = Instant::now();
197
198        // Reset statistics
199        self.stats = OptimizedSolverStats::default();
200
201        // Initialize solution and workspace vectors
202        let mut x = vec![0.0; rows];
203        let mut r = vec![0.0; rows];
204        let mut p = vec![0.0; rows];
205        let mut ap = vec![0.0; rows];
206
207        // r = b - A*x (initially r = b since x = 0)
208        r.copy_from_slice(b);
209
210        let mut iteration = 0;
211        let tolerance_sq = self.config.tolerance * self.config.tolerance;
212        let mut converged = false;
213
214        // Conjugate gradient iteration — every dot product and AXPY goes
215        // through the instrumented helpers so `performance_stats` is correct.
216        let mut rsold = self.dot_product(&r, &r);
217        p.copy_from_slice(&r);
218
219        while iteration < self.config.max_iterations {
220            if rsold <= tolerance_sq {
221                converged = true;
222                break;
223            }
224
225            // ap = A * p
226            matrix.multiply_vector(&p, &mut ap);
227            self.stats.matvec_count += 1;
228
229            // alpha = rsold / (p^T * ap)
230            let pap = self.dot_product(&p, &ap);
231
232            if pap.abs() < 1e-16 {
233                break; // Avoid division by zero
234            }
235
236            let alpha = rsold / pap;
237
238            // x = x + alpha * p
239            self.axpy(alpha, &p, &mut x);
240
241            // r = r - alpha * ap
242            self.axpy(-alpha, &ap, &mut r);
243
244            let rsnew = self.dot_product(&r, &r);
245
246            let beta = rsnew / rsold;
247
248            // p = r + beta * p  (scale p in place, then add r)
249            for (pi, &ri) in p.iter_mut().zip(r.iter()) {
250                *pi = ri + beta * *pi;
251            }
252
253            rsold = rsnew;
254            iteration += 1;
255        }
256
257        #[cfg(feature = "std")]
258        let computation_time_ms = start_time.elapsed().as_millis() as f64;
259        #[cfg(not(feature = "std"))]
260        let computation_time_ms = 0.0;
261
262        // Calculate final residual
263        let final_residual_norm = rsold.sqrt();
264
265        // Update performance statistics
266        self.stats.total_flops = self.stats.matvec_count * matrix.nnz() * 2 +
267                                 iteration * rows * 6; // vector operations per iteration
268
269        if computation_time_ms > 0.0 {
270            let total_gb = (self.stats.total_flops * 8) as f64 / 1e9;
271            self.stats.average_bandwidth_gbs = total_gb / (computation_time_ms / 1000.0);
272            self.stats.average_gflops = (self.stats.total_flops as f64) / (computation_time_ms * 1e6);
273        }
274
275        Ok(OptimizedSolverResult {
276            solution: x,
277            residual_norm: final_residual_norm,
278            iterations: iteration,
279            converged,
280            computation_time_ms,
281            performance_stats: self.stats.clone(),
282        })
283    }
284
285    /// Compute dot product with SIMD optimization.
286    fn dot_product(&mut self, x: &[Precision], y: &[Precision]) -> Precision {
287        self.stats.dot_product_count += 1;
288        #[cfg(feature = "simd")]
289        {
290            dot_product_simd(x, y)
291        }
292        #[cfg(not(feature = "simd"))]
293        {
294            x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum()
295        }
296    }
297
298    /// Compute AXPY operation (y = alpha * x + y) with SIMD optimization.
299    fn axpy(&mut self, alpha: Precision, x: &[Precision], y: &mut [Precision]) {
300        self.stats.axpy_count += 1;
301        #[cfg(feature = "simd")]
302        {
303            axpy_simd(alpha, x, y);
304        }
305        #[cfg(not(feature = "simd"))]
306        {
307            for (yi, &xi) in y.iter_mut().zip(x.iter()) {
308                *yi += alpha * xi;
309            }
310        }
311    }
312
313    /// Compute L2 norm of a vector.
314    fn l2_norm(&self, x: &[Precision]) -> Precision {
315        x.iter().map(|&xi| xi * xi).sum::<Precision>().sqrt()
316    }
317
318    /// Get the last iteration count.
319    pub fn get_last_iteration_count(&self) -> usize {
320        self.stats.matvec_count
321    }
322
323    /// Solve with callback for streaming results.
324    pub fn solve_with_callback<F>(
325        &mut self,
326        matrix: &OptimizedSparseMatrix,
327        b: &[Precision],
328        _chunk_size: usize,
329        mut _callback: F,
330    ) -> Result<OptimizedSolverResult, String>
331    where
332        F: FnMut(&OptimizedSolverStats),
333    {
334        // For now, just call the regular solve method
335        // In a full implementation, this would call the callback periodically
336        self.solve(matrix, b)
337    }
338}
339
340impl OptimizedSolverResult {
341    /// Get the solution data.
342    pub fn data(&self) -> &[Precision] {
343        &self.solution
344    }
345}
346
347/// Additional configuration options for the optimized solver.
348#[derive(Debug, Clone, Default)]
349pub struct OptimizedSolverOptions {
350    /// Enable detailed performance tracking
351    pub track_performance: bool,
352    /// Enable memory usage tracking
353    pub track_memory: bool,
354}
355
356#[cfg(all(test, feature = "std"))]
357mod tests {
358    use super::*;
359
360    fn create_test_matrix() -> OptimizedSparseMatrix {
361        // Create a simple 2x2 symmetric positive definite matrix
362        let triplets = vec![
363            (0, 0, 4.0), (0, 1, 1.0),
364            (1, 0, 1.0), (1, 1, 3.0),
365        ];
366        OptimizedSparseMatrix::from_triplets(triplets, 2, 2).unwrap()
367    }
368
369    #[test]
370    fn test_optimized_matrix_creation() {
371        let matrix = create_test_matrix();
372        assert_eq!(matrix.dimensions(), (2, 2));
373        assert_eq!(matrix.nnz(), 4);
374    }
375
376    #[test]
377    fn test_optimized_matrix_vector_multiply() {
378        let matrix = create_test_matrix();
379        let x = vec![1.0, 2.0];
380        let mut y = vec![0.0; 2];
381
382        matrix.multiply_vector(&x, &mut y);
383        assert_eq!(y, vec![6.0, 7.0]); // [4*1+1*2, 1*1+3*2]
384    }
385
386    #[test]
387    fn test_optimized_conjugate_gradient() {
388        let matrix = create_test_matrix();
389        let b = vec![1.0, 2.0];
390
391        let config = OptimizedSolverConfig::default();
392        let mut solver = OptimizedConjugateGradientSolver::new(config);
393
394        let result = solver.solve(&matrix, &b).unwrap();
395
396        assert!(result.converged);
397        assert!(result.residual_norm < 1e-6);
398        assert!(result.iterations > 0);
399
400        // Verify solution by substituting back
401        let mut ax = vec![0.0; 2];
402        matrix.multiply_vector(&result.solution, &mut ax);
403
404        let error = ((ax[0] - b[0]).powi(2) + (ax[1] - b[1]).powi(2)).sqrt();
405        assert!(error < 1e-10);
406    }
407
408    #[test]
409    fn test_solver_performance_stats() {
410        let matrix = create_test_matrix();
411        let b = vec![1.0, 2.0];
412
413        let config = OptimizedSolverConfig::default();
414        let mut solver = OptimizedConjugateGradientSolver::new(config);
415
416        let result = solver.solve(&matrix, &b).unwrap();
417
418        assert!(result.performance_stats.matvec_count > 0);
419        assert!(result.performance_stats.dot_product_count > 0);
420        assert!(result.performance_stats.total_flops > 0);
421    }
422}