Skip to main content

synaps_cli/runtime/
subagent.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3use tokio::sync::{mpsc, oneshot};
4use serde_json::Value;
5
6// ── SubagentResult ───────────────────────────────────────────────────────────────
7
8#[derive(Debug)]
9pub struct SubagentResult {
10    pub text: String,
11    pub model: String,
12    pub input_tokens: u64,
13    pub output_tokens: u64,
14    pub cache_read: u64,
15    pub cache_creation: u64,
16    pub tool_count: u32,
17}
18
19// ── SubagentStatus ───────────────────────────────────────────────────────────────
20
21#[derive(Debug, Clone, PartialEq)]
22pub enum SubagentStatus {
23    Running,
24    Completed,
25    TimedOut,
26    Failed(String),
27}
28
29// ── SubagentState ────────────────────────────────────────────────────────────────
30
31/// All mutable state shared between the subagent thread and its handle.
32/// Collapsed behind a single RwLock so a status poll takes exactly one lock.
33#[derive(Debug)]
34pub struct SubagentState {
35    pub status: SubagentStatus,
36    pub partial_text: String,
37    pub tool_log: Vec<String>,
38    pub conversation_state: Vec<Value>,
39}
40
41impl SubagentState {
42    pub fn new() -> Self {
43        Self {
44            status: SubagentStatus::Running,
45            partial_text: String::new(),
46            tool_log: Vec::new(),
47            conversation_state: Vec::new(),
48        }
49    }
50}
51
52impl Default for SubagentState {
53    fn default() -> Self { Self::new() }
54}
55
56// ── SubagentHandle ───────────────────────────────────────────────────────────────
57
58pub struct SubagentHandle {
59    pub id: String,
60    pub agent_name: String,
61    pub task_preview: String,
62    pub model: String,
63    pub system_prompt: String,
64    pub started_at: std::time::Instant,
65    pub timeout_secs: u64,
66
67    // Shared state updated by the subagent thread — one lock for everything.
68    state: Arc<RwLock<SubagentState>>,
69
70    // Channels
71    steer_tx: Option<mpsc::UnboundedSender<String>>,
72    shutdown_tx: Option<oneshot::Sender<()>>,
73    /// OS thread running the subagent. Stored for graceful shutdown (join).
74    // OS thread handle for graceful shutdown
75    thread_handle: Option<std::thread::JoinHandle<()>>,
76
77    // Final result
78    result_rx: Option<oneshot::Receiver<SubagentResult>>,
79}
80
81impl std::fmt::Debug for SubagentHandle {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("SubagentHandle")
84            .field("id", &self.id)
85            .field("agent_name", &self.agent_name)
86            .field("model", &self.model)
87            .finish_non_exhaustive()
88    }
89}
90
91impl SubagentHandle {
92    /// Construct a new handle. The state Arc is shared with the spawned subagent thread.
93    #[allow(clippy::too_many_arguments)]
94    pub fn new(
95        id: String,
96        agent_name: String,
97        task_preview: String,
98        model: String,
99        system_prompt: String,
100        timeout_secs: u64,
101        state: Arc<RwLock<SubagentState>>,
102        steer_tx: Option<mpsc::UnboundedSender<String>>,
103        shutdown_tx: Option<oneshot::Sender<()>>,
104        result_rx: Option<oneshot::Receiver<SubagentResult>>,
105    ) -> Self {
106        Self {
107            id,
108            agent_name,
109            task_preview,
110            model,
111            system_prompt,
112            started_at: std::time::Instant::now(),
113            timeout_secs,
114            state,
115            steer_tx,
116            shutdown_tx,
117            thread_handle: None,
118            result_rx,
119        }
120    }
121
122    /// Current status snapshot.
123    pub fn status(&self) -> SubagentStatus {
124        self.state.read().unwrap().status.clone()
125    }
126
127    /// Partial output accumulated so far.
128    pub fn partial_output(&self) -> String {
129        self.state.read().unwrap().partial_text.clone()
130    }
131
132    /// Snapshot of the tool log.
133    pub fn tool_log(&self) -> Vec<String> {
134        self.state.read().unwrap().tool_log.clone()
135    }
136
137    /// Snapshot of conversation state (for resume).
138    pub fn conversation_state(&self) -> Vec<Value> {
139        self.state.read().unwrap().conversation_state.clone()
140    }
141
142    /// Seconds since this handle was created.
143    pub fn elapsed_secs(&self) -> f64 {
144        self.started_at.elapsed().as_secs_f64()
145    }
146
147    /// Send a steering message into the running subagent.
148    pub fn steer(&self, message: &str) -> Result<(), String> {
149        match &self.steer_tx {
150            Some(tx) => tx
151                .send(message.to_string())
152                .map_err(|e| format!("steer channel closed: {e}")),
153            None => Err("no steer channel on this handle".to_string()),
154        }
155    }
156
157    /// Signal the subagent to shut down.
158    /// Store the OS thread handle for graceful shutdown.
159    pub fn set_thread_handle(&mut self, handle: std::thread::JoinHandle<()>) {
160        self.thread_handle = Some(handle);
161    }
162
163    pub fn cancel(&mut self) {
164        if let Some(tx) = self.shutdown_tx.take() {
165            let _ = tx.send(());
166        }
167    }
168
169    /// True if the subagent is no longer running.
170    pub fn is_finished(&self) -> bool {
171        !matches!(self.status(), SubagentStatus::Running)
172    }
173
174    /// Consume the handle and wait for the final result.
175    pub async fn collect(mut self) -> Result<SubagentResult, String> {
176        match self.result_rx.take() {
177            Some(rx) => rx.await.map_err(|_| "subagent result channel dropped".to_string()),
178            None => Err("no result receiver — already collected or never set".to_string()),
179        }
180    }
181}
182
183// ── SubagentRegistry ─────────────────────────────────────────────────────────────
184
185#[derive(Debug)]
186pub struct SubagentRegistry {
187    handles: HashMap<String, SubagentHandle>,
188}
189
190impl SubagentRegistry {
191    pub fn new() -> Self {
192        Self {
193            handles: HashMap::new(),
194        }
195    }
196
197    /// Register a handle and return its id.
198    pub fn register(&mut self, handle: SubagentHandle) -> String {
199        let id = handle.id.clone();
200        self.handles.insert(id.clone(), handle);
201        id
202    }
203
204    pub fn get(&self, id: &str) -> Option<&SubagentHandle> {
205        self.handles.get(id)
206    }
207
208    pub fn get_mut(&mut self, id: &str) -> Option<&mut SubagentHandle> {
209        self.handles.get_mut(id)
210    }
211
212    pub fn remove(&mut self, id: &str) -> Option<SubagentHandle> {
213        self.handles.remove(id)
214    }
215
216    /// Returns (id, agent_name, status) for every tracked handle.
217    pub fn list_active(&self) -> Vec<(String, String, SubagentStatus)> {
218        self.handles
219            .values()
220            .map(|h| (h.id.clone(), h.agent_name.clone(), h.status()))
221            .collect()
222    }
223
224    /// Drop handles that are no longer running.
225    /// Iterate over all handles mutably (for bulk operations like cancel-all).
226    pub fn iter_mut_handles(&mut self) -> impl Iterator<Item = &mut SubagentHandle> {
227        self.handles.values_mut()
228    }
229
230    pub fn cleanup_finished(&mut self) {
231        let finished_ids: Vec<String> = self.handles.iter()
232            .filter(|(_, h)| h.is_finished())
233            .map(|(id, _)| id.clone())
234            .collect();
235        for id in finished_ids {
236            if let Some(mut handle) = self.handles.remove(&id) {
237                // Join the thread to avoid zombies/resource leaks
238                if let Some(th) = handle.thread_handle.take() {
239                    let _ = th.join();
240                }
241            }
242        }
243    }
244}
245
246impl Default for SubagentRegistry {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252impl SubagentStatus {
253    pub fn as_str(&self) -> &str {
254        match self {
255            SubagentStatus::Running => "running",
256            SubagentStatus::Completed => "completed",
257            SubagentStatus::TimedOut => "timed_out",
258            SubagentStatus::Failed(_) => "failed",
259        }
260    }
261}
262
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use tokio::sync::{mpsc, oneshot};
268
269    // Keep receivers alive so channels don't close during tests
270    struct TestHandle {
271        handle: SubagentHandle,
272        _steer_rx: mpsc::UnboundedReceiver<String>,
273        _shutdown_rx: oneshot::Receiver<()>,
274    }
275
276    fn make_test_handle(id: &str) -> TestHandle {
277        let state = Arc::new(RwLock::new(SubagentState::new()));
278        let (steer_tx, steer_rx) = mpsc::unbounded_channel();
279        let (shutdown_tx, shutdown_rx) = oneshot::channel();
280        let (_result_tx, result_rx) = oneshot::channel();
281        TestHandle {
282            handle: SubagentHandle::new(
283                id.to_string(),
284                "test-agent".to_string(),
285                "test task".to_string(),
286                "claude-sonnet-4-6".to_string(),
287                "You are a test agent.".to_string(),
288                300,
289                state,
290                Some(steer_tx),
291                Some(shutdown_tx),
292                Some(result_rx),
293            ),
294            _steer_rx: steer_rx,
295            _shutdown_rx: shutdown_rx,
296        }
297    }
298
299    fn make_handle(id: &str) -> SubagentHandle {
300        make_test_handle(id).handle
301    }
302
303    #[test]
304    fn handle_initial_status_is_running() {
305        let h = make_handle("sa_1");
306        assert_eq!(h.status(), SubagentStatus::Running);
307        assert!(!h.is_finished());
308    }
309
310    #[test]
311    fn handle_partial_output_empty_initially() {
312        let h = make_handle("sa_1");
313        assert_eq!(h.partial_output(), "");
314        assert!(h.tool_log().is_empty());
315        assert!(h.conversation_state().is_empty());
316    }
317
318    #[test]
319    fn handle_status_reflects_state_change() {
320        let h = make_handle("sa_1");
321        {
322            let mut s = h.state.write().unwrap();
323            s.status = SubagentStatus::Completed;
324            s.partial_text = "done!".to_string();
325        }
326        assert_eq!(h.status(), SubagentStatus::Completed);
327        assert!(h.is_finished());
328        assert_eq!(h.partial_output(), "done!");
329    }
330
331    #[test]
332    fn handle_steer_sends_message() {
333        let th = make_test_handle("sa_1");
334        assert!(th.handle.steer("redirect").is_ok());
335    }
336
337    #[test]
338    fn handle_steer_fails_without_channel() {
339        let state = Arc::new(RwLock::new(SubagentState::new()));
340        let (_shutdown_tx, _) = oneshot::channel::<()>();
341        let (_, result_rx) = oneshot::channel();
342        let h = SubagentHandle::new(
343            "sa_1".into(), "test".into(), "task".into(),
344            "model".into(), "prompt".into(), 300, state, None, None, Some(result_rx),
345        );
346        assert!(h.steer("msg").is_err());
347    }
348
349    #[test]
350    fn handle_cancel_consumes_shutdown() {
351        let mut h = make_handle("sa_1");
352        h.cancel(); // first call sends
353        h.cancel(); // second call is no-op (already taken)
354    }
355
356    #[test]
357    fn handle_elapsed_increases() {
358        let h = make_handle("sa_1");
359        std::thread::sleep(std::time::Duration::from_millis(10));
360        assert!(h.elapsed_secs() > 0.0);
361    }
362
363    #[test]
364    fn registry_register_and_get() {
365        let mut reg = SubagentRegistry::new();
366        let h = make_handle("sa_1");
367        reg.register(h);
368        assert!(reg.get("sa_1").is_some());
369        assert!(reg.get("sa_99").is_none());
370    }
371
372    #[test]
373    fn registry_remove() {
374        let mut reg = SubagentRegistry::new();
375        reg.register(make_handle("sa_1"));
376        assert!(reg.remove("sa_1").is_some());
377        assert!(reg.get("sa_1").is_none());
378    }
379
380    #[test]
381    fn registry_list_active() {
382        let mut reg = SubagentRegistry::new();
383        reg.register(make_handle("sa_1"));
384        reg.register(make_handle("sa_2"));
385        let active = reg.list_active();
386        assert_eq!(active.len(), 2);
387    }
388
389    #[test]
390    fn registry_cleanup_finished() {
391        let mut reg = SubagentRegistry::new();
392        let h = make_handle("sa_1");
393        {
394            let mut s = h.state.write().unwrap();
395            s.status = SubagentStatus::Completed;
396        }
397        reg.register(h);
398        reg.register(make_handle("sa_2")); // still running
399        reg.cleanup_finished();
400        assert!(reg.get("sa_1").is_none()); // completed, cleaned up
401        assert!(reg.get("sa_2").is_some()); // still running, kept
402    }
403
404    #[test]
405    fn subagent_state_new_defaults() {
406        let s = SubagentState::new();
407        assert_eq!(s.status, SubagentStatus::Running);
408        assert!(s.partial_text.is_empty());
409        assert!(s.tool_log.is_empty());
410        assert!(s.conversation_state.is_empty());
411    }
412
413    #[test]
414    fn subagent_status_as_str() {
415        assert_eq!(SubagentStatus::Running.as_str(), "running");
416        assert_eq!(SubagentStatus::Completed.as_str(), "completed");
417        assert_eq!(SubagentStatus::TimedOut.as_str(), "timed_out");
418        assert_eq!(SubagentStatus::Failed("oops".into()).as_str(), "failed");
419    }
420}