1use std::path::{Path, PathBuf};
8use std::process::Command;
9use std::thread;
10use std::time::{Duration, Instant};
11
12use anyhow::{Context, Result};
13
14use crate::compose::AgentHandle;
15
16#[derive(Debug, Clone)]
17pub struct AgentSpec {
18 pub project: String,
19 pub agent: String,
20 pub tmux_session: String,
21 pub wrapper: PathBuf,
22 pub cwd: PathBuf,
23 pub env_file: PathBuf,
24}
25
26impl AgentSpec {
27 pub fn from_handle(h: AgentHandle<'_>, root: &Path, tmux_prefix: &str) -> Self {
28 Self {
29 project: h.project.into(),
30 agent: h.agent.into(),
31 tmux_session: format!("{tmux_prefix}{}-{}", h.project, h.agent),
32 wrapper: root.join("bin/agent-wrapper.sh"),
33 cwd: root.to_path_buf(),
34 env_file: crate::render::env_path(root, h.project, h.agent),
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum AgentState {
42 Running,
43 Stopped,
44 Unknown,
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum DrainOutcome {
54 Graceful,
55 TimedOutKilled,
56}
57
58pub trait Supervisor {
59 fn up(&self, spec: &AgentSpec) -> Result<()>;
60 fn down(&self, spec: &AgentSpec) -> Result<()>;
61 fn state(&self, spec: &AgentSpec) -> Result<AgentState>;
62
63 fn drain(&self, spec: &AgentSpec, _timeout: Duration) -> Result<DrainOutcome> {
68 self.down(spec)?;
69 Ok(DrainOutcome::TimedOutKilled)
70 }
71
72 fn drain_poll_interval(&self) -> Duration {
79 Duration::from_millis(250)
80 }
81}
82
83pub fn orchestrate_drain<S, F>(
93 supervisor: &S,
94 spec: &AgentSpec,
95 timeout: Duration,
96 signal_fn: F,
97) -> Result<DrainOutcome>
98where
99 S: Supervisor + ?Sized,
100 F: FnOnce(),
101{
102 signal_fn();
103 let outcome = poll_for_stopped(timeout, supervisor.drain_poll_interval(), || {
104 supervisor.state(spec).unwrap_or(AgentState::Unknown)
105 });
106 if outcome == DrainOutcome::TimedOutKilled {
107 supervisor.down(spec)?;
108 }
109 Ok(outcome)
110}
111
112pub struct TmuxSupervisor;
114
115impl Supervisor for TmuxSupervisor {
116 fn up(&self, spec: &AgentSpec) -> Result<()> {
117 if matches!(self.state(spec)?, AgentState::Running) {
118 return Ok(());
119 }
120 let cmd = format!(
121 "env $(cat {env}) {wrapper} {project}:{agent}",
122 env = shlex::try_quote(&spec.env_file.display().to_string())?,
123 wrapper = shlex::try_quote(&spec.wrapper.display().to_string())?,
124 project = spec.project,
125 agent = spec.agent,
126 );
127 let status = Command::new("tmux")
134 .args([
135 "new-session",
136 "-d",
137 "-x",
138 "200",
139 "-y",
140 "50",
141 "-s",
142 &spec.tmux_session,
143 "-c",
144 &spec.cwd.display().to_string(),
145 "sh",
146 "-c",
147 &cmd,
148 ])
149 .status()
150 .context("spawn tmux new-session")?;
151 anyhow::ensure!(status.success(), "tmux new-session exited {status}");
152 let cwd_str = spec.cwd.to_string_lossy();
157 for (key, value) in [
158 ("@teamctl", "1"),
159 ("@teamctl-project", spec.project.as_str()),
160 ("@teamctl-agent", spec.agent.as_str()),
161 ("@teamctl-root", cwd_str.as_ref()),
162 ] {
163 let _ = Command::new("tmux")
164 .args(["set-option", "-q", "-t", &spec.tmux_session, key, value])
165 .status();
166 }
167 Ok(())
168 }
169
170 fn down(&self, spec: &AgentSpec) -> Result<()> {
171 let _ = Command::new("tmux")
172 .args(["kill-session", "-t", &spec.tmux_session])
173 .status();
174 Ok(())
175 }
176
177 fn state(&self, spec: &AgentSpec) -> Result<AgentState> {
178 let out = Command::new("tmux")
179 .args(["has-session", "-t", &spec.tmux_session])
180 .output();
181 Ok(match out {
182 Ok(o) if o.status.success() => AgentState::Running,
183 Ok(_) => AgentState::Stopped,
184 Err(_) => AgentState::Unknown,
185 })
186 }
187
188 fn drain(&self, spec: &AgentSpec, timeout: Duration) -> Result<DrainOutcome> {
195 orchestrate_drain(self, spec, timeout, || {
196 let _ = Command::new("tmux")
197 .args(["send-keys", "-t", &spec.tmux_session, "C-c"])
198 .status();
199 })
200 }
201}
202
203fn poll_for_stopped<F: FnMut() -> AgentState>(
208 timeout: Duration,
209 interval: Duration,
210 mut observe_state: F,
211) -> DrainOutcome {
212 let deadline = Instant::now() + timeout;
213 loop {
214 if observe_state() == AgentState::Stopped {
215 return DrainOutcome::Graceful;
216 }
217 if Instant::now() >= deadline {
218 return DrainOutcome::TimedOutKilled;
219 }
220 thread::sleep(interval);
221 }
222}
223
224#[cfg(test)]
225mod drain_tests {
226 use super::*;
227 use std::cell::RefCell;
228
229 #[test]
230 fn poll_returns_graceful_when_stopped_observed_in_time() {
231 let calls = RefCell::new(0u32);
232 let outcome = poll_for_stopped(Duration::from_millis(50), Duration::from_millis(1), || {
233 let mut n = calls.borrow_mut();
234 *n += 1;
235 if *n >= 2 {
236 AgentState::Stopped
237 } else {
238 AgentState::Running
239 }
240 });
241 assert_eq!(outcome, DrainOutcome::Graceful);
242 }
243
244 #[test]
245 fn poll_falls_through_to_kill_when_agent_never_stops() {
246 let outcome = poll_for_stopped(Duration::from_millis(8), Duration::from_millis(2), || {
247 AgentState::Running
248 });
249 assert_eq!(outcome, DrainOutcome::TimedOutKilled);
250 }
251
252 #[test]
253 fn poll_zero_timeout_only_checks_once_then_kills() {
254 let mut calls: u32 = 0;
255 let outcome = poll_for_stopped(Duration::from_millis(0), Duration::from_millis(1), || {
256 calls += 1;
257 AgentState::Running
258 });
259 assert_eq!(outcome, DrainOutcome::TimedOutKilled);
260 assert_eq!(calls, 1, "single state observation before timeout");
261 }
262
263 #[derive(Default)]
269 struct MockSupervisor {
270 calls: RefCell<Vec<&'static str>>,
271 stop_after: u32,
274 state_calls: RefCell<u32>,
275 poll_interval: Duration,
276 }
277
278 impl MockSupervisor {
279 fn record(&self, op: &'static str) {
280 self.calls.borrow_mut().push(op);
281 }
282 }
283
284 impl Supervisor for MockSupervisor {
285 fn up(&self, _spec: &AgentSpec) -> Result<()> {
286 self.record("up");
287 Ok(())
288 }
289 fn down(&self, _spec: &AgentSpec) -> Result<()> {
290 self.record("down");
291 Ok(())
292 }
293 fn state(&self, _spec: &AgentSpec) -> Result<AgentState> {
294 self.record("state");
295 let mut n = self.state_calls.borrow_mut();
296 *n += 1;
297 if self.stop_after > 0 && *n >= self.stop_after {
298 Ok(AgentState::Stopped)
299 } else {
300 Ok(AgentState::Running)
301 }
302 }
303 fn drain_poll_interval(&self) -> Duration {
304 self.poll_interval
305 }
306 }
307
308 fn fake_spec() -> AgentSpec {
309 AgentSpec {
310 project: "p".into(),
311 agent: "a".into(),
312 tmux_session: "p-a".into(),
313 wrapper: PathBuf::from("/dev/null"),
314 cwd: PathBuf::from("/tmp"),
315 env_file: PathBuf::from("/dev/null"),
316 }
317 }
318
319 #[test]
320 fn drain_with_zero_timeout_returns_timed_out_killed_and_calls_down() {
321 let mock = MockSupervisor {
325 poll_interval: Duration::from_millis(1),
326 ..Default::default()
327 };
328 let spec = fake_spec();
329 let signaled = RefCell::new(false);
330
331 let outcome = orchestrate_drain(&mock, &spec, Duration::ZERO, || {
332 *signaled.borrow_mut() = true;
333 })
334 .unwrap();
335
336 assert_eq!(outcome, DrainOutcome::TimedOutKilled);
337 assert!(*signaled.borrow(), "signal_fn must run before the poll");
338 assert_eq!(
339 mock.calls.borrow().as_slice(),
340 &["state", "down"],
341 "zero-timeout: one state observation then kill"
342 );
343 }
344
345 #[test]
346 fn drain_with_graceful_stop_does_not_call_down() {
347 let mock = MockSupervisor {
351 poll_interval: Duration::from_millis(1),
352 stop_after: 2, ..Default::default()
354 };
355 let spec = fake_spec();
356
357 let outcome = orchestrate_drain(&mock, &spec, Duration::from_millis(100), || {}).unwrap();
358
359 assert_eq!(outcome, DrainOutcome::Graceful);
360 assert!(
361 !mock.calls.borrow().contains(&"down"),
362 "graceful drain must not call down(); calls: {:?}",
363 mock.calls.borrow()
364 );
365 }
366
367 #[test]
368 fn drain_poll_interval_default_is_250ms() {
369 struct Default250;
373 impl Supervisor for Default250 {
374 fn up(&self, _: &AgentSpec) -> Result<()> {
375 Ok(())
376 }
377 fn down(&self, _: &AgentSpec) -> Result<()> {
378 Ok(())
379 }
380 fn state(&self, _: &AgentSpec) -> Result<AgentState> {
381 Ok(AgentState::Stopped)
382 }
383 }
384 assert_eq!(Default250.drain_poll_interval(), Duration::from_millis(250));
385 }
386
387 #[test]
388 fn drain_poll_interval_override_is_used_by_orchestrator() {
389 let mock = MockSupervisor {
393 poll_interval: Duration::from_millis(2),
394 stop_after: 0,
395 ..Default::default()
396 };
397 let spec = fake_spec();
398
399 let start = Instant::now();
400 let _ = orchestrate_drain(&mock, &spec, Duration::from_millis(8), || {});
401 let elapsed = start.elapsed();
402
403 let states = mock
407 .calls
408 .borrow()
409 .iter()
410 .filter(|c| **c == "state")
411 .count();
412 assert!(
413 states >= 2,
414 "expected several state observations at 2ms cadence, got {states}"
415 );
416 assert!(
417 elapsed < Duration::from_millis(60),
418 "drain with 2ms interval finished too slowly ({elapsed:?})"
419 );
420 }
421}
422
423mod shlex {
424 pub fn try_quote(s: &str) -> anyhow::Result<String> {
426 anyhow::ensure!(!s.contains('\0'), "null byte in shell arg");
427 let escaped = s.replace('\'', r"'\''");
428 Ok(format!("'{escaped}'"))
429 }
430
431 #[cfg(test)]
432 mod tests {
433 use super::*;
434
435 #[test]
436 fn quotes_plain_path() {
437 assert_eq!(try_quote("/a/b.sh").unwrap(), "'/a/b.sh'");
438 }
439
440 #[test]
441 fn escapes_embedded_single_quote() {
442 assert_eq!(try_quote("x'y").unwrap(), r"'x'\''y'");
443 }
444 }
445}