Skip to main content

trellis_runner/engine/policy/
checkpoint.rs

1//! Scheduled checkpoint emission policy.
2//!
3//! This policy requests periodic checkpoint creation based on the current
4//! iteration counter in the [`EngineContext`].
5//!
6//! It does not terminate execution.
7//!
8//! # Behaviour
9//!
10//! - Every `N` iterations (excluding iteration 0), emits
11//!   [`EngineAction::EmitCheckpoint(CheckpointReason::Scheduled)`].
12//! - Otherwise returns [`EngineAction::Continue`].
13//!
14//! # Design notes
15//!
16//! This policy is purely temporal and does not inspect convergence or
17//! progress information.
18//!
19//! It is typically used alongside termination policies such as
20//! [`MaxIterationPolicy`] or [`TimeoutPolicy`].
21//!
22use super::EnginePolicy;
23
24use crate::engine::{event::CheckpointReason, EngineAction, EngineContext, EventBatch};
25
26pub struct CheckpointPolicy {
27    every: usize,
28}
29
30impl CheckpointPolicy {
31    pub fn every(every: usize) -> Self {
32        Self { every }
33    }
34}
35
36impl<F> EnginePolicy<F> for CheckpointPolicy {
37    fn decide(&mut self, _batch: &EventBatch<F>, context: &EngineContext) -> EngineAction {
38        if context.iter.is_multiple_of(self.every) & (context.iter > 0) {
39            return EngineAction::EmitCheckpoint(CheckpointReason::Scheduled);
40        }
41
42        EngineAction::Continue
43    }
44}
45
46#[cfg(test)]
47mod test {
48    use super::*;
49
50    use crate::engine::policy::PolicyStack;
51    use crate::progress::Progress;
52
53    #[test]
54    fn checkpoint_policy_requests_checkpoint_on_schedule() {
55        let mut stack = PolicyStack::<f64>::new().add(CheckpointPolicy::every(10));
56
57        let batch: EventBatch<f64> = EventBatch::new().add(Progress::Complete);
58        let ctx = EngineContext {
59            iter: 10,
60            ..Default::default()
61        };
62
63        assert!(matches!(
64            stack.decide(&batch, &ctx),
65            EngineAction::EmitCheckpoint(CheckpointReason::Scheduled)
66        ))
67    }
68
69    #[test]
70    fn checkpoint_policy_does_not_request_checkpoint_when_not_on_schedule() {
71        let mut stack = PolicyStack::<f64>::new().add(CheckpointPolicy::every(10));
72
73        let batch: EventBatch<f64> = EventBatch::new().add(Progress::Complete);
74        let ctx = EngineContext {
75            iter: 11,
76            ..Default::default()
77        };
78
79        assert!(matches!(stack.decide(&batch, &ctx), EngineAction::Continue))
80    }
81}