trellis_runner/engine/policy/
absolute_tolerance.rs1use super::EnginePolicy;
14
15use crate::{
16 engine::{EngineAction, EngineContext, EventBatch},
17 progress::Progress,
18};
19
20use num_traits::float::FloatCore;
21
22pub struct AbsoluteTolerancePolicy<F> {
23 tolerance: F,
24 window: Vec<F>,
25 window_size: usize,
26}
27
28impl<F> AbsoluteTolerancePolicy<F> {
29 pub fn new(tolerance: F, window_size: usize) -> Self {
30 Self {
31 tolerance,
32 window_size,
33 window: Vec::with_capacity(window_size),
34 }
35 }
36}
37
38impl<F> EnginePolicy<F> for AbsoluteTolerancePolicy<F>
39where
40 F: FloatCore,
41{
42 fn decide(&mut self, batch: &EventBatch<F>, _context: &EngineContext) -> EngineAction {
43 for event in &batch.events {
44 if let Progress::Report { diagnostics, .. } = event {
45 if let Some(rel) = diagnostics.absolute_error {
46 self.window.push(rel);
47
48 if self.window.len() > self.window_size {
49 self.window.remove(0);
50 }
51 }
52 }
53 }
54
55 if self.window.len() < self.window_size {
56 return EngineAction::Continue;
57 }
58
59 let worst = self.window.iter().copied().fold(F::zero(), |a, b| a.max(b));
61
62 if worst < self.tolerance {
63 return EngineAction::Stop(crate::Termination::Converged);
64 }
65
66 EngineAction::Continue
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use crate::engine::policy::PolicyStack;
74 use crate::engine::{EngineContext, EventBatch};
75 use crate::progress::{Progress, ProgressDiagnostics};
76
77 #[test]
78 fn absolute_tolerance_stops_on_error_below_threshold() {
79 let mut stack = PolicyStack::<f64>::new().add(AbsoluteTolerancePolicy::new(0.1, 5));
80
81 let batch = EventBatch::new().add(Progress::Report {
82 measure: 1.0,
83 diagnostics: ProgressDiagnostics {
84 absolute_error: Some(0.05),
85 ..Default::default()
86 },
87 });
88
89 let ctx = EngineContext::default();
90
91 for _ in 0..4 {
92 stack.decide(&batch, &ctx);
93 }
94
95 assert!(matches!(
96 stack.decide(&batch, &ctx),
97 crate::engine::EngineAction::Stop(_)
98 ));
99 }
100
101 #[test]
102 fn absolute_tolerance_continues_above_threshold() {
103 let mut stack = PolicyStack::<f64>::new().add(AbsoluteTolerancePolicy::new(0.1, 5));
104
105 let batch = EventBatch::new().add(Progress::Report {
106 measure: 1.0,
107 diagnostics: ProgressDiagnostics {
108 absolute_error: Some(0.5),
109 ..Default::default()
110 },
111 });
112
113 let ctx = EngineContext::default();
114
115 assert!(matches!(
116 stack.decide(&batch, &ctx),
117 crate::engine::EngineAction::Continue
118 ));
119 }
120}