scirs2_sparse/learned_smoother/types.rs
1//! Types and configuration for learned smoothers.
2//!
3//! Defines configuration structures, smoother weight storage, training
4//! parameters, and convergence metrics used across all smoother variants.
5
6use crate::error::SparseResult;
7
8// ---------------------------------------------------------------------------
9// Smoother type selection
10// ---------------------------------------------------------------------------
11
12/// Selects the learned smoother variant.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14#[non_exhaustive]
15#[derive(Default)]
16pub enum SmootherType {
17 /// Parametric linear smoother: x_{k+1} = x_k + W · r_k
18 #[default]
19 Linear,
20 /// Per-node 2-layer MLP with shared weights (GNN-style).
21 MLP,
22 /// Chebyshev polynomial smoother with learned coefficients.
23 Chebyshev,
24}
25
26// ---------------------------------------------------------------------------
27// Main configuration
28// ---------------------------------------------------------------------------
29
30/// Configuration for a learned smoother.
31#[derive(Debug, Clone)]
32pub struct LearnedSmootherConfig {
33 /// Which smoother variant to use.
34 pub smoother_type: SmootherType,
35 /// Learning rate for weight updates.
36 pub learning_rate: f64,
37 /// Maximum number of training steps.
38 pub max_training_steps: usize,
39 /// Convergence tolerance for training.
40 pub convergence_tol: f64,
41 /// Relaxation parameter (omega) for initialisation.
42 pub omega: f64,
43 /// Number of pre-smoothing sweeps in a V-cycle.
44 pub pre_sweeps: usize,
45 /// Number of post-smoothing sweeps in a V-cycle.
46 pub post_sweeps: usize,
47}
48
49impl Default for LearnedSmootherConfig {
50 fn default() -> Self {
51 Self {
52 smoother_type: SmootherType::Linear,
53 learning_rate: 0.01,
54 max_training_steps: 1000,
55 convergence_tol: 1e-6,
56 omega: 2.0 / 3.0,
57 pre_sweeps: 2,
58 post_sweeps: 2,
59 }
60 }
61}
62
63// ---------------------------------------------------------------------------
64// Training configuration
65// ---------------------------------------------------------------------------
66
67/// Training hyper-parameters for gradient-descent-based weight learning.
68#[derive(Debug, Clone)]
69pub struct TrainingConfig {
70 /// Learning rate for SGD / Adam.
71 pub learning_rate: f64,
72 /// Maximum number of training epochs.
73 pub max_epochs: usize,
74 /// Mini-batch size (number of right-hand-side vectors per step).
75 pub batch_size: usize,
76 /// Early-stopping tolerance on loss decrease.
77 pub convergence_tol: f64,
78 /// Momentum coefficient (0 = vanilla SGD).
79 pub momentum: f64,
80}
81
82impl Default for TrainingConfig {
83 fn default() -> Self {
84 Self {
85 learning_rate: 0.01,
86 max_epochs: 100,
87 batch_size: 16,
88 convergence_tol: 1e-8,
89 momentum: 0.0,
90 }
91 }
92}
93
94// ---------------------------------------------------------------------------
95// Smoother weights
96// ---------------------------------------------------------------------------
97
98/// Stores the trainable parameters of a learned smoother.
99///
100/// For a linear smoother the weights are a diagonal (or dense) matrix W;
101/// for an MLP smoother the weights are the layer parameters.
102#[derive(Debug, Clone)]
103pub struct SmootherWeights {
104 /// Layer-wise weight matrices stored row-major.
105 pub matrices: Vec<Vec<f64>>,
106 /// Layer-wise bias vectors.
107 pub biases: Vec<Vec<f64>>,
108 /// Dimensions of each layer: `(input_dim, output_dim)`.
109 pub layer_dims: Vec<(usize, usize)>,
110}
111
112impl SmootherWeights {
113 /// Create an empty weight container with no layers.
114 pub fn empty() -> Self {
115 Self {
116 matrices: Vec::new(),
117 biases: Vec::new(),
118 layer_dims: Vec::new(),
119 }
120 }
121
122 /// Create weights for a single diagonal layer of size `n`.
123 pub fn diagonal(diag: Vec<f64>) -> Self {
124 let n = diag.len();
125 Self {
126 matrices: vec![diag],
127 biases: vec![vec![0.0; n]],
128 layer_dims: vec![(n, n)],
129 }
130 }
131
132 /// Total number of trainable scalar parameters.
133 pub fn num_parameters(&self) -> usize {
134 let mat_params: usize = self.matrices.iter().map(|m| m.len()).sum();
135 let bias_params: usize = self.biases.iter().map(|b| b.len()).sum();
136 mat_params + bias_params
137 }
138}
139
140// ---------------------------------------------------------------------------
141// Convergence metrics
142// ---------------------------------------------------------------------------
143
144/// Metrics produced after a smoothing / solve run.
145#[derive(Debug, Clone)]
146pub struct SmootherMetrics {
147 /// Estimated spectral radius of the error propagation operator (I - WA).
148 pub spectral_radius_reduction: f64,
149 /// Convergence factor: ratio ‖e_{k+1}‖ / ‖e_k‖ averaged over iterations.
150 pub convergence_factor: f64,
151 /// Ratio ‖r_final‖ / ‖r_0‖ of the residual norms.
152 pub residual_reduction: f64,
153 /// Per-epoch (or per-step) training loss history.
154 pub training_loss_history: Vec<f64>,
155}
156
157impl SmootherMetrics {
158 /// Create a default/empty metrics struct.
159 pub fn new() -> Self {
160 Self {
161 spectral_radius_reduction: 1.0,
162 convergence_factor: 1.0,
163 residual_reduction: 1.0,
164 training_loss_history: Vec::new(),
165 }
166 }
167}
168
169impl Default for SmootherMetrics {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175// ---------------------------------------------------------------------------
176// Smoother trait
177// ---------------------------------------------------------------------------
178
179/// Trait for smoothers that can be plugged into a multigrid cycle.
180pub trait Smoother {
181 /// Apply `n_sweeps` smoothing iterations: update `x` in-place.
182 ///
183 /// The matrix is given in raw CSR form (values, row_ptr, col_idx).
184 fn smooth(
185 &self,
186 a_values: &[f64],
187 a_row_ptr: &[usize],
188 a_col_idx: &[usize],
189 x: &mut [f64],
190 b: &[f64],
191 n_sweeps: usize,
192 ) -> SparseResult<()>;
193
194 /// Perform one training step and return the loss.
195 fn train_step(
196 &mut self,
197 a_values: &[f64],
198 a_row_ptr: &[usize],
199 a_col_idx: &[usize],
200 x: &mut [f64],
201 b: &[f64],
202 x_exact: &[f64],
203 lr: f64,
204 ) -> SparseResult<f64>;
205}