Skip to main content

tensorlogic_quantrs_hooks/
convergence.rs

1//! Convergence monitoring and adaptive damping for iterative inference.
2//!
3//! Tracks message residuals during belief propagation iterations,
4//! detects convergence/divergence, and provides adaptive damping schedules.
5
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9/// Errors related to convergence monitoring configuration and execution.
10#[derive(Debug, Error)]
11pub enum ConvergenceError {
12    /// Tolerance must be a positive value.
13    #[error("Invalid tolerance: {0} (must be positive)")]
14    InvalidTolerance(f64),
15    /// Damping factor must be within the range [0, 1].
16    #[error("Invalid damping factor: {0} (must be in [0, 1])")]
17    InvalidDamping(f64),
18    /// The algorithm did not converge within the allowed iterations.
19    #[error("Max iterations reached: {0}")]
20    MaxIterationsReached(usize),
21}
22
23/// Configuration for convergence monitoring.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ConvergenceConfig {
26    /// Convergence tolerance (residual below this means converged).
27    pub tolerance: f64,
28    /// Maximum iterations before declaring non-convergence.
29    pub max_iterations: usize,
30    /// Initial damping factor (0 = no damping, 1 = full damping).
31    pub damping_factor: f64,
32    /// Number of consecutive converged iterations before declaring convergence.
33    pub patience: usize,
34}
35
36impl Default for ConvergenceConfig {
37    fn default() -> Self {
38        ConvergenceConfig {
39            tolerance: 1e-6,
40            max_iterations: 100,
41            damping_factor: 0.5,
42            patience: 3,
43        }
44    }
45}
46
47impl ConvergenceConfig {
48    /// Create a new configuration with default values.
49    pub fn new() -> Self {
50        Self::default()
51    }
52
53    /// Set the convergence tolerance.
54    pub fn with_tolerance(mut self, t: f64) -> Self {
55        self.tolerance = t;
56        self
57    }
58
59    /// Set the maximum number of iterations.
60    pub fn with_max_iterations(mut self, n: usize) -> Self {
61        self.max_iterations = n;
62        self
63    }
64
65    /// Set the initial damping factor.
66    pub fn with_damping(mut self, d: f64) -> Self {
67        self.damping_factor = d;
68        self
69    }
70
71    /// Set the patience (consecutive converged iterations required).
72    pub fn with_patience(mut self, p: usize) -> Self {
73        self.patience = p;
74        self
75    }
76
77    /// Validate the configuration parameters.
78    pub fn validate(&self) -> Result<(), ConvergenceError> {
79        if self.tolerance <= 0.0 {
80            return Err(ConvergenceError::InvalidTolerance(self.tolerance));
81        }
82        if !(0.0..=1.0).contains(&self.damping_factor) {
83            return Err(ConvergenceError::InvalidDamping(self.damping_factor));
84        }
85        Ok(())
86    }
87}
88
89/// Damping schedule types for controlling how damping evolves over iterations.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub enum DampingSchedule {
92    /// Fixed damping throughout all iterations.
93    Fixed(f64),
94    /// Linear interpolation from `start` to `end` over `total_steps`.
95    Linear {
96        /// Starting damping value.
97        start: f64,
98        /// Ending damping value.
99        end: f64,
100        /// Total steps over which to interpolate.
101        total_steps: usize,
102    },
103    /// Exponential decay: `initial * decay^iteration`.
104    Exponential {
105        /// Initial damping value.
106        initial: f64,
107        /// Decay rate per iteration.
108        decay: f64,
109    },
110    /// Increase damping when residual grows, decrease when stable.
111    Adaptive {
112        /// Minimum damping floor.
113        base: f64,
114        /// Rate at which damping increases on divergence.
115        increase_rate: f64,
116        /// Rate at which damping decreases on convergence.
117        decrease_rate: f64,
118    },
119}
120
121impl DampingSchedule {
122    /// Get damping factor for a given iteration and optional residual information.
123    ///
124    /// # Arguments
125    /// * `iteration` - Current iteration number
126    /// * `prev_residual` - Residual from the previous iteration (if available)
127    /// * `curr_residual` - Residual from the current iteration (if available)
128    /// * `current_damping` - The current damping factor (used by adaptive schedule)
129    pub fn get_damping(
130        &self,
131        iteration: usize,
132        prev_residual: Option<f64>,
133        curr_residual: Option<f64>,
134        current_damping: f64,
135    ) -> f64 {
136        match self {
137            DampingSchedule::Fixed(d) => *d,
138            DampingSchedule::Linear {
139                start,
140                end,
141                total_steps,
142            } => {
143                if *total_steps == 0 {
144                    return *start;
145                }
146                let frac = (iteration as f64 / *total_steps as f64).min(1.0);
147                start + frac * (end - start)
148            }
149            DampingSchedule::Exponential { initial, decay } => {
150                initial * decay.powi(iteration as i32)
151            }
152            DampingSchedule::Adaptive {
153                base,
154                increase_rate,
155                decrease_rate,
156            } => match (prev_residual, curr_residual) {
157                (Some(prev), Some(curr)) if curr > prev => {
158                    // Diverging: increase damping
159                    (current_damping + increase_rate).min(0.99)
160                }
161                (Some(_prev), Some(_curr)) => {
162                    // Converging: decrease damping toward base
163                    (current_damping - decrease_rate).max(*base)
164                }
165                _ => current_damping,
166            },
167        }
168    }
169}
170
171/// Current state of convergence tracking.
172#[derive(Debug, Clone)]
173pub struct ConvergenceState {
174    /// Current iteration count.
175    pub iteration: usize,
176    /// Whether the algorithm has converged.
177    pub converged: bool,
178    /// Whether the algorithm has diverged.
179    pub diverged: bool,
180    /// History of residual values per iteration.
181    pub residual_history: Vec<f64>,
182    /// History of damping values per iteration.
183    pub damping_history: Vec<f64>,
184    /// Number of consecutive iterations below tolerance.
185    pub consecutive_converged: usize,
186}
187
188impl ConvergenceState {
189    /// Create a fresh convergence state.
190    pub fn new() -> Self {
191        ConvergenceState {
192            iteration: 0,
193            converged: false,
194            diverged: false,
195            residual_history: Vec::new(),
196            damping_history: Vec::new(),
197            consecutive_converged: 0,
198        }
199    }
200
201    /// Return the most recent residual value, if any.
202    pub fn latest_residual(&self) -> Option<f64> {
203        self.residual_history.last().copied()
204    }
205
206    /// Compute the convergence rate as the ratio of the last two residuals.
207    ///
208    /// Returns `None` if fewer than two residuals have been recorded.
209    pub fn convergence_rate(&self) -> Option<f64> {
210        if self.residual_history.len() < 2 {
211            return None;
212        }
213        let n = self.residual_history.len();
214        let r0 = self.residual_history[n - 2];
215        let r1 = self.residual_history[n - 1];
216        if r0 > 1e-15 {
217            Some(r1 / r0)
218        } else {
219            Some(0.0)
220        }
221    }
222}
223
224impl Default for ConvergenceState {
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230/// Monitors convergence of iterative algorithms such as belief propagation.
231///
232/// Tracks residuals, manages damping schedules, and detects convergence or divergence.
233pub struct ConvergenceMonitor {
234    config: ConvergenceConfig,
235    schedule: DampingSchedule,
236    state: ConvergenceState,
237    current_damping: f64,
238}
239
240impl ConvergenceMonitor {
241    /// Create a new convergence monitor with the given configuration and schedule.
242    pub fn new(
243        config: ConvergenceConfig,
244        schedule: DampingSchedule,
245    ) -> Result<Self, ConvergenceError> {
246        config.validate()?;
247        let initial_damping = config.damping_factor;
248        Ok(ConvergenceMonitor {
249            config,
250            schedule,
251            state: ConvergenceState::new(),
252            current_damping: initial_damping,
253        })
254    }
255
256    /// Create a monitor with default configuration and fixed damping.
257    pub fn with_default_config() -> Self {
258        let config = ConvergenceConfig::default();
259        let damping = config.damping_factor;
260        let schedule = DampingSchedule::Fixed(damping);
261        ConvergenceMonitor {
262            config,
263            schedule,
264            state: ConvergenceState::new(),
265            current_damping: damping,
266        }
267    }
268
269    /// Record a new iteration with its residual.
270    ///
271    /// Returns `true` if the algorithm should continue iterating,
272    /// `false` if converged, diverged, or max iterations reached.
273    pub fn record_iteration(&mut self, residual: f64) -> bool {
274        let prev_residual = self.state.latest_residual();
275        self.state.iteration += 1;
276        self.state.residual_history.push(residual);
277
278        // Update damping according to schedule
279        self.current_damping = self.schedule.get_damping(
280            self.state.iteration,
281            prev_residual,
282            Some(residual),
283            self.current_damping,
284        );
285        self.state.damping_history.push(self.current_damping);
286
287        // Check convergence: residual below tolerance
288        if residual < self.config.tolerance {
289            self.state.consecutive_converged += 1;
290            if self.state.consecutive_converged >= self.config.patience {
291                self.state.converged = true;
292                return false;
293            }
294        } else {
295            self.state.consecutive_converged = 0;
296        }
297
298        // Check divergence: residual growing for 5+ consecutive iterations
299        if self.state.residual_history.len() >= 5 {
300            let recent = &self.state.residual_history[self.state.residual_history.len() - 5..];
301            let diverging = recent.windows(2).all(|w| w[1] > w[0]);
302            if diverging {
303                self.state.diverged = true;
304                return false;
305            }
306        }
307
308        // Check max iterations
309        if self.state.iteration >= self.config.max_iterations {
310            return false;
311        }
312
313        true
314    }
315
316    /// Get the current damping factor.
317    pub fn current_damping(&self) -> f64 {
318        self.current_damping
319    }
320
321    /// Get a reference to the current convergence state.
322    pub fn state(&self) -> &ConvergenceState {
323        &self.state
324    }
325
326    /// Check if the algorithm has converged.
327    pub fn is_converged(&self) -> bool {
328        self.state.converged
329    }
330
331    /// Check if the algorithm has diverged.
332    pub fn is_diverged(&self) -> bool {
333        self.state.diverged
334    }
335
336    /// Get the current iteration count.
337    pub fn iteration(&self) -> usize {
338        self.state.iteration
339    }
340
341    /// Reset the monitor to its initial state.
342    pub fn reset(&mut self) {
343        self.state = ConvergenceState::new();
344        self.current_damping = self.config.damping_factor;
345    }
346
347    /// Get summary statistics from the inference run.
348    pub fn stats(&self) -> InferenceStats {
349        InferenceStats {
350            total_iterations: self.state.iteration,
351            final_residual: self.state.latest_residual().unwrap_or(f64::NAN),
352            converged: self.state.converged,
353            diverged: self.state.diverged,
354            convergence_rate: self.state.convergence_rate(),
355            final_damping: self.current_damping,
356        }
357    }
358}
359
360/// Summary statistics from an inference run.
361#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct InferenceStats {
363    /// Total number of iterations executed.
364    pub total_iterations: usize,
365    /// The residual value at the final iteration.
366    pub final_residual: f64,
367    /// Whether the algorithm converged.
368    pub converged: bool,
369    /// Whether the algorithm diverged.
370    pub diverged: bool,
371    /// Convergence rate (ratio of last two residuals), if available.
372    pub convergence_rate: Option<f64>,
373    /// The damping factor at the final iteration.
374    pub final_damping: f64,
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_config_default() {
383        let config = ConvergenceConfig::default();
384        assert!((config.tolerance - 1e-6).abs() < 1e-15);
385        assert_eq!(config.max_iterations, 100);
386        assert!((config.damping_factor - 0.5).abs() < 1e-15);
387        assert_eq!(config.patience, 3);
388    }
389
390    #[test]
391    fn test_config_validate_good() {
392        let config = ConvergenceConfig::default();
393        assert!(config.validate().is_ok());
394    }
395
396    #[test]
397    fn test_config_validate_bad_tolerance() {
398        let config = ConvergenceConfig::new().with_tolerance(0.0);
399        let err = config.validate().unwrap_err();
400        assert!(matches!(err, ConvergenceError::InvalidTolerance(_)));
401    }
402
403    #[test]
404    fn test_config_validate_bad_damping() {
405        let config = ConvergenceConfig::new().with_damping(2.0);
406        let err = config.validate().unwrap_err();
407        assert!(matches!(err, ConvergenceError::InvalidDamping(_)));
408    }
409
410    #[test]
411    fn test_config_builder() {
412        let config = ConvergenceConfig::new()
413            .with_tolerance(1e-4)
414            .with_max_iterations(50)
415            .with_damping(0.3)
416            .with_patience(5);
417        assert!((config.tolerance - 1e-4).abs() < 1e-15);
418        assert_eq!(config.max_iterations, 50);
419        assert!((config.damping_factor - 0.3).abs() < 1e-15);
420        assert_eq!(config.patience, 5);
421    }
422
423    #[test]
424    fn test_damping_fixed() {
425        let schedule = DampingSchedule::Fixed(0.7);
426        assert!((schedule.get_damping(0, None, None, 0.5) - 0.7).abs() < 1e-15);
427        assert!((schedule.get_damping(10, Some(0.1), Some(0.05), 0.5) - 0.7).abs() < 1e-15);
428        assert!((schedule.get_damping(100, None, None, 0.9) - 0.7).abs() < 1e-15);
429    }
430
431    #[test]
432    fn test_damping_linear() {
433        let schedule = DampingSchedule::Linear {
434            start: 0.8,
435            end: 0.2,
436            total_steps: 10,
437        };
438        // At step 0
439        assert!((schedule.get_damping(0, None, None, 0.0) - 0.8).abs() < 1e-15);
440        // At step 5 (midpoint)
441        assert!((schedule.get_damping(5, None, None, 0.0) - 0.5).abs() < 1e-15);
442        // At step 10
443        assert!((schedule.get_damping(10, None, None, 0.0) - 0.2).abs() < 1e-15);
444        // Beyond total_steps, clamps at end
445        assert!((schedule.get_damping(20, None, None, 0.0) - 0.2).abs() < 1e-15);
446    }
447
448    #[test]
449    fn test_damping_exponential() {
450        let schedule = DampingSchedule::Exponential {
451            initial: 1.0,
452            decay: 0.5,
453        };
454        // Step 0: 1.0 * 0.5^0 = 1.0
455        assert!((schedule.get_damping(0, None, None, 0.0) - 1.0).abs() < 1e-15);
456        // Step 1: 1.0 * 0.5^1 = 0.5
457        assert!((schedule.get_damping(1, None, None, 0.0) - 0.5).abs() < 1e-15);
458        // Step 2: 1.0 * 0.5^2 = 0.25
459        assert!((schedule.get_damping(2, None, None, 0.0) - 0.25).abs() < 1e-15);
460    }
461
462    #[test]
463    fn test_damping_adaptive_increases_on_diverge() {
464        let schedule = DampingSchedule::Adaptive {
465            base: 0.1,
466            increase_rate: 0.1,
467            decrease_rate: 0.05,
468        };
469        // Residual grew from 0.5 to 0.8 => damping should increase
470        let result = schedule.get_damping(1, Some(0.5), Some(0.8), 0.4);
471        assert!((result - 0.5).abs() < 1e-15); // 0.4 + 0.1 = 0.5
472    }
473
474    #[test]
475    fn test_damping_adaptive_decreases_on_converge() {
476        let schedule = DampingSchedule::Adaptive {
477            base: 0.1,
478            increase_rate: 0.1,
479            decrease_rate: 0.05,
480        };
481        // Residual dropped from 0.8 to 0.5 => damping should decrease
482        let result = schedule.get_damping(1, Some(0.8), Some(0.5), 0.4);
483        assert!((result - 0.35).abs() < 1e-15); // 0.4 - 0.05 = 0.35
484    }
485
486    #[test]
487    fn test_monitor_converges() {
488        let config = ConvergenceConfig::new()
489            .with_tolerance(1e-3)
490            .with_patience(2);
491        let monitor_result = ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.5));
492        assert!(monitor_result.is_ok());
493        let mut monitor = monitor_result.expect("valid config");
494
495        // Feed residuals that decrease and eventually stay below tolerance
496        assert!(monitor.record_iteration(1.0));
497        assert!(monitor.record_iteration(0.1));
498        assert!(monitor.record_iteration(0.0009)); // below tol (1e-3), consecutive=1
499                                                   // Second consecutive below tolerance => converged (patience=2)
500        assert!(!monitor.record_iteration(0.0005));
501
502        assert!(monitor.is_converged());
503        assert!(!monitor.is_diverged());
504    }
505
506    #[test]
507    fn test_monitor_patience() {
508        let config = ConvergenceConfig::new()
509            .with_tolerance(1e-3)
510            .with_patience(3);
511        let mut monitor =
512            ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.5)).expect("valid config");
513
514        // Two below tolerance, then one above resets counter
515        assert!(monitor.record_iteration(0.0001)); // consecutive=1
516        assert!(monitor.record_iteration(0.0002)); // consecutive=2
517        assert!(monitor.record_iteration(0.01)); // above tol, consecutive=0
518        assert!(monitor.record_iteration(0.0001)); // consecutive=1
519        assert!(monitor.record_iteration(0.0002)); // consecutive=2
520        assert!(!monitor.record_iteration(0.0003)); // consecutive=3 => converged
521
522        assert!(monitor.is_converged());
523    }
524
525    #[test]
526    fn test_monitor_max_iterations() {
527        let config = ConvergenceConfig::new()
528            .with_tolerance(1e-10)
529            .with_max_iterations(5);
530        let mut monitor =
531            ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.5)).expect("valid config");
532
533        // Feed residuals that never converge
534        for i in 0..4 {
535            let residual = 1.0 / (i as f64 + 1.0);
536            assert!(monitor.record_iteration(residual), "iteration {i}");
537        }
538        // 5th iteration should return false (max reached)
539        assert!(!monitor.record_iteration(0.1));
540        assert!(!monitor.is_converged());
541        assert_eq!(monitor.iteration(), 5);
542    }
543
544    #[test]
545    fn test_monitor_diverge_detection() {
546        let config = ConvergenceConfig::new()
547            .with_tolerance(1e-10)
548            .with_max_iterations(100);
549        let mut monitor =
550            ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.5)).expect("valid config");
551
552        // Feed strictly increasing residuals
553        assert!(monitor.record_iteration(1.0));
554        assert!(monitor.record_iteration(2.0));
555        assert!(monitor.record_iteration(3.0));
556        assert!(monitor.record_iteration(4.0));
557        // 5th growing residual triggers divergence
558        assert!(!monitor.record_iteration(5.0));
559
560        assert!(monitor.is_diverged());
561        assert!(!monitor.is_converged());
562    }
563
564    #[test]
565    fn test_monitor_reset() {
566        let config = ConvergenceConfig::new()
567            .with_tolerance(1e-3)
568            .with_patience(1);
569        let mut monitor =
570            ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.5)).expect("valid config");
571
572        // Converge
573        assert!(!monitor.record_iteration(0.0001));
574        assert!(monitor.is_converged());
575        assert_eq!(monitor.iteration(), 1);
576
577        // Reset
578        monitor.reset();
579        assert!(!monitor.is_converged());
580        assert!(!monitor.is_diverged());
581        assert_eq!(monitor.iteration(), 0);
582        assert!(monitor.state().residual_history.is_empty());
583    }
584
585    #[test]
586    fn test_monitor_stats() {
587        let config = ConvergenceConfig::new()
588            .with_tolerance(1e-3)
589            .with_patience(2);
590        let mut monitor =
591            ConvergenceMonitor::new(config, DampingSchedule::Fixed(0.3)).expect("valid config");
592
593        monitor.record_iteration(0.5);
594        monitor.record_iteration(0.0001);
595        monitor.record_iteration(0.00005);
596
597        let stats = monitor.stats();
598        assert_eq!(stats.total_iterations, 3);
599        assert!((stats.final_residual - 0.00005).abs() < 1e-15);
600        assert!(stats.converged);
601        assert!(!stats.diverged);
602        assert!((stats.final_damping - 0.3).abs() < 1e-15);
603        assert!(stats.convergence_rate.is_some());
604    }
605
606    #[test]
607    fn test_convergence_rate() {
608        let mut state = ConvergenceState::new();
609        // No residuals
610        assert!(state.convergence_rate().is_none());
611
612        // One residual
613        state.residual_history.push(1.0);
614        assert!(state.convergence_rate().is_none());
615
616        // Two residuals: rate = 0.5 / 1.0 = 0.5
617        state.residual_history.push(0.5);
618        let rate = state.convergence_rate().expect("should have rate");
619        assert!((rate - 0.5).abs() < 1e-15);
620    }
621
622    #[test]
623    fn test_state_default() {
624        let state = ConvergenceState::default();
625        assert_eq!(state.iteration, 0);
626        assert!(!state.converged);
627        assert!(!state.diverged);
628        assert!(state.residual_history.is_empty());
629        assert!(state.damping_history.is_empty());
630        assert_eq!(state.consecutive_converged, 0);
631    }
632}