Skip to main content

sgr_agent/
swarm_tools.rs

1//! Swarm tools — tools for the parent agent to manage sub-agents.
2//!
3//! These tools are registered in the parent's ToolRegistry, allowing the agent
4//! to spawn, wait, query status, and cancel sub-agents via normal tool calls.
5
6use crate::agent_tool::{Tool, ToolError, ToolOutput, parse_args};
7use crate::context::AgentContext;
8use crate::swarm::{AgentId, AgentRole, SwarmManager};
9use serde::Deserialize;
10use serde_json::Value;
11use std::sync::Arc;
12use tokio::sync::Mutex;
13
14/// Shared swarm manager reference for tools.
15pub type SharedSwarm = Arc<Mutex<SwarmManager>>;
16
17/// Create a shared swarm manager.
18pub fn shared_swarm(manager: SwarmManager) -> SharedSwarm {
19    Arc::new(Mutex::new(manager))
20}
21
22// --- SpawnAgentTool ---
23
24#[derive(Deserialize)]
25struct SpawnArgs {
26    /// Role: "explorer", "worker", "reviewer", or custom name
27    role: String,
28    /// Task description for the sub-agent
29    task: String,
30    /// Optional system prompt override
31    system_prompt: Option<String>,
32    /// Optional max steps (default: role-dependent)
33    max_steps: Option<usize>,
34    /// Optional working directory
35    cwd: Option<String>,
36}
37
38/// Tool for spawning sub-agents.
39pub struct SpawnAgentTool {
40    swarm: SharedSwarm,
41    /// Factory function to create agent + tools for a given role.
42    /// The parent must provide this to wire up LlmClient, tools, etc.
43    factory: Arc<dyn AgentFactory>,
44}
45
46/// Factory for creating agent + tool registry based on role.
47///
48/// Implementors provide the actual LlmClient and tools for each role.
49#[async_trait::async_trait]
50pub trait AgentFactory: Send + Sync {
51    /// Create an agent and its tool registry for the given role.
52    async fn create(
53        &self,
54        role: &AgentRole,
55        system_prompt: Option<&str>,
56    ) -> Result<(Box<dyn crate::agent::Agent>, crate::registry::ToolRegistry), String>;
57}
58
59impl SpawnAgentTool {
60    pub fn new(swarm: SharedSwarm, factory: Arc<dyn AgentFactory>) -> Self {
61        Self { swarm, factory }
62    }
63}
64
65#[async_trait::async_trait]
66impl Tool for SpawnAgentTool {
67    fn name(&self) -> &str {
68        "spawn_agent"
69    }
70
71    fn description(&self) -> &str {
72        "Spawn a sub-agent with a specific role and task. Roles: explorer (fast, read-only), worker (smart, read-write), reviewer (read-only, thorough)."
73    }
74
75    fn parameters_schema(&self) -> Value {
76        serde_json::json!({
77            "type": "object",
78            "required": ["role", "task"],
79            "properties": {
80                "role": {
81                    "type": "string",
82                    "description": "Agent role: explorer, worker, reviewer, or custom name"
83                },
84                "task": {
85                    "type": "string",
86                    "description": "Task description for the sub-agent"
87                },
88                "system_prompt": {
89                    "type": "string",
90                    "description": "Optional system prompt override"
91                },
92                "max_steps": {
93                    "type": "integer",
94                    "description": "Optional max steps for the agent loop"
95                },
96                "cwd": {
97                    "type": "string",
98                    "description": "Optional working directory"
99                }
100            }
101        })
102    }
103
104    async fn execute(&self, args: Value, ctx: &mut AgentContext) -> Result<ToolOutput, ToolError> {
105        let args: SpawnArgs = parse_args(&args)?;
106
107        let role = match args.role.as_str() {
108            "explorer" => AgentRole::Explorer,
109            "worker" => AgentRole::Worker,
110            "reviewer" => AgentRole::Reviewer,
111            other => AgentRole::Custom(other.to_string()),
112        };
113
114        let (agent, tools) = self
115            .factory
116            .create(&role, args.system_prompt.as_deref())
117            .await
118            .map_err(ToolError::Execution)?;
119
120        let config = crate::swarm::SpawnConfig {
121            role: role.clone(),
122            system_prompt: args.system_prompt,
123            tool_names: None,
124            cwd: args.cwd.map(std::path::PathBuf::from),
125            task: args.task.clone(),
126            max_steps: args.max_steps.unwrap_or(match &role {
127                AgentRole::Explorer => 10,
128                AgentRole::Worker => 30,
129                AgentRole::Reviewer => 15,
130                AgentRole::Custom(_) => 20,
131            }),
132            writable_roots: None,
133        };
134
135        let mut swarm = self.swarm.lock().await;
136        let id = swarm
137            .spawn(config, agent, tools, ctx)
138            .map_err(|e| ToolError::Execution(e.to_string()))?;
139
140        Ok(ToolOutput::text(format!(
141            "Spawned {} agent (id: {}): {}",
142            role.name(),
143            id,
144            args.task
145        )))
146    }
147}
148
149// --- WaitAgentsTool ---
150
151#[derive(Deserialize)]
152struct WaitArgs {
153    /// Agent IDs to wait for. If empty, wait for all.
154    #[serde(default)]
155    ids: Vec<String>,
156    /// Timeout in seconds (default: 300)
157    timeout_secs: Option<u64>,
158}
159
160/// Tool for waiting on sub-agents to complete.
161pub struct WaitAgentsTool {
162    swarm: SharedSwarm,
163}
164
165impl WaitAgentsTool {
166    pub fn new(swarm: SharedSwarm) -> Self {
167        Self { swarm }
168    }
169}
170
171#[async_trait::async_trait]
172impl Tool for WaitAgentsTool {
173    fn name(&self) -> &str {
174        "wait_agents"
175    }
176
177    fn description(&self) -> &str {
178        "Wait for sub-agents to complete. Provide specific IDs or wait for all."
179    }
180
181    fn parameters_schema(&self) -> Value {
182        serde_json::json!({
183            "type": "object",
184            "properties": {
185                "ids": {
186                    "type": "array",
187                    "items": {"type": "string"},
188                    "description": "Agent IDs to wait for. Empty = wait all."
189                },
190                "timeout_secs": {
191                    "type": "integer",
192                    "description": "Timeout in seconds (default: 300)"
193                }
194            }
195        })
196    }
197
198    async fn execute(&self, args: Value, _ctx: &mut AgentContext) -> Result<ToolOutput, ToolError> {
199        let args: WaitArgs = parse_args(&args)?;
200        let timeout = std::time::Duration::from_secs(args.timeout_secs.unwrap_or(300));
201
202        // Take receivers under lock, then drop lock before awaiting (avoid deadlock)
203        let receivers = {
204            let mut swarm = self.swarm.lock().await;
205            if args.ids.is_empty() {
206                swarm.take_all_receivers()
207            } else {
208                let mut rxs = Vec::new();
209                for id_str in &args.ids {
210                    let id = AgentId::from(id_str.as_str());
211                    match swarm.take_receiver(&id) {
212                        Ok(rx) => rxs.push((id, rx)),
213                        Err(e) => {
214                            return Err(ToolError::Execution(format!(
215                                "Error for {}: {}",
216                                id_str, e
217                            )));
218                        }
219                    }
220                }
221                rxs
222            }
223        }; // lock dropped here
224
225        // Await results without holding the lock
226        let mut results = Vec::new();
227        for (id, rx) in receivers {
228            match tokio::time::timeout(timeout, rx).await {
229                Ok(Ok(result)) => results.push(result),
230                Ok(Err(_)) => {
231                    return Err(ToolError::Execution(format!("Channel closed for {}", id)));
232                }
233                Err(_) => return Err(ToolError::Execution(format!("Timeout waiting for {}", id))),
234            }
235        }
236
237        // Cleanup completed agents
238        {
239            let mut swarm = self.swarm.lock().await;
240            for r in &results {
241                swarm.cleanup(&r.id);
242            }
243        }
244
245        let mut output = String::new();
246        for r in &results {
247            output.push_str(&format!(
248                "<agent_result id=\"{}\" role=\"{}\" status=\"{}\">\n{}\n</agent_result>\n",
249                r.id, r.role, r.status, r.summary
250            ));
251        }
252
253        if output.is_empty() {
254            output = "No agents to wait for.".to_string();
255        }
256
257        Ok(ToolOutput::text(output))
258    }
259}
260
261// --- GetStatusTool ---
262
263/// Tool for checking status of sub-agents.
264pub struct GetStatusTool {
265    swarm: SharedSwarm,
266}
267
268impl GetStatusTool {
269    pub fn new(swarm: SharedSwarm) -> Self {
270        Self { swarm }
271    }
272}
273
274#[async_trait::async_trait]
275impl Tool for GetStatusTool {
276    fn name(&self) -> &str {
277        "agent_status"
278    }
279
280    fn description(&self) -> &str {
281        "Get status of all active sub-agents."
282    }
283
284    fn parameters_schema(&self) -> Value {
285        serde_json::json!({"type": "object"})
286    }
287
288    async fn execute(
289        &self,
290        _args: Value,
291        _ctx: &mut AgentContext,
292    ) -> Result<ToolOutput, ToolError> {
293        let swarm = self.swarm.lock().await;
294        let statuses = swarm.status_all().await;
295
296        if statuses.is_empty() {
297            return Ok(ToolOutput::text("No active agents."));
298        }
299
300        let mut output = String::new();
301        for (id, role, status) in &statuses {
302            output.push_str(&format!("- {} ({}) — {}\n", id, role, status));
303        }
304
305        Ok(ToolOutput::text(output))
306    }
307}
308
309// --- CancelAgentTool ---
310
311#[derive(Deserialize)]
312struct CancelArgs {
313    /// Agent ID to cancel. "all" to cancel all.
314    id: String,
315}
316
317/// Tool for cancelling sub-agents.
318pub struct CancelAgentTool {
319    swarm: SharedSwarm,
320}
321
322impl CancelAgentTool {
323    pub fn new(swarm: SharedSwarm) -> Self {
324        Self { swarm }
325    }
326}
327
328#[async_trait::async_trait]
329impl Tool for CancelAgentTool {
330    fn name(&self) -> &str {
331        "cancel_agent"
332    }
333
334    fn description(&self) -> &str {
335        "Cancel a running sub-agent by ID, or 'all' to cancel all agents."
336    }
337
338    fn parameters_schema(&self) -> Value {
339        serde_json::json!({
340            "type": "object",
341            "required": ["id"],
342            "properties": {
343                "id": {
344                    "type": "string",
345                    "description": "Agent ID to cancel, or 'all'"
346                }
347            }
348        })
349    }
350
351    async fn execute(&self, args: Value, _ctx: &mut AgentContext) -> Result<ToolOutput, ToolError> {
352        let args: CancelArgs = parse_args(&args)?;
353
354        let swarm = self.swarm.lock().await;
355
356        if args.id == "all" {
357            swarm.cancel_all();
358            Ok(ToolOutput::text("Cancelled all agents."))
359        } else {
360            let id = AgentId::from(args.id.as_str());
361            swarm
362                .cancel(&id)
363                .map_err(|e| ToolError::Execution(e.to_string()))?;
364            Ok(ToolOutput::text(format!("Cancelled agent {}.", args.id)))
365        }
366    }
367}
368
369// Helper: construct AgentId from string
370impl From<&str> for AgentId {
371    fn from(s: &str) -> Self {
372        Self(s.to_string())
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn agent_id_from_str() {
382        let id = AgentId::from("abc123");
383        assert_eq!(id.short(), "abc123");
384        assert_eq!(format!("{}", id), "abc123");
385    }
386
387    #[test]
388    fn agent_role_names() {
389        assert_eq!(AgentRole::Explorer.name(), "explorer");
390        assert_eq!(AgentRole::Custom("planner".into()).name(), "planner");
391    }
392
393    #[test]
394    fn shared_swarm_creates() {
395        let swarm = shared_swarm(SwarmManager::new());
396        // Should compile and not panic
397        drop(swarm);
398    }
399}