1use std::time::Duration;
7
8use tokio::sync::watch;
9use tokio::task::JoinHandle;
10
11use crate::config::DaemonConfig;
12
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum ComponentStatus {
15 Running,
16 Failed(String),
17 Stopped,
18}
19
20#[derive(Debug, thiserror::Error)]
22pub enum DaemonError {
23 #[error("task error: {0}")]
24 Task(String),
25 #[error("shutdown error: {0}")]
26 Shutdown(String),
27}
28
29pub struct ComponentHandle {
30 pub name: String,
31 handle: JoinHandle<Result<(), DaemonError>>,
32 pub status: ComponentStatus,
33 pub restart_count: u32,
34}
35
36impl ComponentHandle {
37 #[must_use]
38 pub fn new(name: impl Into<String>, handle: JoinHandle<Result<(), DaemonError>>) -> Self {
39 Self {
40 name: name.into(),
41 handle,
42 status: ComponentStatus::Running,
43 restart_count: 0,
44 }
45 }
46
47 #[must_use]
48 pub fn is_finished(&self) -> bool {
49 self.handle.is_finished()
50 }
51}
52
53pub struct DaemonSupervisor {
54 components: Vec<ComponentHandle>,
55 health_interval: Duration,
56 _max_backoff: Duration,
57 shutdown_rx: watch::Receiver<bool>,
58}
59
60impl DaemonSupervisor {
61 #[must_use]
62 pub fn new(config: &DaemonConfig, shutdown_rx: watch::Receiver<bool>) -> Self {
63 Self {
64 components: Vec::new(),
65 health_interval: Duration::from_secs(config.health_interval_secs),
66 _max_backoff: Duration::from_secs(config.max_restart_backoff_secs),
67 shutdown_rx,
68 }
69 }
70
71 pub fn add_component(&mut self, handle: ComponentHandle) {
72 self.components.push(handle);
73 }
74
75 #[must_use]
76 pub fn component_count(&self) -> usize {
77 self.components.len()
78 }
79
80 pub async fn run(&mut self) {
82 let mut interval = tokio::time::interval(self.health_interval);
83 loop {
84 tokio::select! {
85 _ = interval.tick() => {
86 self.check_health();
87 }
88 _ = self.shutdown_rx.changed() => {
89 if *self.shutdown_rx.borrow() {
90 tracing::info!("daemon supervisor shutting down");
91 break;
92 }
93 }
94 }
95 }
96 }
97
98 fn check_health(&mut self) {
99 for component in &mut self.components {
100 if component.status == ComponentStatus::Running && component.is_finished() {
101 component.status = ComponentStatus::Failed("task exited".into());
102 component.restart_count += 1;
103 tracing::warn!(
104 component = %component.name,
105 restarts = component.restart_count,
106 "component exited unexpectedly"
107 );
108 }
109 }
110 }
111
112 #[must_use]
113 pub fn component_statuses(&self) -> Vec<(&str, &ComponentStatus)> {
114 self.components
115 .iter()
116 .map(|c| (c.name.as_str(), &c.status))
117 .collect()
118 }
119}
120
121#[must_use]
127pub fn is_process_alive(pid: u32) -> bool {
128 #[cfg(unix)]
129 {
130 let Ok(signed) = i32::try_from(pid) else {
133 return false;
134 };
135 if signed <= 0 {
136 return false;
137 }
138 std::process::Command::new("kill")
139 .args(["-0", &signed.to_string()])
140 .output()
141 .map(|o| o.status.success())
142 .unwrap_or(false)
143 }
144 #[cfg(windows)]
145 {
146 std::process::Command::new("tasklist")
147 .args(["/FI", &format!("PID eq {pid}"), "/NH", "/FO", "CSV"])
148 .output()
149 .map(|o| {
150 let stdout = String::from_utf8_lossy(&o.stdout);
151 stdout.contains(&format!("\"{pid}\""))
154 })
155 .unwrap_or(false)
156 }
157 #[cfg(not(any(unix, windows)))]
158 {
159 let _ = pid;
160 false
161 }
162}
163
164pub fn write_pid_file(path: &str) -> std::io::Result<()> {
171 use std::io::Write as _;
172 let expanded = expand_tilde(path);
173 let path = std::path::Path::new(&expanded);
174 if let Some(parent) = path.parent() {
175 std::fs::create_dir_all(parent)?;
176 }
177 let mut file = std::fs::OpenOptions::new()
178 .write(true)
179 .create_new(true)
180 .open(path)?;
181 file.write_all(std::process::id().to_string().as_bytes())
182}
183
184pub fn read_pid_file(path: &str) -> std::io::Result<u32> {
190 let expanded = expand_tilde(path);
191 let content = std::fs::read_to_string(&expanded)?;
192 content
193 .trim()
194 .parse::<u32>()
195 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
196}
197
198pub fn remove_pid_file(path: &str) -> std::io::Result<()> {
204 let expanded = expand_tilde(path);
205 match std::fs::remove_file(&expanded) {
206 Ok(()) => Ok(()),
207 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
208 Err(e) => Err(e),
209 }
210}
211
212fn expand_tilde(path: &str) -> String {
213 if let Some(rest) = path.strip_prefix("~/")
214 && let Some(home) = std::env::var_os("HOME").or_else(|| std::env::var_os("USERPROFILE"))
215 {
216 return format!("{}/{rest}", home.to_string_lossy());
217 }
218 path.to_owned()
219}
220
221#[cfg(test)]
222mod tests {
223 #![allow(clippy::field_reassign_with_default)]
224
225 use super::*;
226
227 #[test]
228 fn expand_tilde_with_home() {
229 let result = expand_tilde("~/test/file.pid");
230 assert!(!result.starts_with("~/"));
231 }
232
233 #[test]
234 fn expand_tilde_absolute_unchanged() {
235 assert_eq!(expand_tilde("/tmp/zeph.pid"), "/tmp/zeph.pid");
236 }
237
238 #[test]
239 fn pid_file_roundtrip() {
240 let dir = tempfile::tempdir().unwrap();
241 let path = dir.path().join("test.pid");
242 let path_str = path.to_string_lossy().to_string();
243
244 write_pid_file(&path_str).unwrap();
245 let pid = read_pid_file(&path_str).unwrap();
246 assert_eq!(pid, std::process::id());
247 remove_pid_file(&path_str).unwrap();
248 assert!(!path.exists());
249 }
250
251 #[test]
252 fn remove_nonexistent_pid_file_ok() {
253 assert!(remove_pid_file("/tmp/nonexistent_zeph_test.pid").is_ok());
254 }
255
256 #[test]
257 fn read_invalid_pid_file() {
258 let dir = tempfile::tempdir().unwrap();
259 let path = dir.path().join("bad.pid");
260 std::fs::write(&path, "not_a_number").unwrap();
261 assert!(read_pid_file(&path.to_string_lossy()).is_err());
262 }
263
264 #[tokio::test]
265 async fn supervisor_tracks_components() {
266 let config = DaemonConfig::default();
267 let (_tx, rx) = watch::channel(false);
268 let mut supervisor = DaemonSupervisor::new(&config, rx);
269
270 let handle = tokio::spawn(async { Ok::<(), DaemonError>(()) });
271 supervisor.add_component(ComponentHandle::new("test", handle));
272 assert_eq!(supervisor.component_count(), 1);
273 }
274
275 #[tokio::test]
276 async fn supervisor_detects_finished_component() {
277 let config = DaemonConfig::default();
278 let (_tx, rx) = watch::channel(false);
279 let mut supervisor = DaemonSupervisor::new(&config, rx);
280
281 let handle = tokio::spawn(async { Ok::<(), DaemonError>(()) });
282 tokio::time::sleep(Duration::from_millis(10)).await;
283 supervisor.add_component(ComponentHandle::new("finished", handle));
284 supervisor.check_health();
285
286 let statuses = supervisor.component_statuses();
287 assert_eq!(statuses.len(), 1);
288 assert!(matches!(statuses[0].1, ComponentStatus::Failed(_)));
289 }
290
291 #[tokio::test]
292 async fn supervisor_shutdown() {
293 let config = DaemonConfig {
294 health_interval_secs: 1,
295 ..DaemonConfig::default()
296 };
297 let (tx, rx) = watch::channel(false);
298 let mut supervisor = DaemonSupervisor::new(&config, rx);
299
300 let run_handle = tokio::spawn(async move { supervisor.run().await });
301 tokio::time::sleep(Duration::from_millis(50)).await;
302 let _ = tx.send(true);
303 tokio::time::timeout(Duration::from_secs(2), run_handle)
304 .await
305 .expect("supervisor should stop on shutdown")
306 .expect("task should complete");
307 }
308
309 #[test]
310 fn component_status_eq() {
311 assert_eq!(ComponentStatus::Running, ComponentStatus::Running);
312 assert_eq!(ComponentStatus::Stopped, ComponentStatus::Stopped);
313 assert_ne!(ComponentStatus::Running, ComponentStatus::Stopped);
314 }
315
316 #[test]
317 fn is_process_alive_current_process() {
318 let pid = std::process::id();
319 assert!(is_process_alive(pid), "current process must be alive");
320 }
321
322 #[test]
323 fn is_process_alive_nonexistent_pid() {
324 assert!(
326 !is_process_alive(u32::MAX),
327 "PID u32::MAX must not be alive"
328 );
329 }
330}