Skip to main content

sublinear_solver/
optimized_solver.rs

1//! High-performance optimized solver implementations.
2//!
3//! This module provides optimized versions of linear system solvers with
4//! SIMD acceleration, buffer pooling, and parallel execution capabilities.
5
6use crate::matrix::sparse::{COOStorage, CSRStorage};
7#[cfg(feature = "simd")]
8use crate::simd_ops::{axpy_simd, dot_product_simd, matrix_vector_multiply_simd};
9use crate::types::Precision;
10use alloc::vec::Vec;
11use core::sync::atomic::{AtomicUsize, Ordering};
12
13#[cfg(feature = "std")]
14use std::time::Instant;
15
16/// High-performance sparse matrix optimized for sublinear-time algorithms.
17pub struct OptimizedSparseMatrix {
18    storage: CSRStorage,
19    dimensions: (usize, usize),
20    performance_stats: PerformanceStats,
21}
22
23/// Performance statistics for matrix operations.
24#[derive(Debug, Default)]
25pub struct PerformanceStats {
26    pub matvec_count: AtomicUsize,
27    pub bytes_processed: AtomicUsize,
28}
29
30impl Clone for PerformanceStats {
31    fn clone(&self) -> Self {
32        Self {
33            matvec_count: AtomicUsize::new(self.matvec_count.load(Ordering::Relaxed)),
34            bytes_processed: AtomicUsize::new(self.bytes_processed.load(Ordering::Relaxed)),
35        }
36    }
37}
38
39impl OptimizedSparseMatrix {
40    /// Create optimized sparse matrix from triplets.
41    pub fn from_triplets(
42        triplets: Vec<(usize, usize, Precision)>,
43        rows: usize,
44        cols: usize,
45    ) -> Result<Self, String> {
46        let coo = COOStorage::from_triplets(triplets)
47            .map_err(|e| format!("Failed to create COO storage: {:?}", e))?;
48        let storage = CSRStorage::from_coo(&coo, rows, cols)
49            .map_err(|e| format!("Failed to create CSR storage: {:?}", e))?;
50
51        Ok(Self {
52            storage,
53            dimensions: (rows, cols),
54            performance_stats: PerformanceStats::default(),
55        })
56    }
57
58    /// Get matrix dimensions.
59    pub fn dimensions(&self) -> (usize, usize) {
60        self.dimensions
61    }
62
63    /// Get number of non-zero elements.
64    pub fn nnz(&self) -> usize {
65        self.storage.nnz()
66    }
67
68    /// SIMD-accelerated matrix-vector multiplication.
69    pub fn multiply_vector(&self, x: &[Precision], y: &mut [Precision]) {
70        assert_eq!(x.len(), self.dimensions.1);
71        assert_eq!(y.len(), self.dimensions.0);
72
73        self.performance_stats
74            .matvec_count
75            .fetch_add(1, Ordering::Relaxed);
76        let bytes = (self.storage.values.len() * 8) + (x.len() * 8) + (y.len() * 8);
77        self.performance_stats
78            .bytes_processed
79            .fetch_add(bytes, Ordering::Relaxed);
80
81        #[cfg(feature = "simd")]
82        {
83            matrix_vector_multiply_simd(
84                &self.storage.values,
85                &self.storage.col_indices,
86                &self.storage.row_ptr,
87                x,
88                y,
89            );
90        }
91        #[cfg(not(feature = "simd"))]
92        {
93            self.storage.multiply_vector(x, y);
94        }
95    }
96
97    /// Get performance statistics.
98    pub fn get_performance_stats(&self) -> (usize, usize) {
99        (
100            self.performance_stats.matvec_count.load(Ordering::Relaxed),
101            self.performance_stats
102                .bytes_processed
103                .load(Ordering::Relaxed),
104        )
105    }
106
107    /// Reset performance counters.
108    pub fn reset_stats(&self) {
109        self.performance_stats
110            .matvec_count
111            .store(0, Ordering::Relaxed);
112        self.performance_stats
113            .bytes_processed
114            .store(0, Ordering::Relaxed);
115    }
116}
117
118/// Configuration for the optimized conjugate gradient solver.
119#[derive(Debug, Clone)]
120pub struct OptimizedSolverConfig {
121    /// Maximum number of iterations
122    pub max_iterations: usize,
123    /// Convergence tolerance
124    pub tolerance: Precision,
125    /// Enable performance profiling
126    pub enable_profiling: bool,
127}
128
129impl Default for OptimizedSolverConfig {
130    fn default() -> Self {
131        Self {
132            max_iterations: 1000,
133            tolerance: 1e-6,
134            enable_profiling: false,
135        }
136    }
137}
138
139/// Result of optimized solver computation.
140#[derive(Debug, Clone)]
141pub struct OptimizedSolverResult {
142    /// Solution vector
143    pub solution: Vec<Precision>,
144    /// Final residual norm
145    pub residual_norm: Precision,
146    /// Number of iterations performed
147    pub iterations: usize,
148    /// Whether the solver converged
149    pub converged: bool,
150    /// Total computation time in milliseconds
151    #[cfg(feature = "std")]
152    pub computation_time_ms: f64,
153    #[cfg(not(feature = "std"))]
154    pub computation_time_ms: u64,
155    /// Performance statistics
156    pub performance_stats: OptimizedSolverStats,
157}
158
159/// Performance statistics for optimized solver.
160#[derive(Debug, Clone, Default)]
161pub struct OptimizedSolverStats {
162    /// Number of matrix-vector multiplications
163    pub matvec_count: usize,
164    /// Number of dot products computed
165    pub dot_product_count: usize,
166    /// Number of AXPY operations
167    pub axpy_count: usize,
168    /// Total floating-point operations
169    pub total_flops: usize,
170    /// Average bandwidth achieved (GB/s)
171    pub average_bandwidth_gbs: f64,
172    /// Average GFLOPS achieved
173    pub average_gflops: f64,
174}
175
176/// High-performance conjugate gradient solver with SIMD optimizations.
177pub struct OptimizedConjugateGradientSolver {
178    config: OptimizedSolverConfig,
179    stats: OptimizedSolverStats,
180}
181
182impl OptimizedConjugateGradientSolver {
183    /// Create a new optimized solver.
184    pub fn new(config: OptimizedSolverConfig) -> Self {
185        Self {
186            config,
187            stats: OptimizedSolverStats::default(),
188        }
189    }
190
191    /// Solve the linear system Ax = b using optimized conjugate gradient.
192    pub fn solve(
193        &mut self,
194        matrix: &OptimizedSparseMatrix,
195        b: &[Precision],
196    ) -> Result<OptimizedSolverResult, String> {
197        let (rows, cols) = matrix.dimensions();
198        if rows != cols {
199            return Err("Matrix must be square".to_string());
200        }
201        if b.len() != rows {
202            return Err("Right-hand side vector length must match matrix size".to_string());
203        }
204
205        #[cfg(feature = "std")]
206        let start_time = Instant::now();
207
208        // Reset statistics
209        self.stats = OptimizedSolverStats::default();
210
211        // Initialize solution and workspace vectors
212        let mut x = vec![0.0; rows];
213        let mut r = vec![0.0; rows];
214        let mut p = vec![0.0; rows];
215        let mut ap = vec![0.0; rows];
216
217        // r = b - A*x (initially r = b since x = 0)
218        r.copy_from_slice(b);
219
220        let mut iteration = 0;
221        let tolerance_sq = self.config.tolerance * self.config.tolerance;
222        let mut converged = false;
223
224        // Conjugate gradient iteration — every dot product and AXPY goes
225        // through the instrumented helpers so `performance_stats` is correct.
226        let mut rsold = self.dot_product(&r, &r);
227        p.copy_from_slice(&r);
228
229        while iteration < self.config.max_iterations {
230            if rsold <= tolerance_sq {
231                converged = true;
232                break;
233            }
234
235            // ap = A * p
236            matrix.multiply_vector(&p, &mut ap);
237            self.stats.matvec_count += 1;
238
239            // alpha = rsold / (p^T * ap)
240            let pap = self.dot_product(&p, &ap);
241
242            if pap.abs() < 1e-16 {
243                break; // Avoid division by zero
244            }
245
246            let alpha = rsold / pap;
247
248            // x = x + alpha * p
249            self.axpy(alpha, &p, &mut x);
250
251            // r = r - alpha * ap
252            self.axpy(-alpha, &ap, &mut r);
253
254            let rsnew = self.dot_product(&r, &r);
255
256            let beta = rsnew / rsold;
257
258            // p = r + beta * p  (scale p in place, then add r)
259            for (pi, &ri) in p.iter_mut().zip(r.iter()) {
260                *pi = ri + beta * *pi;
261            }
262
263            rsold = rsnew;
264            iteration += 1;
265        }
266
267        #[cfg(feature = "std")]
268        let computation_time_ms = start_time.elapsed().as_millis() as f64;
269        #[cfg(not(feature = "std"))]
270        let computation_time_ms = 0.0;
271
272        // Calculate final residual
273        let final_residual_norm = rsold.sqrt();
274
275        // Update performance statistics
276        self.stats.total_flops = self.stats.matvec_count * matrix.nnz() * 2 + iteration * rows * 6; // vector operations per iteration
277
278        if computation_time_ms > 0.0 {
279            let total_gb = (self.stats.total_flops * 8) as f64 / 1e9;
280            self.stats.average_bandwidth_gbs = total_gb / (computation_time_ms / 1000.0);
281            self.stats.average_gflops =
282                (self.stats.total_flops as f64) / (computation_time_ms * 1e6);
283        }
284
285        Ok(OptimizedSolverResult {
286            solution: x,
287            residual_norm: final_residual_norm,
288            iterations: iteration,
289            converged,
290            computation_time_ms,
291            performance_stats: self.stats.clone(),
292        })
293    }
294
295    /// Compute dot product with SIMD optimization.
296    fn dot_product(&mut self, x: &[Precision], y: &[Precision]) -> Precision {
297        self.stats.dot_product_count += 1;
298        #[cfg(feature = "simd")]
299        {
300            dot_product_simd(x, y)
301        }
302        #[cfg(not(feature = "simd"))]
303        {
304            x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum()
305        }
306    }
307
308    /// Compute AXPY operation (y = alpha * x + y) with SIMD optimization.
309    fn axpy(&mut self, alpha: Precision, x: &[Precision], y: &mut [Precision]) {
310        self.stats.axpy_count += 1;
311        #[cfg(feature = "simd")]
312        {
313            axpy_simd(alpha, x, y);
314        }
315        #[cfg(not(feature = "simd"))]
316        {
317            for (yi, &xi) in y.iter_mut().zip(x.iter()) {
318                *yi += alpha * xi;
319            }
320        }
321    }
322
323    /// Compute L2 norm of a vector.
324    fn l2_norm(&self, x: &[Precision]) -> Precision {
325        x.iter().map(|&xi| xi * xi).sum::<Precision>().sqrt()
326    }
327
328    /// Get the last iteration count.
329    pub fn get_last_iteration_count(&self) -> usize {
330        self.stats.matvec_count
331    }
332
333    /// Solve with callback for streaming results.
334    pub fn solve_with_callback<F>(
335        &mut self,
336        matrix: &OptimizedSparseMatrix,
337        b: &[Precision],
338        _chunk_size: usize,
339        mut _callback: F,
340    ) -> Result<OptimizedSolverResult, String>
341    where
342        F: FnMut(&OptimizedSolverStats),
343    {
344        // For now, just call the regular solve method
345        // In a full implementation, this would call the callback periodically
346        self.solve(matrix, b)
347    }
348}
349
350impl OptimizedSolverResult {
351    /// Get the solution data.
352    pub fn data(&self) -> &[Precision] {
353        &self.solution
354    }
355}
356
357/// Additional configuration options for the optimized solver.
358#[derive(Debug, Clone, Default)]
359pub struct OptimizedSolverOptions {
360    /// Enable detailed performance tracking
361    pub track_performance: bool,
362    /// Enable memory usage tracking
363    pub track_memory: bool,
364}
365
366#[cfg(all(test, feature = "std"))]
367mod tests {
368    use super::*;
369
370    fn create_test_matrix() -> OptimizedSparseMatrix {
371        // Create a simple 2x2 symmetric positive definite matrix
372        let triplets = vec![(0, 0, 4.0), (0, 1, 1.0), (1, 0, 1.0), (1, 1, 3.0)];
373        OptimizedSparseMatrix::from_triplets(triplets, 2, 2).unwrap()
374    }
375
376    #[test]
377    fn test_optimized_matrix_creation() {
378        let matrix = create_test_matrix();
379        assert_eq!(matrix.dimensions(), (2, 2));
380        assert_eq!(matrix.nnz(), 4);
381    }
382
383    #[test]
384    fn test_optimized_matrix_vector_multiply() {
385        let matrix = create_test_matrix();
386        let x = vec![1.0, 2.0];
387        let mut y = vec![0.0; 2];
388
389        matrix.multiply_vector(&x, &mut y);
390        assert_eq!(y, vec![6.0, 7.0]); // [4*1+1*2, 1*1+3*2]
391    }
392
393    #[test]
394    fn test_optimized_conjugate_gradient() {
395        let matrix = create_test_matrix();
396        let b = vec![1.0, 2.0];
397
398        let config = OptimizedSolverConfig::default();
399        let mut solver = OptimizedConjugateGradientSolver::new(config);
400
401        let result = solver.solve(&matrix, &b).unwrap();
402
403        assert!(result.converged);
404        assert!(result.residual_norm < 1e-6);
405        assert!(result.iterations > 0);
406
407        // Verify solution by substituting back
408        let mut ax = vec![0.0; 2];
409        matrix.multiply_vector(&result.solution, &mut ax);
410
411        let error = ((ax[0] - b[0]).powi(2) + (ax[1] - b[1]).powi(2)).sqrt();
412        assert!(error < 1e-10);
413    }
414
415    #[test]
416    fn test_solver_performance_stats() {
417        let matrix = create_test_matrix();
418        let b = vec![1.0, 2.0];
419
420        let config = OptimizedSolverConfig::default();
421        let mut solver = OptimizedConjugateGradientSolver::new(config);
422
423        let result = solver.solve(&matrix, &b).unwrap();
424
425        assert!(result.performance_stats.matvec_count > 0);
426        assert!(result.performance_stats.dot_product_count > 0);
427        assert!(result.performance_stats.total_flops > 0);
428    }
429}