trellis_runner/engine/policy/
stagnation.rs1use super::EnginePolicy;
25
26use num_traits::float::FloatCore;
27
28use crate::{
29 engine::{EngineAction, EngineContext, EventBatch},
30 progress::Progress,
31 Termination,
32};
33
34pub struct StagnationPolicy<F> {
35 window: usize,
36 relative_slope_tol: F,
37 relative_noise_floor: F,
38 history: Vec<F>,
39}
40
41impl<F: num_traits::FromPrimitive> StagnationPolicy<F> {
42 pub fn new(window: usize) -> Self {
43 Self {
44 window,
45 history: Vec::new(),
46 relative_slope_tol: F::from_f64(1e-4).unwrap(),
47 relative_noise_floor: F::from_f64(1e-6).unwrap(),
48 }
49 }
50}
51
52impl<F: FloatCore + num_traits::FromPrimitive + std::iter::Sum<F>> EnginePolicy<F>
53 for StagnationPolicy<F>
54{
55 fn decide(&mut self, batch: &EventBatch<F>, _ctx: &EngineContext) -> EngineAction {
56 for e in &batch.events {
57 let v = match e {
58 Progress::Measure(value) => *value,
59 _ => continue,
60 };
61
62 self.history.push(v);
63 }
64
65 if self.history.len() > self.window {
66 self.history.remove(0);
67 }
68
69 if self.history.len() < self.window {
70 return EngineAction::Continue;
71 }
72
73 let scale = self.history[0].abs().max(F::one());
75 let slope_tol = self.relative_slope_tol * scale;
76 let noise_floor = self.relative_noise_floor * scale * scale;
77
78 let n = F::from(self.history.len()).unwrap();
79
80 let mut sum_x = F::zero();
81 let mut sum_y = F::zero();
82 let mut sum_xy = F::zero();
83 let mut sum_x2 = F::zero();
84
85 for (i, y) in self.history.iter().enumerate() {
86 let x = F::from(i).unwrap();
87
88 sum_x = sum_x + x;
89 sum_y = sum_y + *y;
90 sum_xy = sum_xy + x * *y;
91 sum_x2 = sum_x2 + x * x;
92 }
93
94 let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
95
96 let mean = sum_y / n;
97
98 let variance: F = self
99 .history
100 .iter()
101 .map(|y| {
102 let d = *y - mean;
103 d * d
104 })
105 .sum::<F>()
106 / n;
107
108 if slope.abs() < slope_tol && variance < noise_floor {
109 return EngineAction::Stop(Termination::Stagnated);
110 }
111
112 EngineAction::Continue
113 }
114}
115
116#[cfg(test)]
117mod test {
118 use super::*;
119 use crate::engine::policy::PolicyStack;
120 use crate::engine::EngineContext;
121 use crate::progress::Progress;
122
123 fn batch(v: f64) -> EventBatch<f64> {
124 EventBatch::new().add(Progress::Measure(v))
125 }
126
127 #[test]
128 fn stagnation_detects_flat_region() {
129 let mut p = StagnationPolicy::new(3);
130
131 let ctx = EngineContext::default();
132
133 let _ = p.decide(&batch(1.0), &ctx);
134 let _ = p.decide(&batch(1.0), &ctx);
135 let _ = p.decide(&batch(1.0), &ctx);
136 let res = p.decide(&batch(1.0), &ctx);
137
138 assert!(matches!(res, EngineAction::Stop(_)));
139 }
140
141 #[test]
142 fn stagnation_requires_window() {
143 let mut p = StagnationPolicy::new(5);
144
145 let ctx = EngineContext::default();
146
147 let res = p.decide(&batch(1.0), &ctx);
148
149 assert!(matches!(res, EngineAction::Continue));
150 }
151
152 #[test]
153 fn stagnation_resets_with_change() {
154 let mut stack = PolicyStack::<f64>::new().add(StagnationPolicy::new(3));
155
156 let ctx = EngineContext::default();
157
158 let seq = vec![1.0, 1.0001, 1.0002, 2.0];
159
160 for v in seq {
161 let batch = EventBatch::new().add(Progress::Measure(v));
162
163 let res = stack.decide(&batch, &ctx);
164
165 if v == 2.0 {
166 assert!(matches!(res, crate::engine::EngineAction::Continue));
167 }
168 }
169 }
170}