Skip to main content

trellis_runner/engine/policy/
absolute_tolerance.rs

1//! # Absolute tolerance policy
2//!
3//! Terminates when the absolute error estimate falls below a threshold.
4//!
5//! ## Behaviour
6//!
7//! - Consumes `Progress::Measure` events.
8//! - Checks `absolute < tolerance`.
9//!
10//! ## Termination
11//!
12//! Returns [`Termination::Converged`] when condition is met.
13use super::EnginePolicy;
14
15use crate::{
16    engine::{EngineAction, EngineContext, EventBatch},
17    progress::Progress,
18};
19
20use num_traits::float::FloatCore;
21
22pub struct AbsoluteTolerancePolicy<F> {
23    tolerance: F,
24    window: Vec<F>,
25    window_size: usize,
26}
27
28impl<F> AbsoluteTolerancePolicy<F> {
29    pub fn new(tolerance: F, window_size: usize) -> Self {
30        Self {
31            tolerance,
32            window_size,
33            window: Vec::with_capacity(window_size),
34        }
35    }
36}
37
38impl<F> EnginePolicy<F> for AbsoluteTolerancePolicy<F>
39where
40    F: FloatCore,
41{
42    fn decide(&mut self, batch: &EventBatch<F>, _context: &EngineContext) -> EngineAction {
43        for event in &batch.events {
44            if let Progress::Report { diagnostics, .. } = event {
45                if let Some(rel) = diagnostics.absolute_error {
46                    self.window.push(rel);
47
48                    if self.window.len() > self.window_size {
49                        self.window.remove(0);
50                    }
51                }
52            }
53        }
54
55        if self.window.len() < self.window_size {
56            return EngineAction::Continue;
57        }
58
59        // use worst-case (robust against noise)
60        let worst = self.window.iter().copied().fold(F::zero(), |a, b| a.max(b));
61
62        if worst < self.tolerance {
63            return EngineAction::Stop(crate::Termination::Converged);
64        }
65
66        EngineAction::Continue
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use crate::engine::policy::PolicyStack;
74    use crate::engine::{EngineContext, EventBatch};
75    use crate::progress::{Progress, ProgressDiagnostics};
76
77    #[test]
78    fn absolute_tolerance_stops_on_error_below_threshold() {
79        let mut stack = PolicyStack::<f64>::new().add(AbsoluteTolerancePolicy::new(0.1, 5));
80
81        let batch = EventBatch::new().add(Progress::Report {
82            measure: 1.0,
83            diagnostics: ProgressDiagnostics {
84                absolute_error: Some(0.05),
85                ..Default::default()
86            },
87        });
88
89        let ctx = EngineContext::default();
90
91        for _ in 0..4 {
92            stack.decide(&batch, &ctx);
93        }
94
95        assert!(matches!(
96            stack.decide(&batch, &ctx),
97            crate::engine::EngineAction::Stop(_)
98        ));
99    }
100
101    #[test]
102    fn absolute_tolerance_continues_above_threshold() {
103        let mut stack = PolicyStack::<f64>::new().add(AbsoluteTolerancePolicy::new(0.1, 5));
104
105        let batch = EventBatch::new().add(Progress::Report {
106            measure: 1.0,
107            diagnostics: ProgressDiagnostics {
108                absolute_error: Some(0.5),
109                ..Default::default()
110            },
111        });
112
113        let ctx = EngineContext::default();
114
115        assert!(matches!(
116            stack.decide(&batch, &ctx),
117            crate::engine::EngineAction::Continue
118        ));
119    }
120}