Skip to main content

sublinear_solver/
types.rs

1//! Common types and type aliases used throughout the solver.
2//!
3//! This module defines fundamental types for numerical computations,
4//! graph operations, and solver configuration.
5
6use alloc::{string::String, vec::Vec};
7use core::fmt;
8
9/// Node identifier for graph-based algorithms.
10pub type NodeId = u32;
11
12/// Edge identifier for graph operations.
13pub type EdgeId = u32;
14
15/// Floating-point precision type.
16///
17/// Currently fixed to f64 for numerical stability, but may be
18/// parameterized in future versions for memory optimization.
19pub type Precision = f64;
20
21/// Integer type for array indices and counts.
22pub type IndexType = u32;
23
24/// Type for storing matrix/vector dimensions.
25pub type DimensionType = usize;
26
27/// Convergence detection modes for iterative solvers.
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30pub enum ConvergenceMode {
31    /// Check residual norm: ||Ax - b|| < tolerance
32    ResidualNorm,
33    /// Check relative residual: ||Ax - b|| / ||b|| < tolerance
34    RelativeResidual,
35    /// Check solution change: ||x_new - x_old|| < tolerance
36    SolutionChange,
37    /// Check relative solution change: ||x_new - x_old|| / ||x_old|| < tolerance
38    RelativeSolutionChange,
39    /// Use multiple criteria (most conservative)
40    Combined,
41}
42
43/// Vector norm types for error measurement.
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub enum NormType {
47    /// L1 norm (sum of absolute values)
48    L1,
49    /// L2 norm (Euclidean norm)
50    L2,
51    /// L∞ norm (maximum absolute value)
52    LInfinity,
53    /// Weighted norm with custom weights
54    Weighted,
55}
56
57/// Error bounds for approximate solutions.
58#[derive(Debug, Clone, PartialEq)]
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60pub struct ErrorBounds {
61    /// Lower bound on the true error
62    pub lower_bound: Precision,
63    /// Upper bound on the true error
64    pub upper_bound: Precision,
65    /// Confidence level (0.0 to 1.0) for probabilistic bounds
66    pub confidence: Option<Precision>,
67    /// Method used to compute the bounds
68    pub method: ErrorBoundMethod,
69}
70
71/// Methods for computing error bounds.
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
74pub enum ErrorBoundMethod {
75    /// Deterministic bounds based on matrix properties
76    Deterministic,
77    /// Probabilistic bounds from random sampling
78    Probabilistic,
79    /// Adaptive bounds that tighten during iteration
80    Adaptive,
81    /// Bounds from Neumann series truncation analysis
82    NeumannTruncation,
83}
84
85/// Comprehensive statistics about solver execution.
86#[derive(Debug, Clone, PartialEq)]
87#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
88pub struct SolverStats {
89    /// Total wall-clock time for solving
90    pub total_time_ms: f64,
91    /// Time spent in matrix operations
92    pub matrix_ops_time_ms: f64,
93    /// Time spent in convergence checking
94    pub convergence_check_time_ms: f64,
95    /// Number of matrix-vector multiplications performed
96    pub matvec_count: usize,
97    /// Number of vector operations (add, scale, etc.)
98    pub vector_ops_count: usize,
99    /// Peak memory usage in bytes
100    pub peak_memory_bytes: usize,
101    /// Number of cache misses (if available)
102    pub cache_misses: Option<usize>,
103    /// FLOPS (floating-point operations per second) achieved
104    pub flops: Option<f64>,
105    /// Whether SIMD optimizations were used
106    pub simd_used: bool,
107    /// Number of parallel threads used
108    pub thread_count: usize,
109}
110
111/// Matrix sparsity pattern information.
112#[derive(Debug, Clone, PartialEq)]
113#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
114pub struct SparsityInfo {
115    /// Total number of non-zero elements
116    pub nnz: usize,
117    /// Matrix dimensions (rows, cols)
118    pub dimensions: (DimensionType, DimensionType),
119    /// Sparsity ratio (nnz / (rows * cols))
120    pub sparsity_ratio: Precision,
121    /// Average number of non-zeros per row
122    pub avg_nnz_per_row: Precision,
123    /// Maximum number of non-zeros in any row
124    pub max_nnz_per_row: usize,
125    /// Bandwidth of the matrix
126    pub bandwidth: Option<usize>,
127    /// Whether the matrix has a banded structure
128    pub is_banded: bool,
129}
130
131/// Graph connectivity information for push algorithms.
132#[derive(Debug, Clone, PartialEq)]
133#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
134pub struct GraphInfo {
135    /// Number of nodes in the graph
136    pub node_count: usize,
137    /// Number of edges in the graph
138    pub edge_count: usize,
139    /// Average degree (edges per node)
140    pub avg_degree: Precision,
141    /// Maximum degree in the graph
142    pub max_degree: usize,
143    /// Graph diameter (longest shortest path)
144    pub diameter: Option<usize>,
145    /// Whether the graph is strongly connected
146    pub is_strongly_connected: bool,
147    /// Number of strongly connected components
148    pub scc_count: usize,
149}
150
151/// Matrix conditioning information.
152#[derive(Debug, Clone, PartialEq)]
153#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
154pub struct ConditioningInfo {
155    /// Estimated condition number
156    pub condition_number: Option<Precision>,
157    /// Whether matrix is diagonally dominant
158    pub is_diagonally_dominant: bool,
159    /// Diagonal dominance factor (minimum ratio)
160    pub diagonal_dominance_factor: Option<Precision>,
161    /// Spectral radius estimate
162    pub spectral_radius: Option<Precision>,
163    /// Whether matrix is positive definite
164    pub is_positive_definite: Option<bool>,
165}
166
167/// Algorithm selection hints based on problem characteristics.
168#[derive(Debug, Clone, PartialEq)]
169#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
170pub struct AlgorithmHints {
171    /// Recommended primary algorithm
172    pub primary_algorithm: String,
173    /// Alternative algorithms in order of preference
174    pub alternative_algorithms: Vec<String>,
175    /// Confidence in the recommendation (0.0 to 1.0)
176    pub confidence: Precision,
177    /// Reasoning for the recommendation
178    pub reasoning: Vec<String>,
179}
180
181/// Update operation for incremental solving.
182#[derive(Debug, Clone, PartialEq)]
183#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
184pub struct DeltaUpdate {
185    /// Indices of updated elements
186    pub indices: Vec<IndexType>,
187    /// New values for the updated elements
188    pub values: Vec<Precision>,
189    /// Timestamp of the update
190    pub timestamp: u64,
191    /// Update sequence number for ordering
192    pub sequence_number: u64,
193}
194
195/// Streaming solution chunk for real-time applications.
196#[derive(Debug, Clone, PartialEq)]
197#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
198pub struct SolutionChunk {
199    /// Iteration number when this chunk was produced
200    pub iteration: usize,
201    /// Partial solution values (sparse representation)
202    pub values: Vec<(IndexType, Precision)>,
203    /// Current residual norm
204    pub residual_norm: Precision,
205    /// Whether the solution has converged
206    pub converged: bool,
207    /// Estimated remaining iterations
208    pub estimated_remaining_iterations: Option<usize>,
209    /// Timestamp when chunk was generated
210    pub timestamp: u64,
211}
212
213/// Memory usage tracking information.
214#[derive(Debug, Clone, PartialEq)]
215#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
216pub struct MemoryInfo {
217    /// Current memory usage in bytes
218    pub current_usage_bytes: usize,
219    /// Peak memory usage in bytes
220    pub peak_usage_bytes: usize,
221    /// Memory allocated for matrix storage
222    pub matrix_memory_bytes: usize,
223    /// Memory allocated for vectors
224    pub vector_memory_bytes: usize,
225    /// Memory allocated for temporary workspace
226    pub workspace_memory_bytes: usize,
227    /// Number of memory allocations
228    pub allocation_count: usize,
229    /// Number of memory deallocations
230    pub deallocation_count: usize,
231}
232
233/// Performance profiling data.
234#[derive(Debug, Clone, PartialEq)]
235#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
236pub struct ProfileData {
237    /// Function name or operation description
238    pub operation: String,
239    /// Number of times this operation was called
240    pub call_count: usize,
241    /// Total time spent in this operation (microseconds)
242    pub total_time_us: u64,
243    /// Average time per call (microseconds)
244    pub avg_time_us: f64,
245    /// Minimum time for a single call (microseconds)
246    pub min_time_us: u64,
247    /// Maximum time for a single call (microseconds)
248    pub max_time_us: u64,
249    /// Percentage of total execution time
250    pub time_percentage: f64,
251}
252
253impl ErrorBounds {
254    /// Create error bounds with only an upper bound.
255    pub fn upper_bound_only(upper: Precision, method: ErrorBoundMethod) -> Self {
256        Self {
257            lower_bound: 0.0,
258            upper_bound: upper,
259            confidence: None,
260            method,
261        }
262    }
263
264    /// Create deterministic error bounds.
265    pub fn deterministic(lower: Precision, upper: Precision) -> Self {
266        Self {
267            lower_bound: lower,
268            upper_bound: upper,
269            confidence: None,
270            method: ErrorBoundMethod::Deterministic,
271        }
272    }
273
274    /// Create probabilistic error bounds with confidence level.
275    pub fn probabilistic(lower: Precision, upper: Precision, confidence: Precision) -> Self {
276        Self {
277            lower_bound: lower,
278            upper_bound: upper,
279            confidence: Some(confidence.clamp(0.0, 1.0)),
280            method: ErrorBoundMethod::Probabilistic,
281        }
282    }
283
284    /// Check if the bounds are valid (lower <= upper).
285    pub fn is_valid(&self) -> bool {
286        self.lower_bound <= self.upper_bound && self.lower_bound >= 0.0 && self.upper_bound >= 0.0
287    }
288
289    /// Get the width of the error bounds.
290    pub fn width(&self) -> Precision {
291        self.upper_bound - self.lower_bound
292    }
293
294    /// Get the midpoint of the error bounds.
295    pub fn midpoint(&self) -> Precision {
296        (self.lower_bound + self.upper_bound) / 2.0
297    }
298}
299
300impl SolverStats {
301    /// Create a new empty statistics object.
302    pub fn new() -> Self {
303        Self {
304            total_time_ms: 0.0,
305            matrix_ops_time_ms: 0.0,
306            convergence_check_time_ms: 0.0,
307            matvec_count: 0,
308            vector_ops_count: 0,
309            peak_memory_bytes: 0,
310            cache_misses: None,
311            flops: None,
312            simd_used: false,
313            thread_count: 1,
314        }
315    }
316
317    /// Calculate matrix operations percentage of total time.
318    pub fn matrix_ops_percentage(&self) -> f64 {
319        if self.total_time_ms > 0.0 {
320            (self.matrix_ops_time_ms / self.total_time_ms) * 100.0
321        } else {
322            0.0
323        }
324    }
325
326    /// Calculate convergence checking percentage of total time.
327    pub fn convergence_percentage(&self) -> f64 {
328        if self.total_time_ms > 0.0 {
329            (self.convergence_check_time_ms / self.total_time_ms) * 100.0
330        } else {
331            0.0
332        }
333    }
334}
335
336impl Default for SolverStats {
337    fn default() -> Self {
338        Self::new()
339    }
340}
341
342impl SparsityInfo {
343    /// Create sparsity information from basic matrix data.
344    pub fn new(nnz: usize, rows: DimensionType, cols: DimensionType) -> Self {
345        let total_elements = rows * cols;
346        let sparsity_ratio = if total_elements > 0 {
347            nnz as Precision / total_elements as Precision
348        } else {
349            0.0
350        };
351
352        let avg_nnz_per_row = if rows > 0 {
353            nnz as Precision / rows as Precision
354        } else {
355            0.0
356        };
357
358        Self {
359            nnz,
360            dimensions: (rows, cols),
361            sparsity_ratio,
362            avg_nnz_per_row,
363            max_nnz_per_row: 0, // To be computed separately
364            bandwidth: None,
365            is_banded: false,
366        }
367    }
368
369    /// Check if the matrix is considered sparse (< 10% non-zero).
370    pub fn is_sparse(&self) -> bool {
371        self.sparsity_ratio < 0.1
372    }
373
374    /// Check if the matrix is very sparse (< 1% non-zero).
375    pub fn is_very_sparse(&self) -> bool {
376        self.sparsity_ratio < 0.01
377    }
378}
379
380impl fmt::Display for ConvergenceMode {
381    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382        match self {
383            ConvergenceMode::ResidualNorm => write!(f, "residual_norm"),
384            ConvergenceMode::RelativeResidual => write!(f, "relative_residual"),
385            ConvergenceMode::SolutionChange => write!(f, "solution_change"),
386            ConvergenceMode::RelativeSolutionChange => write!(f, "relative_solution_change"),
387            ConvergenceMode::Combined => write!(f, "combined"),
388        }
389    }
390}
391
392impl fmt::Display for NormType {
393    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
394        match self {
395            NormType::L1 => write!(f, "L1"),
396            NormType::L2 => write!(f, "L2"),
397            NormType::LInfinity => write!(f, "L∞"),
398            NormType::Weighted => write!(f, "weighted"),
399        }
400    }
401}
402
403#[cfg(all(test, feature = "std"))]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_error_bounds_validity() {
409        let valid_bounds = ErrorBounds::deterministic(1.0, 2.0);
410        assert!(valid_bounds.is_valid());
411        assert_eq!(valid_bounds.width(), 1.0);
412        assert_eq!(valid_bounds.midpoint(), 1.5);
413
414        let invalid_bounds = ErrorBounds {
415            lower_bound: 2.0,
416            upper_bound: 1.0,
417            confidence: None,
418            method: ErrorBoundMethod::Deterministic,
419        };
420        assert!(!invalid_bounds.is_valid());
421    }
422
423    #[test]
424    fn test_sparsity_info() {
425        let info = SparsityInfo::new(100, 1000, 1000);
426        assert_eq!(info.sparsity_ratio, 0.0001);
427        assert!(info.is_very_sparse());
428        assert!(info.is_sparse());
429        assert_eq!(info.avg_nnz_per_row, 0.1);
430    }
431
432    #[test]
433    fn test_solver_stats_percentages() {
434        let mut stats = SolverStats::new();
435        stats.total_time_ms = 100.0;
436        stats.matrix_ops_time_ms = 60.0;
437        stats.convergence_check_time_ms = 10.0;
438
439        assert_eq!(stats.matrix_ops_percentage(), 60.0);
440        assert_eq!(stats.convergence_percentage(), 10.0);
441    }
442}