Skip to main content

trellis_runner/engine/policy/
stagnation.rs

1//! # Stagnation policy (window-based)
2//!
3//! Detects lack of meaningful variation over a sliding window of values.
4//!
5//! ## Behaviour
6//!
7//! - Maintains a fixed-size history of recent values.
8//! - Values are extracted from:
9//!   - `Progress::Measure`
10//!
11//! - Once enough samples are collected:
12//!
13//!   - If all consecutive differences are < `epsilon()`
14//!     → stagnation is detected
15//!
16//! ## Termination
17//!
18//! Returns [`Termination::Stagnated`] when stagnation is detected.
19//!
20//! ## Notes
21//!
22//! This policy is stricter than [`NoProgressPolicy`] because it requires
23//! *persistent flat behaviour over a window*, not just repeated small steps.
24use super::EnginePolicy;
25
26use num_traits::float::FloatCore;
27
28use crate::{
29    engine::{EngineAction, EngineContext, EventBatch},
30    progress::Progress,
31    Termination,
32};
33
34pub struct StagnationPolicy<F> {
35    window: usize,
36    relative_slope_tol: F,
37    relative_noise_floor: F,
38    history: Vec<F>,
39}
40
41impl<F: num_traits::FromPrimitive> StagnationPolicy<F> {
42    pub fn new(window: usize) -> Self {
43        Self {
44            window,
45            history: Vec::new(),
46            relative_slope_tol: F::from_f64(1e-4).unwrap(),
47            relative_noise_floor: F::from_f64(1e-6).unwrap(),
48        }
49    }
50}
51
52impl<F: FloatCore + num_traits::FromPrimitive + std::iter::Sum<F>> EnginePolicy<F>
53    for StagnationPolicy<F>
54{
55    fn decide(&mut self, batch: &EventBatch<F>, _ctx: &EngineContext) -> EngineAction {
56        for e in &batch.events {
57            let v = match e {
58                Progress::Measure(value) => *value,
59                _ => continue,
60            };
61
62            self.history.push(v);
63        }
64
65        if self.history.len() > self.window {
66            self.history.remove(0);
67        }
68
69        if self.history.len() < self.window {
70            return EngineAction::Continue;
71        }
72
73        // slope-based stagnation
74        let scale = self.history[0].abs().max(F::one());
75        let slope_tol = self.relative_slope_tol * scale;
76        let noise_floor = self.relative_noise_floor * scale * scale;
77
78        let n = F::from(self.history.len()).unwrap();
79
80        let mut sum_x = F::zero();
81        let mut sum_y = F::zero();
82        let mut sum_xy = F::zero();
83        let mut sum_x2 = F::zero();
84
85        for (i, y) in self.history.iter().enumerate() {
86            let x = F::from(i).unwrap();
87
88            sum_x = sum_x + x;
89            sum_y = sum_y + *y;
90            sum_xy = sum_xy + x * *y;
91            sum_x2 = sum_x2 + x * x;
92        }
93
94        let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
95
96        let mean = sum_y / n;
97
98        let variance: F = self
99            .history
100            .iter()
101            .map(|y| {
102                let d = *y - mean;
103                d * d
104            })
105            .sum::<F>()
106            / n;
107
108        if slope.abs() < slope_tol && variance < noise_floor {
109            return EngineAction::Stop(Termination::Stagnated);
110        }
111
112        EngineAction::Continue
113    }
114}
115
116#[cfg(test)]
117mod test {
118    use super::*;
119    use crate::engine::policy::PolicyStack;
120    use crate::engine::EngineContext;
121    use crate::progress::Progress;
122
123    fn batch(v: f64) -> EventBatch<f64> {
124        EventBatch::new().add(Progress::Measure(v))
125    }
126
127    #[test]
128    fn stagnation_detects_flat_region() {
129        let mut p = StagnationPolicy::new(3);
130
131        let ctx = EngineContext::default();
132
133        let _ = p.decide(&batch(1.0), &ctx);
134        let _ = p.decide(&batch(1.0), &ctx);
135        let _ = p.decide(&batch(1.0), &ctx);
136        let res = p.decide(&batch(1.0), &ctx);
137
138        assert!(matches!(res, EngineAction::Stop(_)));
139    }
140
141    #[test]
142    fn stagnation_requires_window() {
143        let mut p = StagnationPolicy::new(5);
144
145        let ctx = EngineContext::default();
146
147        let res = p.decide(&batch(1.0), &ctx);
148
149        assert!(matches!(res, EngineAction::Continue));
150    }
151
152    #[test]
153    fn stagnation_resets_with_change() {
154        let mut stack = PolicyStack::<f64>::new().add(StagnationPolicy::new(3));
155
156        let ctx = EngineContext::default();
157
158        let seq = vec![1.0, 1.0001, 1.0002, 2.0];
159
160        for v in seq {
161            let batch = EventBatch::new().add(Progress::Measure(v));
162
163            let res = stack.decide(&batch, &ctx);
164
165            if v == 2.0 {
166                assert!(matches!(res, crate::engine::EngineAction::Continue));
167            }
168        }
169    }
170}