1use std::path::{Path, PathBuf};
8use std::process::Command;
9use std::thread;
10use std::time::{Duration, Instant};
11
12use anyhow::{Context, Result};
13
14use crate::compose::AgentHandle;
15
16#[derive(Debug, Clone)]
17pub struct AgentSpec {
18 pub project: String,
19 pub agent: String,
20 pub tmux_session: String,
21 pub wrapper: PathBuf,
22 pub cwd: PathBuf,
23 pub env_file: PathBuf,
24}
25
26impl AgentSpec {
27 pub fn from_handle(h: AgentHandle<'_>, root: &Path, tmux_prefix: &str) -> Self {
28 Self {
29 project: h.project.into(),
30 agent: h.agent.into(),
31 tmux_session: format!("{tmux_prefix}{}-{}", h.project, h.agent),
32 wrapper: root.join("bin/agent-wrapper.sh"),
33 cwd: root.to_path_buf(),
34 env_file: crate::render::env_path(root, h.project, h.agent),
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum AgentState {
42 Running,
43 Stopped,
44 Unknown,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum DrainOutcome {
54 Graceful,
55 TimedOutKilled,
56}
57
58pub trait Supervisor {
59 fn up(&self, spec: &AgentSpec) -> Result<()>;
60 fn down(&self, spec: &AgentSpec) -> Result<()>;
61 fn state(&self, spec: &AgentSpec) -> Result<AgentState>;
62
63 fn drain(&self, spec: &AgentSpec, _timeout: Duration) -> Result<DrainOutcome> {
68 self.down(spec)?;
69 Ok(DrainOutcome::TimedOutKilled)
70 }
71
72 fn drain_poll_interval(&self) -> Duration {
79 Duration::from_millis(250)
80 }
81}
82
83pub fn orchestrate_drain<S, F>(
93 supervisor: &S,
94 spec: &AgentSpec,
95 timeout: Duration,
96 signal_fn: F,
97) -> Result<DrainOutcome>
98where
99 S: Supervisor + ?Sized,
100 F: FnOnce(),
101{
102 signal_fn();
103 let outcome = poll_for_stopped(timeout, supervisor.drain_poll_interval(), || {
104 supervisor.state(spec).unwrap_or(AgentState::Unknown)
105 });
106 if outcome == DrainOutcome::TimedOutKilled {
107 supervisor.down(spec)?;
108 }
109 Ok(outcome)
110}
111
112pub struct TmuxSupervisor;
114
115impl Supervisor for TmuxSupervisor {
116 fn up(&self, spec: &AgentSpec) -> Result<()> {
117 if matches!(self.state(spec)?, AgentState::Running) {
118 return Ok(());
119 }
120 let cmd = format!(
121 "env $(cat {env}) {wrapper} {project}:{agent}",
122 env = shlex::try_quote(&spec.env_file.display().to_string())?,
123 wrapper = shlex::try_quote(&spec.wrapper.display().to_string())?,
124 project = spec.project,
125 agent = spec.agent,
126 );
127 let status = Command::new("tmux")
128 .args([
129 "new-session",
130 "-d",
131 "-s",
132 &spec.tmux_session,
133 "-c",
134 &spec.cwd.display().to_string(),
135 "sh",
136 "-c",
137 &cmd,
138 ])
139 .status()
140 .context("spawn tmux new-session")?;
141 anyhow::ensure!(status.success(), "tmux new-session exited {status}");
142 Ok(())
143 }
144
145 fn down(&self, spec: &AgentSpec) -> Result<()> {
146 let _ = Command::new("tmux")
147 .args(["kill-session", "-t", &spec.tmux_session])
148 .status();
149 Ok(())
150 }
151
152 fn state(&self, spec: &AgentSpec) -> Result<AgentState> {
153 let out = Command::new("tmux")
154 .args(["has-session", "-t", &spec.tmux_session])
155 .output();
156 Ok(match out {
157 Ok(o) if o.status.success() => AgentState::Running,
158 Ok(_) => AgentState::Stopped,
159 Err(_) => AgentState::Unknown,
160 })
161 }
162
163 fn drain(&self, spec: &AgentSpec, timeout: Duration) -> Result<DrainOutcome> {
170 orchestrate_drain(self, spec, timeout, || {
171 let _ = Command::new("tmux")
172 .args(["send-keys", "-t", &spec.tmux_session, "C-c"])
173 .status();
174 })
175 }
176}
177
178fn poll_for_stopped<F: FnMut() -> AgentState>(
183 timeout: Duration,
184 interval: Duration,
185 mut observe_state: F,
186) -> DrainOutcome {
187 let deadline = Instant::now() + timeout;
188 loop {
189 if observe_state() == AgentState::Stopped {
190 return DrainOutcome::Graceful;
191 }
192 if Instant::now() >= deadline {
193 return DrainOutcome::TimedOutKilled;
194 }
195 thread::sleep(interval);
196 }
197}
198
199#[cfg(test)]
200mod drain_tests {
201 use super::*;
202 use std::cell::RefCell;
203
204 #[test]
205 fn poll_returns_graceful_when_stopped_observed_in_time() {
206 let calls = RefCell::new(0u32);
207 let outcome = poll_for_stopped(Duration::from_millis(50), Duration::from_millis(1), || {
208 let mut n = calls.borrow_mut();
209 *n += 1;
210 if *n >= 2 {
211 AgentState::Stopped
212 } else {
213 AgentState::Running
214 }
215 });
216 assert_eq!(outcome, DrainOutcome::Graceful);
217 }
218
219 #[test]
220 fn poll_falls_through_to_kill_when_agent_never_stops() {
221 let outcome = poll_for_stopped(Duration::from_millis(8), Duration::from_millis(2), || {
222 AgentState::Running
223 });
224 assert_eq!(outcome, DrainOutcome::TimedOutKilled);
225 }
226
227 #[test]
228 fn poll_zero_timeout_only_checks_once_then_kills() {
229 let mut calls: u32 = 0;
230 let outcome = poll_for_stopped(Duration::from_millis(0), Duration::from_millis(1), || {
231 calls += 1;
232 AgentState::Running
233 });
234 assert_eq!(outcome, DrainOutcome::TimedOutKilled);
235 assert_eq!(calls, 1, "single state observation before timeout");
236 }
237
238 #[derive(Default)]
244 struct MockSupervisor {
245 calls: RefCell<Vec<&'static str>>,
246 stop_after: u32,
249 state_calls: RefCell<u32>,
250 poll_interval: Duration,
251 }
252
253 impl MockSupervisor {
254 fn record(&self, op: &'static str) {
255 self.calls.borrow_mut().push(op);
256 }
257 }
258
259 impl Supervisor for MockSupervisor {
260 fn up(&self, _spec: &AgentSpec) -> Result<()> {
261 self.record("up");
262 Ok(())
263 }
264 fn down(&self, _spec: &AgentSpec) -> Result<()> {
265 self.record("down");
266 Ok(())
267 }
268 fn state(&self, _spec: &AgentSpec) -> Result<AgentState> {
269 self.record("state");
270 let mut n = self.state_calls.borrow_mut();
271 *n += 1;
272 if self.stop_after > 0 && *n >= self.stop_after {
273 Ok(AgentState::Stopped)
274 } else {
275 Ok(AgentState::Running)
276 }
277 }
278 fn drain_poll_interval(&self) -> Duration {
279 self.poll_interval
280 }
281 }
282
283 fn fake_spec() -> AgentSpec {
284 AgentSpec {
285 project: "p".into(),
286 agent: "a".into(),
287 tmux_session: "p-a".into(),
288 wrapper: PathBuf::from("/dev/null"),
289 cwd: PathBuf::from("/tmp"),
290 env_file: PathBuf::from("/dev/null"),
291 }
292 }
293
294 #[test]
295 fn drain_with_zero_timeout_returns_timed_out_killed_and_calls_down() {
296 let mock = MockSupervisor {
300 poll_interval: Duration::from_millis(1),
301 ..Default::default()
302 };
303 let spec = fake_spec();
304 let signaled = RefCell::new(false);
305
306 let outcome = orchestrate_drain(&mock, &spec, Duration::ZERO, || {
307 *signaled.borrow_mut() = true;
308 })
309 .unwrap();
310
311 assert_eq!(outcome, DrainOutcome::TimedOutKilled);
312 assert!(*signaled.borrow(), "signal_fn must run before the poll");
313 assert_eq!(
314 mock.calls.borrow().as_slice(),
315 &["state", "down"],
316 "zero-timeout: one state observation then kill"
317 );
318 }
319
320 #[test]
321 fn drain_with_graceful_stop_does_not_call_down() {
322 let mock = MockSupervisor {
326 poll_interval: Duration::from_millis(1),
327 stop_after: 2, ..Default::default()
329 };
330 let spec = fake_spec();
331
332 let outcome = orchestrate_drain(&mock, &spec, Duration::from_millis(100), || {}).unwrap();
333
334 assert_eq!(outcome, DrainOutcome::Graceful);
335 assert!(
336 !mock.calls.borrow().contains(&"down"),
337 "graceful drain must not call down(); calls: {:?}",
338 mock.calls.borrow()
339 );
340 }
341
342 #[test]
343 fn drain_poll_interval_default_is_250ms() {
344 struct Default250;
348 impl Supervisor for Default250 {
349 fn up(&self, _: &AgentSpec) -> Result<()> {
350 Ok(())
351 }
352 fn down(&self, _: &AgentSpec) -> Result<()> {
353 Ok(())
354 }
355 fn state(&self, _: &AgentSpec) -> Result<AgentState> {
356 Ok(AgentState::Stopped)
357 }
358 }
359 assert_eq!(Default250.drain_poll_interval(), Duration::from_millis(250));
360 }
361
362 #[test]
363 fn drain_poll_interval_override_is_used_by_orchestrator() {
364 let mock = MockSupervisor {
368 poll_interval: Duration::from_millis(2),
369 stop_after: 0,
370 ..Default::default()
371 };
372 let spec = fake_spec();
373
374 let start = Instant::now();
375 let _ = orchestrate_drain(&mock, &spec, Duration::from_millis(8), || {});
376 let elapsed = start.elapsed();
377
378 let states = mock
382 .calls
383 .borrow()
384 .iter()
385 .filter(|c| **c == "state")
386 .count();
387 assert!(
388 states >= 2,
389 "expected several state observations at 2ms cadence, got {states}"
390 );
391 assert!(
392 elapsed < Duration::from_millis(60),
393 "drain with 2ms interval finished too slowly ({elapsed:?})"
394 );
395 }
396}
397
398mod shlex {
399 pub fn try_quote(s: &str) -> anyhow::Result<String> {
401 anyhow::ensure!(!s.contains('\0'), "null byte in shell arg");
402 let escaped = s.replace('\'', r"'\''");
403 Ok(format!("'{escaped}'"))
404 }
405
406 #[cfg(test)]
407 mod tests {
408 use super::*;
409
410 #[test]
411 fn quotes_plain_path() {
412 assert_eq!(try_quote("/a/b.sh").unwrap(), "'/a/b.sh'");
413 }
414
415 #[test]
416 fn escapes_embedded_single_quote() {
417 assert_eq!(try_quote("x'y").unwrap(), r"'x'\''y'");
418 }
419 }
420}