Skip to main content

solti_model/domain/policy/
restart.rs

1//! # Restart policy.
2//!
3//! [`RestartPolicy`] controls when a task is restarted after completion or failure.
4
5use serde::{Deserialize, Serialize};
6use std::str::FromStr;
7
8use crate::error::{ModelError, ModelResult};
9
10/// Determines whether a task should be automatically restarted after completion or failure.
11///
12/// | Variant     | Behaviour                                                    |
13/// |-------------|--------------------------------------------------------------|
14/// | `Never`     | Do not restart under any circumstances                       |
15/// | `OnFailure` | Restart only when the task ends with an error                |
16/// | `Always`    | Restart unconditionally (immediate or periodic via interval) |
17///
18/// `Always { interval_ms: None }` restarts immediately;
19/// `Always { interval_ms: Some(N) }` waits N ms between runs (periodic task).
20///
21/// Cancellation (via controller or shutdown) is **not** treated as failure and will not trigger a restart.
22///
23/// ## Also
24///
25/// - [`BackoffPolicy`](super::BackoffPolicy) delay between restart attempts.
26/// - [`TaskSpec`](crate::TaskSpec) carries `restart` as a field.
27#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
28#[serde(tag = "type", rename_all = "camelCase")]
29#[non_exhaustive]
30pub enum RestartPolicy {
31    /// Never restart the task.
32    #[default]
33    Never,
34    /// Restart the task only if it failed (non-zero exit, error, panic, etc.).
35    OnFailure,
36    /// Always restart after completion.
37    ///
38    /// - `interval_ms: None` - restart immediately after the previous run completes (tight loop tempered only by [`BackoffPolicy`] on failure).
39    /// - `interval_ms: Some(n)` - wait `n` milliseconds between runs (periodic task).
40    /// - `Some(0)` is treated as immediate and is semantically identical to `None`; prefer `None` for clarity.
41    #[serde(rename_all = "camelCase")]
42    Always {
43        #[serde(skip_serializing_if = "Option::is_none")]
44        interval_ms: Option<u64>,
45    },
46}
47
48impl RestartPolicy {
49    /// Create an Always policy without interval (immediate restart).
50    pub const fn always() -> Self {
51        RestartPolicy::Always { interval_ms: None }
52    }
53
54    /// Create an Always policy with periodic interval.
55    pub const fn periodic(interval_ms: u64) -> Self {
56        RestartPolicy::Always {
57            interval_ms: Some(interval_ms),
58        }
59    }
60}
61
62impl FromStr for RestartPolicy {
63    type Err = ModelError;
64
65    fn from_str(s: &str) -> ModelResult<Self> {
66        let original = s.trim();
67        if original.is_empty() {
68            return Ok(RestartPolicy::Never);
69        }
70
71        let (head, rest) = match original.find(':') {
72            Some(pos) => (&original[..pos], Some(original[pos + 1..].trim())),
73            None => (original, None),
74        };
75
76        if head.eq_ignore_ascii_case("never") {
77            Ok(RestartPolicy::Never)
78        } else if head.eq_ignore_ascii_case("on-failure") || head.eq_ignore_ascii_case("failure") {
79            Ok(RestartPolicy::OnFailure)
80        } else if head.eq_ignore_ascii_case("always") {
81            let interval_ms = match rest {
82                None | Some("") => None,
83                Some(v) => {
84                    let v = v.parse::<u64>().map_err(|_| {
85                        ModelError::UnknownRestart(format!(
86                            "invalid interval in '{}': must be u64",
87                            original
88                        ))
89                    })?;
90                    Some(v)
91                }
92            };
93            Ok(RestartPolicy::Always { interval_ms })
94        } else {
95            Err(ModelError::UnknownRestart(original.to_string()))
96        }
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::RestartPolicy;
103    use crate::error::ModelError;
104    use std::str::FromStr;
105
106    #[test]
107    fn parse_never_and_empty() {
108        assert_eq!(RestartPolicy::from_str("").unwrap(), RestartPolicy::Never);
109        assert_eq!(
110            RestartPolicy::from_str("never").unwrap(),
111            RestartPolicy::Never
112        );
113        assert_eq!(
114            RestartPolicy::from_str("  NeVeR  ").unwrap(),
115            RestartPolicy::Never
116        );
117    }
118
119    #[test]
120    fn parse_on_failure() {
121        assert_eq!(
122            RestartPolicy::from_str("on-failure").unwrap(),
123            RestartPolicy::OnFailure
124        );
125        assert_eq!(
126            RestartPolicy::from_str("failure").unwrap(),
127            RestartPolicy::OnFailure
128        );
129        assert_eq!(
130            RestartPolicy::from_str("  Failure ").unwrap(),
131            RestartPolicy::OnFailure
132        );
133    }
134
135    #[test]
136    fn parse_always_immediate() {
137        assert_eq!(
138            RestartPolicy::from_str("always").unwrap(),
139            RestartPolicy::Always { interval_ms: None }
140        );
141        assert_eq!(
142            RestartPolicy::from_str("  ALWAYS  ").unwrap(),
143            RestartPolicy::Always { interval_ms: None }
144        );
145        assert_eq!(
146            RestartPolicy::from_str("always:").unwrap(),
147            RestartPolicy::Always { interval_ms: None }
148        );
149        assert_eq!(
150            RestartPolicy::from_str("always:   ").unwrap(),
151            RestartPolicy::Always { interval_ms: None }
152        );
153    }
154
155    #[test]
156    fn parse_always_with_interval() {
157        assert_eq!(
158            RestartPolicy::from_str("always:1000").unwrap(),
159            RestartPolicy::Always {
160                interval_ms: Some(1000)
161            }
162        );
163        assert_eq!(
164            RestartPolicy::from_str(" Always:  60000 ").unwrap(),
165            RestartPolicy::Always {
166                interval_ms: Some(60000)
167            }
168        );
169    }
170
171    #[test]
172    fn parse_always_invalid_interval() {
173        let err = RestartPolicy::from_str("always:not-a-number").unwrap_err();
174        assert!(matches!(err, ModelError::UnknownRestart(_)));
175    }
176
177    #[test]
178    fn parse_unknown_head_fails() {
179        let err = RestartPolicy::from_str("random").unwrap_err();
180        assert!(matches!(err, ModelError::UnknownRestart(_)));
181    }
182}