1use serde::{Deserialize, Serialize};
2use std::time::Duration;
3
4fn default_max_retries() -> Option<u32> {
5 None
6}
7
8fn default_backoff_base() -> f64 {
9 1.0
10}
11
12fn default_backoff_max() -> f64 {
13 300.0
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(tag = "policy", rename_all = "snake_case")]
18pub enum RestartPolicy {
19 Always {
20 #[serde(default = "default_max_retries")]
21 max_retries: Option<u32>,
22 #[serde(default = "default_backoff_base")]
23 backoff_base_secs: f64,
24 #[serde(default = "default_backoff_max")]
25 backoff_max_secs: f64,
26 },
27 OnFailure {
28 #[serde(default = "default_max_retries")]
29 max_retries: Option<u32>,
30 #[serde(default = "default_backoff_base")]
31 backoff_base_secs: f64,
32 #[serde(default = "default_backoff_max")]
33 backoff_max_secs: f64,
34 },
35 Never,
36}
37
38impl Default for RestartPolicy {
39 fn default() -> Self {
40 RestartPolicy::Never
41 }
42}
43
44pub fn compute_backoff(attempt: u32, base_secs: f64, max_secs: f64) -> Duration {
45 let exp = base_secs * 2.0f64.powi(attempt as i32);
46 let capped = exp.min(max_secs);
47 let jitter = rand::random::<f64>() * capped * 0.1;
49 Duration::from_secs_f64(capped + jitter)
50}
51
52pub struct RestartEvaluator;
53
54impl RestartEvaluator {
55 pub fn should_restart(
56 policy: &RestartPolicy,
57 exit_code: Option<i32>,
58 restart_count: u32,
59 ) -> bool {
60 match policy {
61 RestartPolicy::Never => false,
62 RestartPolicy::Always { max_retries, .. } => {
63 max_retries.map_or(true, |max| restart_count < max)
64 }
65 RestartPolicy::OnFailure { max_retries, .. } => {
66 let failed = exit_code.map_or(true, |c| c != 0);
67 failed && max_retries.map_or(true, |max| restart_count < max)
68 }
69 }
70 }
71
72 pub fn backoff_duration(policy: &RestartPolicy, restart_count: u32) -> Duration {
73 match policy {
74 RestartPolicy::Always {
75 backoff_base_secs,
76 backoff_max_secs,
77 ..
78 }
79 | RestartPolicy::OnFailure {
80 backoff_base_secs,
81 backoff_max_secs,
82 ..
83 } => compute_backoff(restart_count, *backoff_base_secs, *backoff_max_secs),
84 RestartPolicy::Never => Duration::ZERO,
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn never_policy_never_restarts() {
95 let policy = RestartPolicy::Never;
96 assert!(!RestartEvaluator::should_restart(&policy, Some(1), 0));
97 assert!(!RestartEvaluator::should_restart(&policy, Some(0), 0));
98 assert!(!RestartEvaluator::should_restart(&policy, None, 0));
99 }
100
101 #[test]
102 fn always_policy_restarts_regardless_of_exit_code() {
103 let policy = RestartPolicy::Always {
104 max_retries: None,
105 backoff_base_secs: 1.0,
106 backoff_max_secs: 300.0,
107 };
108 assert!(RestartEvaluator::should_restart(&policy, Some(0), 0));
109 assert!(RestartEvaluator::should_restart(&policy, Some(1), 0));
110 assert!(RestartEvaluator::should_restart(&policy, None, 0));
111 assert!(RestartEvaluator::should_restart(&policy, Some(0), 100));
112 }
113
114 #[test]
115 fn always_policy_respects_max_retries() {
116 let policy = RestartPolicy::Always {
117 max_retries: Some(3),
118 backoff_base_secs: 1.0,
119 backoff_max_secs: 300.0,
120 };
121 assert!(RestartEvaluator::should_restart(&policy, Some(1), 0));
122 assert!(RestartEvaluator::should_restart(&policy, Some(1), 2));
123 assert!(!RestartEvaluator::should_restart(&policy, Some(1), 3));
124 assert!(!RestartEvaluator::should_restart(&policy, Some(1), 10));
125 }
126
127 #[test]
128 fn on_failure_restarts_only_on_failure() {
129 let policy = RestartPolicy::OnFailure {
130 max_retries: None,
131 backoff_base_secs: 1.0,
132 backoff_max_secs: 300.0,
133 };
134 assert!(!RestartEvaluator::should_restart(&policy, Some(0), 0));
136 assert!(RestartEvaluator::should_restart(&policy, Some(1), 0));
138 assert!(RestartEvaluator::should_restart(&policy, Some(137), 0));
139 assert!(RestartEvaluator::should_restart(&policy, None, 0));
141 }
142
143 #[test]
144 fn on_failure_respects_max_retries() {
145 let policy = RestartPolicy::OnFailure {
146 max_retries: Some(2),
147 backoff_base_secs: 1.0,
148 backoff_max_secs: 300.0,
149 };
150 assert!(RestartEvaluator::should_restart(&policy, Some(1), 0));
151 assert!(RestartEvaluator::should_restart(&policy, Some(1), 1));
152 assert!(!RestartEvaluator::should_restart(&policy, Some(1), 2));
153 }
154
155 #[test]
156 fn backoff_exponential_growth() {
157 let d0 = compute_backoff(0, 1.0, 300.0);
158 let d1 = compute_backoff(1, 1.0, 300.0);
159 let d2 = compute_backoff(2, 1.0, 300.0);
160
161 assert!(d0.as_secs_f64() >= 1.0);
163 assert!(d0.as_secs_f64() <= 1.1);
164
165 assert!(d1.as_secs_f64() >= 2.0);
167 assert!(d1.as_secs_f64() <= 2.2);
168
169 assert!(d2.as_secs_f64() >= 4.0);
171 assert!(d2.as_secs_f64() <= 4.4);
172 }
173
174 #[test]
175 fn backoff_caps_at_max() {
176 let d = compute_backoff(20, 1.0, 300.0);
177 assert!(d.as_secs_f64() >= 300.0);
179 assert!(d.as_secs_f64() <= 330.0);
180 }
181
182 #[test]
183 fn backoff_never_policy_returns_zero() {
184 let policy = RestartPolicy::Never;
185 let d = RestartEvaluator::backoff_duration(&policy, 5);
186 assert_eq!(d, Duration::ZERO);
187 }
188
189 #[test]
190 fn default_restart_policy_is_never() {
191 match RestartPolicy::default() {
192 RestartPolicy::Never => {}
193 _ => panic!("Default should be Never"),
194 }
195 }
196}