trellis_runner/engine/policy/
checkpoint.rs1use super::EnginePolicy;
23
24use crate::engine::{event::CheckpointReason, EngineAction, EngineContext, EventBatch};
25
26pub struct CheckpointPolicy {
27 every: usize,
28}
29
30impl CheckpointPolicy {
31 pub fn every(every: usize) -> Self {
32 Self { every }
33 }
34}
35
36impl<F> EnginePolicy<F> for CheckpointPolicy {
37 fn decide(&mut self, _batch: &EventBatch<F>, context: &EngineContext) -> EngineAction {
38 if context.iter.is_multiple_of(self.every) & (context.iter > 0) {
39 return EngineAction::EmitCheckpoint(CheckpointReason::Scheduled);
40 }
41
42 EngineAction::Continue
43 }
44}
45
46#[cfg(test)]
47mod test {
48 use super::*;
49
50 use crate::engine::policy::PolicyStack;
51 use crate::progress::Progress;
52
53 #[test]
54 fn checkpoint_policy_requests_checkpoint_on_schedule() {
55 let mut stack = PolicyStack::<f64>::new().add(CheckpointPolicy::every(10));
56
57 let batch: EventBatch<f64> = EventBatch::new().add(Progress::Complete);
58 let ctx = EngineContext {
59 iter: 10,
60 ..Default::default()
61 };
62
63 assert!(matches!(
64 stack.decide(&batch, &ctx),
65 EngineAction::EmitCheckpoint(CheckpointReason::Scheduled)
66 ))
67 }
68
69 #[test]
70 fn checkpoint_policy_does_not_request_checkpoint_when_not_on_schedule() {
71 let mut stack = PolicyStack::<f64>::new().add(CheckpointPolicy::every(10));
72
73 let batch: EventBatch<f64> = EventBatch::new().add(Progress::Complete);
74 let ctx = EngineContext {
75 iter: 11,
76 ..Default::default()
77 };
78
79 assert!(matches!(stack.decide(&batch, &ctx), EngineAction::Continue))
80 }
81}