trellis_runner/engine/policy/
relative_tolerance.rs1use super::EnginePolicy;
13
14use crate::{
15 engine::{EngineAction, EngineContext, EventBatch},
16 progress::Progress,
17};
18
19use num_traits::float::FloatCore;
20
21pub struct RelativeTolerancePolicy<F> {
22 tolerance: F,
23 window: Vec<F>,
24 window_size: usize,
25}
26
27impl<F> RelativeTolerancePolicy<F> {
28 pub fn new(tolerance: F, window_size: usize) -> Self {
29 Self {
30 tolerance,
31 window_size,
32 window: Vec::with_capacity(window_size),
33 }
34 }
35}
36
37impl<F> EnginePolicy<F> for RelativeTolerancePolicy<F>
38where
39 F: FloatCore,
40{
41 fn decide(&mut self, batch: &EventBatch<F>, _context: &EngineContext) -> EngineAction {
42 for event in &batch.events {
43 if let Progress::Report { diagnostics, .. } = event {
44 if let Some(rel) = diagnostics.relative_error {
45 self.window.push(rel);
46
47 if self.window.len() > self.window_size {
48 self.window.remove(0);
49 }
50 }
51 }
52 }
53
54 if self.window.len() < self.window_size {
55 return EngineAction::Continue;
56 }
57
58 let worst = self.window.iter().copied().fold(F::zero(), |a, b| a.max(b));
60
61 if worst < self.tolerance {
62 return EngineAction::Stop(crate::Termination::Converged);
63 }
64
65 EngineAction::Continue
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use crate::engine::policy::PolicyStack;
73 use crate::engine::{EngineContext, EventBatch};
74 use crate::progress::{Progress, ProgressDiagnostics};
75
76 #[test]
77 fn relative_tolerance_stops_on_small_relative_error() {
78 let mut stack = PolicyStack::<f64>::new().add(RelativeTolerancePolicy::new(0.1, 5));
79
80 let batch = EventBatch::new().add(Progress::Report {
81 measure: 1.0,
82 diagnostics: ProgressDiagnostics {
83 relative_error: Some(0.05),
84 ..Default::default()
85 },
86 });
87
88 let ctx = EngineContext::default();
89
90 for _ in 0..4 {
91 stack.decide(&batch, &ctx);
92 }
93
94 assert!(matches!(
95 stack.decide(&batch, &ctx),
96 crate::engine::EngineAction::Stop(_)
97 ));
98 }
99
100 #[test]
101 fn relative_tolerance_continues_when_large() {
102 let mut stack = PolicyStack::<f64>::new().add(RelativeTolerancePolicy::new(0.1, 5));
103
104 let batch = EventBatch::new().add(Progress::Report {
105 measure: 1.0,
106 diagnostics: ProgressDiagnostics {
107 relative_error: Some(0.5),
108 ..Default::default()
109 },
110 });
111
112 let ctx = EngineContext::default();
113
114 assert!(matches!(
115 stack.decide(&batch, &ctx),
116 crate::engine::EngineAction::Continue
117 ));
118 }
119}