trellis_runner/engine/policy/
target_value.rs1use super::EnginePolicy;
19
20use crate::{
21 engine::{EngineAction, EngineContext, EventBatch},
22 progress::Progress,
23 Termination,
24};
25
26use num_traits::float::FloatCore;
27
28pub struct TargetValuePolicy<F> {
29 target: F,
30 tolerance: F,
31 window: Vec<F>,
32 window_size: usize,
33}
34
35impl<F: FloatCore> TargetValuePolicy<F> {
36 pub fn new(target: F, tolerance: F, window_size: usize) -> Self {
37 Self {
38 target,
39 tolerance,
40 window: Vec::with_capacity(window_size),
41 window_size,
42 }
43 }
44}
45
46impl<F> EnginePolicy<F> for TargetValuePolicy<F>
47where
48 F: FloatCore,
49{
50 fn decide(&mut self, batch: &EventBatch<F>, _context: &EngineContext) -> EngineAction {
51 for event in &batch.events {
52 let value = match event {
53 Progress::Measure(v) => *v,
54 Progress::Report { measure, .. } => *measure,
55 _ => continue,
56 };
57
58 let dist = (value - self.target).abs();
60
61 self.window.push(dist);
62
63 if self.window.len() > self.window_size {
64 self.window.remove(0);
65 }
66 }
67
68 if self.window.len() < self.window_size {
69 return EngineAction::Continue;
70 }
71
72 let mean = self.window.iter().copied().fold(F::zero(), |a, b| a + b)
74 / F::from(self.window.len()).unwrap();
75
76 if mean < self.tolerance {
78 return EngineAction::Stop(Termination::Converged);
79 }
80
81 EngineAction::Continue
82 }
83}
84
85#[cfg(test)]
86mod test {
87 use super::*;
88 use crate::engine::EngineContext;
89 use crate::progress::Progress;
90
91 fn batch(v: f64) -> EventBatch<f64> {
92 EventBatch::new().add(Progress::Measure(v))
93 }
94
95 #[test]
96 fn target_reached_stops() {
97 let mut p = TargetValuePolicy::new(1.0, 0.01, 5);
98
99 let ctx = EngineContext::default();
100
101 for _ in 0..5 {
102 p.decide(&batch(1.0), &ctx);
103 }
104
105 let res = p.decide(&batch(1.0), &ctx);
106
107 assert!(matches!(res, EngineAction::Stop(_)));
108 }
109
110 #[test]
111 fn target_not_reached_continues() {
112 let mut p = TargetValuePolicy::new(1.0, 0.01, 5);
113
114 let ctx = EngineContext::default();
115
116 let res = p.decide(&batch(2.0), &ctx);
117
118 assert!(matches!(res, EngineAction::Continue));
119 }
120}