1use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9#[derive(Debug, Error)]
11pub enum ConvergenceError {
12 #[error("Invalid tolerance: {0} (must be positive)")]
14 InvalidTolerance(f64),
15 #[error("Invalid damping factor: {0} (must be in [0, 1])")]
17 InvalidDamping(f64),
18 #[error("Max iterations reached: {0}")]
20 MaxIterationsReached(usize),
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ConvergenceConfig {
26 pub tolerance: f64,
28 pub max_iterations: usize,
30 pub damping_factor: f64,
32 pub patience: usize,
34}
35
36impl Default for ConvergenceConfig {
37 fn default() -> Self {
38 ConvergenceConfig {
39 tolerance: 1e-6,
40 max_iterations: 100,
41 damping_factor: 0.5,
42 patience: 3,
43 }
44 }
45}
46
47impl ConvergenceConfig {
48 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn with_tolerance(mut self, t: f64) -> Self {
55 self.tolerance = t;
56 self
57 }
58
59 pub fn with_max_iterations(mut self, n: usize) -> Self {
61 self.max_iterations = n;
62 self
63 }
64
65 pub fn with_damping(mut self, d: f64) -> Self {
67 self.damping_factor = d;
68 self
69 }
70
71 pub fn with_patience(mut self, p: usize) -> Self {
73 self.patience = p;
74 self
75 }
76
77 pub fn validate(&self) -> Result<(), ConvergenceError> {
79 if self.tolerance <= 0.0 {
80 return Err(ConvergenceError::InvalidTolerance(self.tolerance));
81 }
82 if !(0.0..=1.0).contains(&self.damping_factor) {
83 return Err(ConvergenceError::InvalidDamping(self.damping_factor));
84 }
85 Ok(())
86 }
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub enum DampingSchedule {
92 Fixed(f64),
94 Linear {
96 start: f64,
98 end: f64,
100 total_steps: usize,
102 },
103 Exponential {
105 initial: f64,
107 decay: f64,
109 },
110 Adaptive {
112 base: f64,
114 increase_rate: f64,
116 decrease_rate: f64,
118 },
119}
120
121impl DampingSchedule {
122 pub fn get_damping(
130 &self,
131 iteration: usize,
132 prev_residual: Option<f64>,
133 curr_residual: Option<f64>,
134 current_damping: f64,
135 ) -> f64 {
136 match self {
137 DampingSchedule::Fixed(d) => *d,
138 DampingSchedule::Linear {
139 start,
140 end,
141 total_steps,
142 } => {
143 if *total_steps == 0 {
144 return *start;
145 }
146 let frac = (iteration as f64 / *total_steps as f64).min(1.0);
147 start + frac * (end - start)
148 }
149 DampingSchedule::Exponential { initial, decay } => {
150 initial * decay.powi(iteration as i32)
151 }
152 DampingSchedule::Adaptive {
153 base,
154 increase_rate,
155 decrease_rate,
156 } => match (prev_residual, curr_residual) {
157 (Some(prev), Some(curr)) if curr > prev => {
158 (current_damping + increase_rate).min(0.99)
160 }
161 (Some(_prev), Some(_curr)) => {
162 (current_damping - decrease_rate).max(*base)
164 }
165 _ => current_damping,
166 },
167 }
168 }
169}
170
171#[derive(Debug, Clone)]
173pub struct ConvergenceState {
174 pub iteration: usize,
176 pub converged: bool,
178 pub diverged: bool,
180 pub residual_history: Vec<f64>,
182 pub damping_history: Vec<f64>,
184 pub consecutive_converged: usize,
186}
187
188impl ConvergenceState {
189 pub fn new() -> Self {
191 ConvergenceState {
192 iteration: 0,
193 converged: false,
194 diverged: false,
195 residual_history: Vec::new(),
196 damping_history: Vec::new(),
197 consecutive_converged: 0,
198 }
199 }
200
201 pub fn latest_residual(&self) -> Option<f64> {
203 self.residual_history.last().copied()
204 }
205
206 pub fn convergence_rate(&self) -> Option<f64> {
210 if self.residual_history.len() < 2 {
211 return None;
212 }
213 let n = self.residual_history.len();
214 let r0 = self.residual_history[n - 2];
215 let r1 = self.residual_history[n - 1];
216 if r0 > 1e-15 {
217 Some(r1 / r0)
218 } else {
219 Some(0.0)
220 }
221 }
222}
223
224impl Default for ConvergenceState {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230pub struct ConvergenceMonitor {
234 config: ConvergenceConfig,
235 schedule: DampingSchedule,
236 state: ConvergenceState,
237 current_damping: f64,
238}
239
240impl ConvergenceMonitor {
241 pub fn new(
243 config: ConvergenceConfig,
244 schedule: DampingSchedule,
245 ) -> Result<Self, ConvergenceError> {
246 config.validate()?;
247 let initial_damping = config.damping_factor;
248 Ok(ConvergenceMonitor {
249 config,
250 schedule,
251 state: ConvergenceState::new(),
252 current_damping: initial_damping,
253 })
254 }
255
256 pub fn with_default_config() -> Self {
258 let config = ConvergenceConfig::default();
259 let damping = config.damping_factor;
260 let schedule = DampingSchedule::Fixed(damping);
261 ConvergenceMonitor {
262 config,
263 schedule,
264 state: ConvergenceState::new(),
265 current_damping: damping,
266 }
267 }
268
269 pub fn record_iteration(&mut self, residual: f64) -> bool {
274 let prev_residual = self.state.latest_residual();
275 self.state.iteration += 1;
276 self.state.residual_history.push(residual);
277
278 self.current_damping = self.schedule.get_damping(
280 self.state.iteration,
281 prev_residual,
282 Some(residual),
283 self.current_damping,
284 );
285 self.state.damping_history.push(self.current_damping);
286
287 if residual < self.config.tolerance {
289 self.state.consecutive_converged += 1;
290 if self.state.consecutive_converged >= self.config.patience {
291 self.state.converged = true;
292 return false;
293 }
294 } else {
295 self.state.consecutive_converged = 0;
296 }
297
298 if self.state.residual_history.len() >= 5 {
300 let recent = &self.state.residual_history[self.state.residual_history.len() - 5..];
301 let diverging = recent.windows(2).all(|w| w[1] > w[0]);
302 if diverging {
303 self.state.diverged = true;
304 return false;
305 }
306 }
307
308 if self.state.iteration >= self.config.max_iterations {
310 return false;
311 }
312
313 true
314 }
315
316 pub fn current_damping(&self) -> f64 {
318 self.current_damping
319 }
320
321 pub fn state(&self) -> &ConvergenceState {
323 &self.state
324 }
325
326 pub fn is_converged(&self) -> bool {
328 self.state.converged
329 }
330
331 pub fn is_diverged(&self) -> bool {
333 self.state.diverged
334 }
335
336 pub fn iteration(&self) -> usize {
338 self.state.iteration
339 }
340
341 pub fn reset(&mut self) {
343 self.state = ConvergenceState::new();
344 self.current_damping = self.config.damping_factor;
345 }
346
347 pub fn stats(&self) -> InferenceStats {
349 InferenceStats {
350 total_iterations: self.state.iteration,
351 final_residual: self.state.latest_residual().unwrap_or(f64::NAN),
352 converged: self.state.converged,
353 diverged: self.state.diverged,
354 convergence_rate: self.state.convergence_rate(),
355 final_damping: self.current_damping,
356 }
357 }
358}
359
360#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct InferenceStats {
363 pub total_iterations: usize,
365 pub final_residual: f64,
367 pub converged: bool,
369 pub diverged: bool,
371 pub convergence_rate: Option<f64>,
373 pub final_damping: f64,
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_config_default() {
383 let config = ConvergenceConfig::default();
384 assert!((config.tolerance - 1e-6).abs() < 1e-15);
385 assert_eq!(config.max_iterations, 100);
386 assert!((config.damping_factor - 0.5).abs() < 1e-15);
387 assert_eq!(config.patience, 3);
388 }
389
390 #[test]
391 fn test_config_validate_good() {
392 let config = ConvergenceConfig::default();
393 assert!(config.validate().is_ok());
394 }
395
396 #[test]
397 fn test_config_validate_bad_tolerance() {
398 let config = ConvergenceConfig::new().with_tolerance(0.0);
399 let err = config.validate().unwrap_err();
400 assert!(matches!(err, ConvergenceError::InvalidTolerance(_)));
401 }
402
403 #[test]
404 fn test_config_validate_bad_damping() {
405 let config = ConvergenceConfig::new().with_damping(2.0);
406 let err = config.validate().unwrap_err();
407 assert!(matches!(err, ConvergenceError::InvalidDamping(_)));
408 }
409
410 #[test]
411 fn test_config_builder() {
412 let config = ConvergenceConfig::new()
413 .with_tolerance(1e-4)
414 .with_max_iterations(50)
415 .with_damping(0.3)
416 .with_patience(5);
417 assert!((config.tolerance - 1e-4).abs() < 1e-15);
418 assert_eq!(config.max_iterations, 50);
419 assert!((config.damping_factor - 0.3).abs() < 1e-15);
420 assert_eq!(config.patience, 5);
421 }
422
423 #[test]
424 fn test_damping_fixed() {
425 let schedule = DampingSchedule::Fixed(0.7);
426 assert!((schedule.get_damping(0, None, None, 0.5) - 0.7).abs() < 1e-15);
427 assert!((schedule.get_damping(10, Some(0.1), Some(0.05), 0.5) - 0.7).abs() < 1e-15);
428 assert!((schedule.get_damping(100, None, None, 0.9) - 0.7).abs() < 1e-15);
429 }
430
431 #[test]
432 fn test_damping_linear() {
433 let schedule = DampingSchedule::Linear {
434 start: 0.8,
435 end: 0.2,
436 total_steps: 10,
437 };
438 assert!((schedule.get_damping(0, None, None, 0.0) - 0.8).abs() < 1e-15);
440 assert!((schedule.get_damping(5, None, None, 0.0) - 0.5).abs() < 1e-15);
442 assert!((schedule.get_damping(10, None, None, 0.0) - 0.2).abs() < 1e-15);
444 assert!((schedule.get_damping(20, None, None, 0.0) - 0.2).abs() < 1e-15);
446 }
447
448 #[test]
449 fn test_damping_exponential() {
450 let schedule = DampingSchedule::Exponential {
451 initial: 1.0,
452 decay: 0.5,
453 };
454 assert!((schedule.get_damping(0, None, None, 0.0) - 1.0).abs() < 1e-15);
456 assert!((schedule.get_damping(1, None, None, 0.0) - 0.5).abs() < 1e-15);
458 assert!((schedule.get_damping(2, None, None, 0.0) - 0.25).abs() < 1e-15);
460 }
461
462 #[test]
463 fn test_damping_adaptive_increases_on_diverge() {
464 let schedule = DampingSchedule::Adaptive {
465 base: 0.1,
466 increase_rate: 0.1,
467 decrease_rate: 0.05,
468 };
469 let result = schedule.get_damping(1, Some(0.5), Some(0.8), 0.4);
471 assert!((result - 0.5).abs() < 1e-15); }
473
474 #[test]
475 fn test_damping_adaptive_decreases_on_converge() {
476 let schedule = DampingSchedule::Adaptive {
477 base: 0.1,
478 increase_rate: 0.1,
479 decrease_rate: 0.05,
480 };
481 let result = schedule.get_damping(1, Some(0.8), Some(0.5), 0.4);
483 assert!((result - 0.35).abs() < 1e-15); }
485
486 #[test]
487 fn test_monitor_converges() {
488 let config = ConvergenceConfig::new()
489 .with_tolerance(1e-3)
490 .with_patience(2);
491 let monitor_result = ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.5));
492 assert!(monitor_result.is_ok());
493 let mut monitor = monitor_result.expect("valid config");
494
495 assert!(monitor.record_iteration(1.0));
497 assert!(monitor.record_iteration(0.1));
498 assert!(monitor.record_iteration(0.0009)); assert!(!monitor.record_iteration(0.0005));
501
502 assert!(monitor.is_converged());
503 assert!(!monitor.is_diverged());
504 }
505
506 #[test]
507 fn test_monitor_patience() {
508 let config = ConvergenceConfig::new()
509 .with_tolerance(1e-3)
510 .with_patience(3);
511 let mut monitor =
512 ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.5)).expect("valid config");
513
514 assert!(monitor.record_iteration(0.0001)); assert!(monitor.record_iteration(0.0002)); assert!(monitor.record_iteration(0.01)); assert!(monitor.record_iteration(0.0001)); assert!(monitor.record_iteration(0.0002)); assert!(!monitor.record_iteration(0.0003)); assert!(monitor.is_converged());
523 }
524
525 #[test]
526 fn test_monitor_max_iterations() {
527 let config = ConvergenceConfig::new()
528 .with_tolerance(1e-10)
529 .with_max_iterations(5);
530 let mut monitor =
531 ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.5)).expect("valid config");
532
533 for i in 0..4 {
535 let residual = 1.0 / (i as f64 + 1.0);
536 assert!(monitor.record_iteration(residual), "iteration {i}");
537 }
538 assert!(!monitor.record_iteration(0.1));
540 assert!(!monitor.is_converged());
541 assert_eq!(monitor.iteration(), 5);
542 }
543
544 #[test]
545 fn test_monitor_diverge_detection() {
546 let config = ConvergenceConfig::new()
547 .with_tolerance(1e-10)
548 .with_max_iterations(100);
549 let mut monitor =
550 ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.5)).expect("valid config");
551
552 assert!(monitor.record_iteration(1.0));
554 assert!(monitor.record_iteration(2.0));
555 assert!(monitor.record_iteration(3.0));
556 assert!(monitor.record_iteration(4.0));
557 assert!(!monitor.record_iteration(5.0));
559
560 assert!(monitor.is_diverged());
561 assert!(!monitor.is_converged());
562 }
563
564 #[test]
565 fn test_monitor_reset() {
566 let config = ConvergenceConfig::new()
567 .with_tolerance(1e-3)
568 .with_patience(1);
569 let mut monitor =
570 ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.5)).expect("valid config");
571
572 assert!(!monitor.record_iteration(0.0001));
574 assert!(monitor.is_converged());
575 assert_eq!(monitor.iteration(), 1);
576
577 monitor.reset();
579 assert!(!monitor.is_converged());
580 assert!(!monitor.is_diverged());
581 assert_eq!(monitor.iteration(), 0);
582 assert!(monitor.state().residual_history.is_empty());
583 }
584
585 #[test]
586 fn test_monitor_stats() {
587 let config = ConvergenceConfig::new()
588 .with_tolerance(1e-3)
589 .with_patience(2);
590 let mut monitor =
591 ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.3)).expect("valid config");
592
593 monitor.record_iteration(0.5);
594 monitor.record_iteration(0.0001);
595 monitor.record_iteration(0.00005);
596
597 let stats = monitor.stats();
598 assert_eq!(stats.total_iterations, 3);
599 assert!((stats.final_residual - 0.00005).abs() < 1e-15);
600 assert!(stats.converged);
601 assert!(!stats.diverged);
602 assert!((stats.final_damping - 0.3).abs() < 1e-15);
603 assert!(stats.convergence_rate.is_some());
604 }
605
606 #[test]
607 fn test_convergence_rate() {
608 let mut state = ConvergenceState::new();
609 assert!(state.convergence_rate().is_none());
611
612 state.residual_history.push(1.0);
614 assert!(state.convergence_rate().is_none());
615
616 state.residual_history.push(0.5);
618 let rate = state.convergence_rate().expect("should have rate");
619 assert!((rate - 0.5).abs() < 1e-15);
620 }
621
622 #[test]
623 fn test_state_default() {
624 let state = ConvergenceState::default();
625 assert_eq!(state.iteration, 0);
626 assert!(!state.converged);
627 assert!(!state.diverged);
628 assert!(state.residual_history.is_empty());
629 assert!(state.damping_history.is_empty());
630 assert_eq!(state.consecutive_converged, 0);
631 }
632}