Skip to main content

team_core/
supervisor.rs

1//! Process supervision.
2//!
3//! The default back-end is a portable `TmuxSupervisor` that works on macOS
4//! and Linux. `SystemdSupervisor` and `LaunchdSupervisor` plug in behind
5//! the same trait when the host supports them.
6
7use std::path::{Path, PathBuf};
8use std::process::Command;
9use std::thread;
10use std::time::{Duration, Instant};
11
12use anyhow::{Context, Result};
13
14use crate::compose::AgentHandle;
15
16#[derive(Debug, Clone)]
17pub struct AgentSpec {
18    pub project: String,
19    pub agent: String,
20    pub tmux_session: String,
21    pub wrapper: PathBuf,
22    pub cwd: PathBuf,
23    pub env_file: PathBuf,
24}
25
26impl AgentSpec {
27    pub fn from_handle(h: AgentHandle<'_>, root: &Path, tmux_prefix: &str) -> Self {
28        Self {
29            project: h.project.into(),
30            agent: h.agent.into(),
31            tmux_session: format!("{tmux_prefix}{}-{}", h.project, h.agent),
32            wrapper: root.join("bin/agent-wrapper.sh"),
33            cwd: root.to_path_buf(),
34            env_file: crate::render::env_path(root, h.project, h.agent),
35        }
36    }
37}
38
39/// Observed state of an agent's supervising process.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum AgentState {
42    Running,
43    Stopped,
44    Unknown,
45}
46
47/// Outcome of a graceful drain. `Graceful` means the agent observed
48/// `Stopped` before the timeout elapsed; `TimedOutKilled` means the
49/// poll fell through and `down()` was used as a hard stop. Surfaced
50/// to the caller so reload can annotate which agents were forcibly
51/// killed — operator signal that a drain budget needs tuning.
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum DrainOutcome {
54    Graceful,
55    TimedOutKilled,
56}
57
58pub trait Supervisor {
59    fn up(&self, spec: &AgentSpec) -> Result<()>;
60    fn down(&self, spec: &AgentSpec) -> Result<()>;
61    fn state(&self, spec: &AgentSpec) -> Result<AgentState>;
62
63    /// Stop an agent gracefully. The default implementation falls
64    /// back to `down()` for back-ends that don't implement signal
65    /// delivery (or where graceful shutdown isn't meaningful — e.g.
66    /// a `MockSupervisor` in tests).
67    fn drain(&self, spec: &AgentSpec, _timeout: Duration) -> Result<DrainOutcome> {
68        self.down(spec)?;
69        Ok(DrainOutcome::TimedOutKilled)
70    }
71
72    /// Cadence at which `drain` polls for `Stopped` after the
73    /// graceful-stop signal is sent. Default 250ms — fine on every
74    /// host we've tested. The hook exists so tests can inject a
75    /// shorter cadence (no real-time waits) without going through
76    /// the OS, and so a future slow-tmux host has an escape valve
77    /// without forking the orchestration.
78    fn drain_poll_interval(&self) -> Duration {
79        Duration::from_millis(250)
80    }
81}
82
83/// Generic graceful-drain orchestration used by `Supervisor` impls
84/// that have a "signal a graceful stop" primitive (e.g. tmux's
85/// `send-keys C-c`). Calls `signal_fn`, polls
86/// `supervisor.state(spec)` for `Stopped` up to `timeout` at the
87/// supervisor's `drain_poll_interval`, falls through to
88/// `supervisor.down(spec)` if the agent doesn't exit in time.
89///
90/// Pulled out so the orchestration contract is testable end-to-end
91/// against a `MockSupervisor` without a real tmux runtime.
92pub fn orchestrate_drain<S, F>(
93    supervisor: &S,
94    spec: &AgentSpec,
95    timeout: Duration,
96    signal_fn: F,
97) -> Result<DrainOutcome>
98where
99    S: Supervisor + ?Sized,
100    F: FnOnce(),
101{
102    signal_fn();
103    let outcome = poll_for_stopped(timeout, supervisor.drain_poll_interval(), || {
104        supervisor.state(spec).unwrap_or(AgentState::Unknown)
105    });
106    if outcome == DrainOutcome::TimedOutKilled {
107        supervisor.down(spec)?;
108    }
109    Ok(outcome)
110}
111
112/// Portable supervisor: one detached `tmux` session per agent.
113pub struct TmuxSupervisor;
114
115impl Supervisor for TmuxSupervisor {
116    fn up(&self, spec: &AgentSpec) -> Result<()> {
117        if matches!(self.state(spec)?, AgentState::Running) {
118            return Ok(());
119        }
120        let cmd = format!(
121            "env $(cat {env}) {wrapper} {project}:{agent}",
122            env = shlex::try_quote(&spec.env_file.display().to_string())?,
123            wrapper = shlex::try_quote(&spec.wrapper.display().to_string())?,
124            project = spec.project,
125            agent = spec.agent,
126        );
127        let status = Command::new("tmux")
128            .args([
129                "new-session",
130                "-d",
131                "-s",
132                &spec.tmux_session,
133                "-c",
134                &spec.cwd.display().to_string(),
135                "sh",
136                "-c",
137                &cmd,
138            ])
139            .status()
140            .context("spawn tmux new-session")?;
141        anyhow::ensure!(status.success(), "tmux new-session exited {status}");
142        Ok(())
143    }
144
145    fn down(&self, spec: &AgentSpec) -> Result<()> {
146        let _ = Command::new("tmux")
147            .args(["kill-session", "-t", &spec.tmux_session])
148            .status();
149        Ok(())
150    }
151
152    fn state(&self, spec: &AgentSpec) -> Result<AgentState> {
153        let out = Command::new("tmux")
154            .args(["has-session", "-t", &spec.tmux_session])
155            .output();
156        Ok(match out {
157            Ok(o) if o.status.success() => AgentState::Running,
158            Ok(_) => AgentState::Stopped,
159            Err(_) => AgentState::Unknown,
160        })
161    }
162
163    /// Send Ctrl-C to the pane (kernel delivers SIGINT to the
164    /// foreground process), then poll for `Stopped` up to `timeout`.
165    /// Falls through to `kill-session` if the agent doesn't exit in
166    /// time. Used by `reload` so in-flight tool calls and partial
167    /// assistant responses get a chance to flush instead of being
168    /// SIGKILL'd by the prior `down()`.
169    fn drain(&self, spec: &AgentSpec, timeout: Duration) -> Result<DrainOutcome> {
170        orchestrate_drain(self, spec, timeout, || {
171            let _ = Command::new("tmux")
172                .args(["send-keys", "-t", &spec.tmux_session, "C-c"])
173                .status();
174        })
175    }
176}
177
178/// Poll `observe_state` every `interval` for up to `timeout`, returning
179/// `Graceful` if `Stopped` is observed in time and `TimedOutKilled`
180/// otherwise. Pulled out as a free function so it can be tested with
181/// fake observers — neither tmux nor real time is involved.
182fn poll_for_stopped<F: FnMut() -> AgentState>(
183    timeout: Duration,
184    interval: Duration,
185    mut observe_state: F,
186) -> DrainOutcome {
187    let deadline = Instant::now() + timeout;
188    loop {
189        if observe_state() == AgentState::Stopped {
190            return DrainOutcome::Graceful;
191        }
192        if Instant::now() >= deadline {
193            return DrainOutcome::TimedOutKilled;
194        }
195        thread::sleep(interval);
196    }
197}
198
199#[cfg(test)]
200mod drain_tests {
201    use super::*;
202    use std::cell::RefCell;
203
204    #[test]
205    fn poll_returns_graceful_when_stopped_observed_in_time() {
206        let calls = RefCell::new(0u32);
207        let outcome = poll_for_stopped(Duration::from_millis(50), Duration::from_millis(1), || {
208            let mut n = calls.borrow_mut();
209            *n += 1;
210            if *n >= 2 {
211                AgentState::Stopped
212            } else {
213                AgentState::Running
214            }
215        });
216        assert_eq!(outcome, DrainOutcome::Graceful);
217    }
218
219    #[test]
220    fn poll_falls_through_to_kill_when_agent_never_stops() {
221        let outcome = poll_for_stopped(Duration::from_millis(8), Duration::from_millis(2), || {
222            AgentState::Running
223        });
224        assert_eq!(outcome, DrainOutcome::TimedOutKilled);
225    }
226
227    #[test]
228    fn poll_zero_timeout_only_checks_once_then_kills() {
229        let mut calls: u32 = 0;
230        let outcome = poll_for_stopped(Duration::from_millis(0), Duration::from_millis(1), || {
231            calls += 1;
232            AgentState::Running
233        });
234        assert_eq!(outcome, DrainOutcome::TimedOutKilled);
235        assert_eq!(calls, 1, "single state observation before timeout");
236    }
237
238    /// Test supervisor that records every up/down/state/drain
239    /// call, optionally returns `Stopped` after N state observations,
240    /// and exposes a tunable `drain_poll_interval` so tests don't
241    /// wait on real time. Every invariant a Supervisor impl is
242    /// supposed to honour can be asserted against this.
243    #[derive(Default)]
244    struct MockSupervisor {
245        calls: RefCell<Vec<&'static str>>,
246        /// On the Nth state() call (1-indexed), return Stopped. 0 =
247        /// always Running.
248        stop_after: u32,
249        state_calls: RefCell<u32>,
250        poll_interval: Duration,
251    }
252
253    impl MockSupervisor {
254        fn record(&self, op: &'static str) {
255            self.calls.borrow_mut().push(op);
256        }
257    }
258
259    impl Supervisor for MockSupervisor {
260        fn up(&self, _spec: &AgentSpec) -> Result<()> {
261            self.record("up");
262            Ok(())
263        }
264        fn down(&self, _spec: &AgentSpec) -> Result<()> {
265            self.record("down");
266            Ok(())
267        }
268        fn state(&self, _spec: &AgentSpec) -> Result<AgentState> {
269            self.record("state");
270            let mut n = self.state_calls.borrow_mut();
271            *n += 1;
272            if self.stop_after > 0 && *n >= self.stop_after {
273                Ok(AgentState::Stopped)
274            } else {
275                Ok(AgentState::Running)
276            }
277        }
278        fn drain_poll_interval(&self) -> Duration {
279            self.poll_interval
280        }
281    }
282
283    fn fake_spec() -> AgentSpec {
284        AgentSpec {
285            project: "p".into(),
286            agent: "a".into(),
287            tmux_session: "p-a".into(),
288            wrapper: PathBuf::from("/dev/null"),
289            cwd: PathBuf::from("/tmp"),
290            env_file: PathBuf::from("/dev/null"),
291        }
292    }
293
294    #[test]
295    fn drain_with_zero_timeout_returns_timed_out_killed_and_calls_down() {
296        // Contract: timeout=0 → instant signal-fn invocation, single
297        // state observation, fall-through to down(). No graceful path,
298        // no double-kill, no other side effects.
299        let mock = MockSupervisor {
300            poll_interval: Duration::from_millis(1),
301            ..Default::default()
302        };
303        let spec = fake_spec();
304        let signaled = RefCell::new(false);
305
306        let outcome = orchestrate_drain(&mock, &spec, Duration::ZERO, || {
307            *signaled.borrow_mut() = true;
308        })
309        .unwrap();
310
311        assert_eq!(outcome, DrainOutcome::TimedOutKilled);
312        assert!(*signaled.borrow(), "signal_fn must run before the poll");
313        assert_eq!(
314            mock.calls.borrow().as_slice(),
315            &["state", "down"],
316            "zero-timeout: one state observation then kill"
317        );
318    }
319
320    #[test]
321    fn drain_with_graceful_stop_does_not_call_down() {
322        // Contract: agent observed `Stopped` within timeout → no
323        // fall-through kill. The down() side effect is reserved for
324        // forced terminations.
325        let mock = MockSupervisor {
326            poll_interval: Duration::from_millis(1),
327            stop_after: 2, // Stopped on 2nd state() call.
328            ..Default::default()
329        };
330        let spec = fake_spec();
331
332        let outcome = orchestrate_drain(&mock, &spec, Duration::from_millis(100), || {}).unwrap();
333
334        assert_eq!(outcome, DrainOutcome::Graceful);
335        assert!(
336            !mock.calls.borrow().contains(&"down"),
337            "graceful drain must not call down(); calls: {:?}",
338            mock.calls.borrow()
339        );
340    }
341
342    #[test]
343    fn drain_poll_interval_default_is_250ms() {
344        // Pin the documented default so a future "tighten the
345        // default" change has to update the docstring + this test
346        // together.
347        struct Default250;
348        impl Supervisor for Default250 {
349            fn up(&self, _: &AgentSpec) -> Result<()> {
350                Ok(())
351            }
352            fn down(&self, _: &AgentSpec) -> Result<()> {
353                Ok(())
354            }
355            fn state(&self, _: &AgentSpec) -> Result<AgentState> {
356                Ok(AgentState::Stopped)
357            }
358        }
359        assert_eq!(Default250.drain_poll_interval(), Duration::from_millis(250));
360    }
361
362    #[test]
363    fn drain_poll_interval_override_is_used_by_orchestrator() {
364        // Sanity check that the trait method's value flows into
365        // poll_for_stopped — without this, a host-specific override
366        // would silently no-op.
367        let mock = MockSupervisor {
368            poll_interval: Duration::from_millis(2),
369            stop_after: 0,
370            ..Default::default()
371        };
372        let spec = fake_spec();
373
374        let start = Instant::now();
375        let _ = orchestrate_drain(&mock, &spec, Duration::from_millis(8), || {});
376        let elapsed = start.elapsed();
377
378        // With a 2ms poll interval and an 8ms timeout, we expect a
379        // handful of state observations, not 0 and not 100. Loose
380        // bound — enough to catch a 250ms default leaking in.
381        let states = mock
382            .calls
383            .borrow()
384            .iter()
385            .filter(|c| **c == "state")
386            .count();
387        assert!(
388            states >= 2,
389            "expected several state observations at 2ms cadence, got {states}"
390        );
391        assert!(
392            elapsed < Duration::from_millis(60),
393            "drain with 2ms interval finished too slowly ({elapsed:?})"
394        );
395    }
396}
397
398mod shlex {
399    /// Minimal POSIX shell single-quote escaper so we don't pull a full dep.
400    pub fn try_quote(s: &str) -> anyhow::Result<String> {
401        anyhow::ensure!(!s.contains('\0'), "null byte in shell arg");
402        let escaped = s.replace('\'', r"'\''");
403        Ok(format!("'{escaped}'"))
404    }
405
406    #[cfg(test)]
407    mod tests {
408        use super::*;
409
410        #[test]
411        fn quotes_plain_path() {
412            assert_eq!(try_quote("/a/b.sh").unwrap(), "'/a/b.sh'");
413        }
414
415        #[test]
416        fn escapes_embedded_single_quote() {
417            assert_eq!(try_quote("x'y").unwrap(), r"'x'\''y'");
418        }
419    }
420}