Skip to main content

trellis_runner/engine/policy/
relative_tolerance.rs

1//! # Relative tolerance policy
2//!
3//! Terminates when the relative error estimate falls below a threshold.
4//!
5//! ## Behaviour
6//!
7//! - Checks `relative < tolerance`.
8//!
9//! ## Termination
10//!
11//! Returns [`Termination::Converged`] when condition is met.
12use super::EnginePolicy;
13
14use crate::{
15    engine::{EngineAction, EngineContext, EventBatch},
16    progress::Progress,
17};
18
19use num_traits::float::FloatCore;
20
21pub struct RelativeTolerancePolicy<F> {
22    tolerance: F,
23    window: Vec<F>,
24    window_size: usize,
25}
26
27impl<F> RelativeTolerancePolicy<F> {
28    pub fn new(tolerance: F, window_size: usize) -> Self {
29        Self {
30            tolerance,
31            window_size,
32            window: Vec::with_capacity(window_size),
33        }
34    }
35}
36
37impl<F> EnginePolicy<F> for RelativeTolerancePolicy<F>
38where
39    F: FloatCore,
40{
41    fn decide(&mut self, batch: &EventBatch<F>, _context: &EngineContext) -> EngineAction {
42        for event in &batch.events {
43            if let Progress::Report { diagnostics, .. } = event {
44                if let Some(rel) = diagnostics.relative_error {
45                    self.window.push(rel);
46
47                    if self.window.len() > self.window_size {
48                        self.window.remove(0);
49                    }
50                }
51            }
52        }
53
54        if self.window.len() < self.window_size {
55            return EngineAction::Continue;
56        }
57
58        // use worst-case (robust against noise)
59        let worst = self.window.iter().copied().fold(F::zero(), |a, b| a.max(b));
60
61        if worst < self.tolerance {
62            return EngineAction::Stop(crate::Termination::Converged);
63        }
64
65        EngineAction::Continue
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use crate::engine::policy::PolicyStack;
73    use crate::engine::{EngineContext, EventBatch};
74    use crate::progress::{Progress, ProgressDiagnostics};
75
76    #[test]
77    fn relative_tolerance_stops_on_small_relative_error() {
78        let mut stack = PolicyStack::<f64>::new().add(RelativeTolerancePolicy::new(0.1, 5));
79
80        let batch = EventBatch::new().add(Progress::Report {
81            measure: 1.0,
82            diagnostics: ProgressDiagnostics {
83                relative_error: Some(0.05),
84                ..Default::default()
85            },
86        });
87
88        let ctx = EngineContext::default();
89
90        for _ in 0..4 {
91            stack.decide(&batch, &ctx);
92        }
93
94        assert!(matches!(
95            stack.decide(&batch, &ctx),
96            crate::engine::EngineAction::Stop(_)
97        ));
98    }
99
100    #[test]
101    fn relative_tolerance_continues_when_large() {
102        let mut stack = PolicyStack::<f64>::new().add(RelativeTolerancePolicy::new(0.1, 5));
103
104        let batch = EventBatch::new().add(Progress::Report {
105            measure: 1.0,
106            diagnostics: ProgressDiagnostics {
107                relative_error: Some(0.5),
108                ..Default::default()
109            },
110        });
111
112        let ctx = EngineContext::default();
113
114        assert!(matches!(
115            stack.decide(&batch, &ctx),
116            crate::engine::EngineAction::Continue
117        ));
118    }
119}