Skip to main content

solti_model/domain/policy/
restart.rs

1//! # Restart policy.
2//!
3//! [`RestartPolicy`] controls when a task is restarted after completion or failure.
4
5use serde::{Deserialize, Serialize};
6use std::str::FromStr;
7
8use crate::error::{ModelError, ModelResult};
9
10/// Determines whether a task should be automatically restarted after completion or failure.
11///
12/// | Variant     | Behaviour                                                    |
13/// |-------------|--------------------------------------------------------------|
14/// | `Never`     | Do not restart under any circumstances                       |
15/// | `OnFailure` | Restart only when the task ends with an error                |
16/// | `Always`    | Restart unconditionally (immediate or periodic via interval) |
17///
18/// `Always { interval_ms: None }` restarts immediately;
19/// `Always { interval_ms: Some(N) }` waits N ms between runs (periodic task).
20///
21/// Cancellation (via controller or shutdown) is **not** treated as failure and will not trigger a restart.
22///
23/// ## Also
24///
25/// - [`BackoffPolicy`](super::BackoffPolicy) delay between restart attempts.
26/// - [`TaskSpec`](crate::TaskSpec) carries `restart` as a field.
27#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(tag = "type", rename_all = "camelCase")]
29#[non_exhaustive]
30pub enum RestartPolicy {
31    /// Never restart the task.
32    #[default]
33    Never,
34    /// Restart the task only if it failed (non-zero exit, error, panic, etc.).
35    OnFailure,
36    /// Always restart after completion.
37    ///
38    /// - `interval_ms: None` — restart immediately after the previous
39    ///   run completes (tight loop tempered only by [`BackoffPolicy`]
40    ///   on failure).
41    /// - `interval_ms: Some(n)` — wait `n` milliseconds between runs
42    ///   (periodic task). `Some(0)` is treated as immediate and is
43    ///   semantically identical to `None`; prefer `None` for clarity.
44    #[serde(rename_all = "camelCase")]
45    Always {
46        #[serde(skip_serializing_if = "Option::is_none")]
47        interval_ms: Option<u64>,
48    },
49}
50
51impl RestartPolicy {
52    /// Create an Always policy without interval (immediate restart).
53    pub const fn always() -> Self {
54        RestartPolicy::Always { interval_ms: None }
55    }
56
57    /// Create an Always policy with periodic interval.
58    pub const fn periodic(interval_ms: u64) -> Self {
59        RestartPolicy::Always {
60            interval_ms: Some(interval_ms),
61        }
62    }
63}
64
65impl FromStr for RestartPolicy {
66    type Err = ModelError;
67
68    fn from_str(s: &str) -> ModelResult<Self> {
69        let original = s.trim();
70        if original.is_empty() {
71            return Ok(RestartPolicy::Never);
72        }
73
74        let (head, rest) = match original.find(':') {
75            Some(pos) => (&original[..pos], Some(original[pos + 1..].trim())),
76            None => (original, None),
77        };
78
79        if head.eq_ignore_ascii_case("never") {
80            Ok(RestartPolicy::Never)
81        } else if head.eq_ignore_ascii_case("on-failure") || head.eq_ignore_ascii_case("failure") {
82            Ok(RestartPolicy::OnFailure)
83        } else if head.eq_ignore_ascii_case("always") {
84            let interval_ms = match rest {
85                None | Some("") => None,
86                Some(v) => {
87                    let v = v.parse::<u64>().map_err(|_| {
88                        ModelError::UnknownRestart(format!(
89                            "invalid interval in '{}': must be u64",
90                            original
91                        ))
92                    })?;
93                    Some(v)
94                }
95            };
96            Ok(RestartPolicy::Always { interval_ms })
97        } else {
98            Err(ModelError::UnknownRestart(original.to_string()))
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::RestartPolicy;
106    use crate::error::ModelError;
107    use std::str::FromStr;
108
109    #[test]
110    fn parse_never_and_empty() {
111        assert_eq!(RestartPolicy::from_str("").unwrap(), RestartPolicy::Never);
112        assert_eq!(
113            RestartPolicy::from_str("never").unwrap(),
114            RestartPolicy::Never
115        );
116        assert_eq!(
117            RestartPolicy::from_str("  NeVeR  ").unwrap(),
118            RestartPolicy::Never
119        );
120    }
121
122    #[test]
123    fn parse_on_failure() {
124        assert_eq!(
125            RestartPolicy::from_str("on-failure").unwrap(),
126            RestartPolicy::OnFailure
127        );
128        assert_eq!(
129            RestartPolicy::from_str("failure").unwrap(),
130            RestartPolicy::OnFailure
131        );
132        assert_eq!(
133            RestartPolicy::from_str("  Failure ").unwrap(),
134            RestartPolicy::OnFailure
135        );
136    }
137
138    #[test]
139    fn parse_always_immediate() {
140        assert_eq!(
141            RestartPolicy::from_str("always").unwrap(),
142            RestartPolicy::Always { interval_ms: None }
143        );
144        assert_eq!(
145            RestartPolicy::from_str("  ALWAYS  ").unwrap(),
146            RestartPolicy::Always { interval_ms: None }
147        );
148        assert_eq!(
149            RestartPolicy::from_str("always:").unwrap(),
150            RestartPolicy::Always { interval_ms: None }
151        );
152        assert_eq!(
153            RestartPolicy::from_str("always:   ").unwrap(),
154            RestartPolicy::Always { interval_ms: None }
155        );
156    }
157
158    #[test]
159    fn parse_always_with_interval() {
160        assert_eq!(
161            RestartPolicy::from_str("always:1000").unwrap(),
162            RestartPolicy::Always {
163                interval_ms: Some(1000)
164            }
165        );
166        assert_eq!(
167            RestartPolicy::from_str(" Always:  60000 ").unwrap(),
168            RestartPolicy::Always {
169                interval_ms: Some(60000)
170            }
171        );
172    }
173
174    #[test]
175    fn parse_always_invalid_interval() {
176        let err = RestartPolicy::from_str("always:not-a-number").unwrap_err();
177        assert!(matches!(err, ModelError::UnknownRestart(_)));
178    }
179
180    #[test]
181    fn parse_unknown_head_fails() {
182        let err = RestartPolicy::from_str("random").unwrap_err();
183        assert!(matches!(err, ModelError::UnknownRestart(_)));
184    }
185}