Skip to main content

syspulse_core/
restart.rs

1use serde::{Deserialize, Serialize};
2use std::time::Duration;
3
4fn default_max_retries() -> Option<u32> {
5    None
6}
7
8fn default_backoff_base() -> f64 {
9    1.0
10}
11
12fn default_backoff_max() -> f64 {
13    300.0
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(tag = "policy", rename_all = "snake_case")]
18pub enum RestartPolicy {
19    Always {
20        #[serde(default = "default_max_retries")]
21        max_retries: Option<u32>,
22        #[serde(default = "default_backoff_base")]
23        backoff_base_secs: f64,
24        #[serde(default = "default_backoff_max")]
25        backoff_max_secs: f64,
26    },
27    OnFailure {
28        #[serde(default = "default_max_retries")]
29        max_retries: Option<u32>,
30        #[serde(default = "default_backoff_base")]
31        backoff_base_secs: f64,
32        #[serde(default = "default_backoff_max")]
33        backoff_max_secs: f64,
34    },
35    Never,
36}
37
38impl Default for RestartPolicy {
39    fn default() -> Self {
40        RestartPolicy::Never
41    }
42}
43
44pub fn compute_backoff(attempt: u32, base_secs: f64, max_secs: f64) -> Duration {
45    let exp = base_secs * 2.0f64.powi(attempt as i32);
46    let capped = exp.min(max_secs);
47    // Add jitter: random between 0 and 10% of capped
48    let jitter = rand::random::<f64>() * capped * 0.1;
49    Duration::from_secs_f64(capped + jitter)
50}
51
52pub struct RestartEvaluator;
53
54impl RestartEvaluator {
55    pub fn should_restart(
56        policy: &RestartPolicy,
57        exit_code: Option<i32>,
58        restart_count: u32,
59    ) -> bool {
60        match policy {
61            RestartPolicy::Never => false,
62            RestartPolicy::Always { max_retries, .. } => {
63                max_retries.map_or(true, |max| restart_count < max)
64            }
65            RestartPolicy::OnFailure { max_retries, .. } => {
66                let failed = exit_code.map_or(true, |c| c != 0);
67                failed && max_retries.map_or(true, |max| restart_count < max)
68            }
69        }
70    }
71
72    pub fn backoff_duration(policy: &RestartPolicy, restart_count: u32) -> Duration {
73        match policy {
74            RestartPolicy::Always {
75                backoff_base_secs,
76                backoff_max_secs,
77                ..
78            }
79            | RestartPolicy::OnFailure {
80                backoff_base_secs,
81                backoff_max_secs,
82                ..
83            } => compute_backoff(restart_count, *backoff_base_secs, *backoff_max_secs),
84            RestartPolicy::Never => Duration::ZERO,
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92
93    #[test]
94    fn never_policy_never_restarts() {
95        let policy = RestartPolicy::Never;
96        assert!(!RestartEvaluator::should_restart(&policy, Some(1), 0));
97        assert!(!RestartEvaluator::should_restart(&policy, Some(0), 0));
98        assert!(!RestartEvaluator::should_restart(&policy, None, 0));
99    }
100
101    #[test]
102    fn always_policy_restarts_regardless_of_exit_code() {
103        let policy = RestartPolicy::Always {
104            max_retries: None,
105            backoff_base_secs: 1.0,
106            backoff_max_secs: 300.0,
107        };
108        assert!(RestartEvaluator::should_restart(&policy, Some(0), 0));
109        assert!(RestartEvaluator::should_restart(&policy, Some(1), 0));
110        assert!(RestartEvaluator::should_restart(&policy, None, 0));
111        assert!(RestartEvaluator::should_restart(&policy, Some(0), 100));
112    }
113
114    #[test]
115    fn always_policy_respects_max_retries() {
116        let policy = RestartPolicy::Always {
117            max_retries: Some(3),
118            backoff_base_secs: 1.0,
119            backoff_max_secs: 300.0,
120        };
121        assert!(RestartEvaluator::should_restart(&policy, Some(1), 0));
122        assert!(RestartEvaluator::should_restart(&policy, Some(1), 2));
123        assert!(!RestartEvaluator::should_restart(&policy, Some(1), 3));
124        assert!(!RestartEvaluator::should_restart(&policy, Some(1), 10));
125    }
126
127    #[test]
128    fn on_failure_restarts_only_on_failure() {
129        let policy = RestartPolicy::OnFailure {
130            max_retries: None,
131            backoff_base_secs: 1.0,
132            backoff_max_secs: 300.0,
133        };
134        // exit_code 0 => success => no restart
135        assert!(!RestartEvaluator::should_restart(&policy, Some(0), 0));
136        // exit_code != 0 => failure => restart
137        assert!(RestartEvaluator::should_restart(&policy, Some(1), 0));
138        assert!(RestartEvaluator::should_restart(&policy, Some(137), 0));
139        // None exit_code (e.g., signal) => treat as failure
140        assert!(RestartEvaluator::should_restart(&policy, None, 0));
141    }
142
143    #[test]
144    fn on_failure_respects_max_retries() {
145        let policy = RestartPolicy::OnFailure {
146            max_retries: Some(2),
147            backoff_base_secs: 1.0,
148            backoff_max_secs: 300.0,
149        };
150        assert!(RestartEvaluator::should_restart(&policy, Some(1), 0));
151        assert!(RestartEvaluator::should_restart(&policy, Some(1), 1));
152        assert!(!RestartEvaluator::should_restart(&policy, Some(1), 2));
153    }
154
155    #[test]
156    fn backoff_exponential_growth() {
157        let d0 = compute_backoff(0, 1.0, 300.0);
158        let d1 = compute_backoff(1, 1.0, 300.0);
159        let d2 = compute_backoff(2, 1.0, 300.0);
160
161        // Base case: 1.0 * 2^0 = 1.0s (+ up to 10% jitter)
162        assert!(d0.as_secs_f64() >= 1.0);
163        assert!(d0.as_secs_f64() <= 1.1);
164
165        // 1.0 * 2^1 = 2.0s (+ up to 10% jitter)
166        assert!(d1.as_secs_f64() >= 2.0);
167        assert!(d1.as_secs_f64() <= 2.2);
168
169        // 1.0 * 2^2 = 4.0s (+ up to 10% jitter)
170        assert!(d2.as_secs_f64() >= 4.0);
171        assert!(d2.as_secs_f64() <= 4.4);
172    }
173
174    #[test]
175    fn backoff_caps_at_max() {
176        let d = compute_backoff(20, 1.0, 300.0);
177        // 2^20 = 1048576 >> 300, so should be capped at 300 + up to 10% jitter
178        assert!(d.as_secs_f64() >= 300.0);
179        assert!(d.as_secs_f64() <= 330.0);
180    }
181
182    #[test]
183    fn backoff_never_policy_returns_zero() {
184        let policy = RestartPolicy::Never;
185        let d = RestartEvaluator::backoff_duration(&policy, 5);
186        assert_eq!(d, Duration::ZERO);
187    }
188
189    #[test]
190    fn default_restart_policy_is_never() {
191        match RestartPolicy::default() {
192            RestartPolicy::Never => {}
193            _ => panic!("Default should be Never"),
194        }
195    }
196}