Skip to main content

trellis_runner/engine/policy/
target_value.rs

1//! # Target value policy
2//!
3//! Terminates when a metric crosses a predefined target threshold.
4//!
5//! ## Behaviour
6//!
7//! - Monitors `Progress::Measure` values.
8//! - If any value <= `target`, termination is triggered.
9//!
10//! ## Termination
11//!
12//! Returns [`Termination::Converged`] when target is reached.
13//!
14//! ## Notes
15//!
16//! This policy assumes minimisation semantics.
17//! For maximisation problems, invert the metric externally.
18use super::EnginePolicy;
19
20use crate::{
21    engine::{EngineAction, EngineContext, EventBatch},
22    progress::Progress,
23    Termination,
24};
25
26use num_traits::float::FloatCore;
27
28pub struct TargetValuePolicy<F> {
29    target: F,
30    tolerance: F,
31    window: Vec<F>,
32    window_size: usize,
33}
34
35impl<F: FloatCore> TargetValuePolicy<F> {
36    pub fn new(target: F, tolerance: F, window_size: usize) -> Self {
37        Self {
38            target,
39            tolerance,
40            window: Vec::with_capacity(window_size),
41            window_size,
42        }
43    }
44}
45
46impl<F> EnginePolicy<F> for TargetValuePolicy<F>
47where
48    F: FloatCore,
49{
50    fn decide(&mut self, batch: &EventBatch<F>, _context: &EngineContext) -> EngineAction {
51        for event in &batch.events {
52            let value = match event {
53                Progress::Measure(v) => *v,
54                Progress::Report { measure, .. } => *measure,
55                _ => continue,
56            };
57
58            // symmetric distance to target
59            let dist = (value - self.target).abs();
60
61            self.window.push(dist);
62
63            if self.window.len() > self.window_size {
64                self.window.remove(0);
65            }
66        }
67
68        if self.window.len() < self.window_size {
69            return EngineAction::Continue;
70        }
71
72        // mean absolute distance
73        let mean = self.window.iter().copied().fold(F::zero(), |a, b| a + b)
74            / F::from(self.window.len()).unwrap();
75
76        // tolerance-based stopping condition
77        if mean < self.tolerance {
78            return EngineAction::Stop(Termination::Converged);
79        }
80
81        EngineAction::Continue
82    }
83}
84
85#[cfg(test)]
86mod test {
87    use super::*;
88    use crate::engine::EngineContext;
89    use crate::progress::Progress;
90
91    fn batch(v: f64) -> EventBatch<f64> {
92        EventBatch::new().add(Progress::Measure(v))
93    }
94
95    #[test]
96    fn target_reached_stops() {
97        let mut p = TargetValuePolicy::new(1.0, 0.01, 5);
98
99        let ctx = EngineContext::default();
100
101        for _ in 0..5 {
102            p.decide(&batch(1.0), &ctx);
103        }
104
105        let res = p.decide(&batch(1.0), &ctx);
106
107        assert!(matches!(res, EngineAction::Stop(_)));
108    }
109
110    #[test]
111    fn target_not_reached_continues() {
112        let mut p = TargetValuePolicy::new(1.0, 0.01, 5);
113
114        let ctx = EngineContext::default();
115
116        let res = p.decide(&batch(2.0), &ctx);
117
118        assert!(matches!(res, EngineAction::Continue));
119    }
120}