1use crate::agent_tool::{Tool, ToolError, ToolOutput, parse_args};
7use crate::context::AgentContext;
8use crate::swarm::{AgentId, AgentRole, SwarmManager};
9use serde::Deserialize;
10use serde_json::Value;
11use std::sync::Arc;
12use tokio::sync::Mutex;
13
14pub type SharedSwarm = Arc<Mutex<SwarmManager>>;
16
17pub fn shared_swarm(manager: SwarmManager) -> SharedSwarm {
19 Arc::new(Mutex::new(manager))
20}
21
22#[derive(Deserialize)]
25struct SpawnArgs {
26 role: String,
28 task: String,
30 system_prompt: Option<String>,
32 max_steps: Option<usize>,
34 cwd: Option<String>,
36}
37
38pub struct SpawnAgentTool {
40 swarm: SharedSwarm,
41 factory: Arc<dyn AgentFactory>,
44}
45
46#[async_trait::async_trait]
50pub trait AgentFactory: Send + Sync {
51 async fn create(
53 &self,
54 role: &AgentRole,
55 system_prompt: Option<&str>,
56 ) -> Result<(Box<dyn crate::agent::Agent>, crate::registry::ToolRegistry), String>;
57}
58
59impl SpawnAgentTool {
60 pub fn new(swarm: SharedSwarm, factory: Arc<dyn AgentFactory>) -> Self {
61 Self { swarm, factory }
62 }
63}
64
65#[async_trait::async_trait]
66impl Tool for SpawnAgentTool {
67 fn name(&self) -> &str {
68 "spawn_agent"
69 }
70
71 fn description(&self) -> &str {
72 "Spawn a sub-agent with a specific role and task. Roles: explorer (fast, read-only), worker (smart, read-write), reviewer (read-only, thorough)."
73 }
74
75 fn parameters_schema(&self) -> Value {
76 serde_json::json!({
77 "type": "object",
78 "required": ["role", "task"],
79 "properties": {
80 "role": {
81 "type": "string",
82 "description": "Agent role: explorer, worker, reviewer, or custom name"
83 },
84 "task": {
85 "type": "string",
86 "description": "Task description for the sub-agent"
87 },
88 "system_prompt": {
89 "type": "string",
90 "description": "Optional system prompt override"
91 },
92 "max_steps": {
93 "type": "integer",
94 "description": "Optional max steps for the agent loop"
95 },
96 "cwd": {
97 "type": "string",
98 "description": "Optional working directory"
99 }
100 }
101 })
102 }
103
104 async fn execute(&self, args: Value, ctx: &mut AgentContext) -> Result<ToolOutput, ToolError> {
105 let args: SpawnArgs = parse_args(&args)?;
106
107 let role = match args.role.as_str() {
108 "explorer" => AgentRole::Explorer,
109 "worker" => AgentRole::Worker,
110 "reviewer" => AgentRole::Reviewer,
111 other => AgentRole::Custom(other.to_string()),
112 };
113
114 let (agent, tools) = self
115 .factory
116 .create(&role, args.system_prompt.as_deref())
117 .await
118 .map_err(ToolError::Execution)?;
119
120 let config = crate::swarm::SpawnConfig {
121 role: role.clone(),
122 system_prompt: args.system_prompt,
123 tool_names: None,
124 cwd: args.cwd.map(std::path::PathBuf::from),
125 task: args.task.clone(),
126 max_steps: args.max_steps.unwrap_or(match &role {
127 AgentRole::Explorer => 10,
128 AgentRole::Worker => 30,
129 AgentRole::Reviewer => 15,
130 AgentRole::Custom(_) => 20,
131 }),
132 writable_roots: None,
133 };
134
135 let mut swarm = self.swarm.lock().await;
136 let id = swarm
137 .spawn(config, agent, tools, ctx)
138 .map_err(|e| ToolError::Execution(e.to_string()))?;
139
140 Ok(ToolOutput::text(format!(
141 "Spawned {} agent (id: {}): {}",
142 role.name(),
143 id,
144 args.task
145 )))
146 }
147}
148
149#[derive(Deserialize)]
152struct WaitArgs {
153 #[serde(default)]
155 ids: Vec<String>,
156 timeout_secs: Option<u64>,
158}
159
160pub struct WaitAgentsTool {
162 swarm: SharedSwarm,
163}
164
165impl WaitAgentsTool {
166 pub fn new(swarm: SharedSwarm) -> Self {
167 Self { swarm }
168 }
169}
170
171#[async_trait::async_trait]
172impl Tool for WaitAgentsTool {
173 fn name(&self) -> &str {
174 "wait_agents"
175 }
176
177 fn description(&self) -> &str {
178 "Wait for sub-agents to complete. Provide specific IDs or wait for all."
179 }
180
181 fn parameters_schema(&self) -> Value {
182 serde_json::json!({
183 "type": "object",
184 "properties": {
185 "ids": {
186 "type": "array",
187 "items": {"type": "string"},
188 "description": "Agent IDs to wait for. Empty = wait all."
189 },
190 "timeout_secs": {
191 "type": "integer",
192 "description": "Timeout in seconds (default: 300)"
193 }
194 }
195 })
196 }
197
198 async fn execute(&self, args: Value, _ctx: &mut AgentContext) -> Result<ToolOutput, ToolError> {
199 let args: WaitArgs = parse_args(&args)?;
200 let timeout = std::time::Duration::from_secs(args.timeout_secs.unwrap_or(300));
201
202 let receivers = {
204 let mut swarm = self.swarm.lock().await;
205 if args.ids.is_empty() {
206 swarm.take_all_receivers()
207 } else {
208 let mut rxs = Vec::new();
209 for id_str in &args.ids {
210 let id = AgentId::from(id_str.as_str());
211 match swarm.take_receiver(&id) {
212 Ok(rx) => rxs.push((id, rx)),
213 Err(e) => {
214 return Err(ToolError::Execution(format!(
215 "Error for {}: {}",
216 id_str, e
217 )));
218 }
219 }
220 }
221 rxs
222 }
223 }; let mut results = Vec::new();
227 for (id, rx) in receivers {
228 match tokio::time::timeout(timeout, rx).await {
229 Ok(Ok(result)) => results.push(result),
230 Ok(Err(_)) => {
231 return Err(ToolError::Execution(format!("Channel closed for {}", id)));
232 }
233 Err(_) => return Err(ToolError::Execution(format!("Timeout waiting for {}", id))),
234 }
235 }
236
237 {
239 let mut swarm = self.swarm.lock().await;
240 for r in &results {
241 swarm.cleanup(&r.id);
242 }
243 }
244
245 let mut output = String::new();
246 for r in &results {
247 output.push_str(&format!(
248 "<agent_result id=\"{}\" role=\"{}\" status=\"{}\">\n{}\n</agent_result>\n",
249 r.id, r.role, r.status, r.summary
250 ));
251 }
252
253 if output.is_empty() {
254 output = "No agents to wait for.".to_string();
255 }
256
257 Ok(ToolOutput::text(output))
258 }
259}
260
261pub struct GetStatusTool {
265 swarm: SharedSwarm,
266}
267
268impl GetStatusTool {
269 pub fn new(swarm: SharedSwarm) -> Self {
270 Self { swarm }
271 }
272}
273
274#[async_trait::async_trait]
275impl Tool for GetStatusTool {
276 fn name(&self) -> &str {
277 "agent_status"
278 }
279
280 fn description(&self) -> &str {
281 "Get status of all active sub-agents."
282 }
283
284 fn parameters_schema(&self) -> Value {
285 serde_json::json!({"type": "object"})
286 }
287
288 async fn execute(
289 &self,
290 _args: Value,
291 _ctx: &mut AgentContext,
292 ) -> Result<ToolOutput, ToolError> {
293 let swarm = self.swarm.lock().await;
294 let statuses = swarm.status_all().await;
295
296 if statuses.is_empty() {
297 return Ok(ToolOutput::text("No active agents."));
298 }
299
300 let mut output = String::new();
301 for (id, role, status) in &statuses {
302 output.push_str(&format!("- {} ({}) — {}\n", id, role, status));
303 }
304
305 Ok(ToolOutput::text(output))
306 }
307}
308
309#[derive(Deserialize)]
312struct CancelArgs {
313 id: String,
315}
316
317pub struct CancelAgentTool {
319 swarm: SharedSwarm,
320}
321
322impl CancelAgentTool {
323 pub fn new(swarm: SharedSwarm) -> Self {
324 Self { swarm }
325 }
326}
327
328#[async_trait::async_trait]
329impl Tool for CancelAgentTool {
330 fn name(&self) -> &str {
331 "cancel_agent"
332 }
333
334 fn description(&self) -> &str {
335 "Cancel a running sub-agent by ID, or 'all' to cancel all agents."
336 }
337
338 fn parameters_schema(&self) -> Value {
339 serde_json::json!({
340 "type": "object",
341 "required": ["id"],
342 "properties": {
343 "id": {
344 "type": "string",
345 "description": "Agent ID to cancel, or 'all'"
346 }
347 }
348 })
349 }
350
351 async fn execute(&self, args: Value, _ctx: &mut AgentContext) -> Result<ToolOutput, ToolError> {
352 let args: CancelArgs = parse_args(&args)?;
353
354 let swarm = self.swarm.lock().await;
355
356 if args.id == "all" {
357 swarm.cancel_all();
358 Ok(ToolOutput::text("Cancelled all agents."))
359 } else {
360 let id = AgentId::from(args.id.as_str());
361 swarm
362 .cancel(&id)
363 .map_err(|e| ToolError::Execution(e.to_string()))?;
364 Ok(ToolOutput::text(format!("Cancelled agent {}.", args.id)))
365 }
366 }
367}
368
369impl From<&str> for AgentId {
371 fn from(s: &str) -> Self {
372 Self(s.to_string())
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn agent_id_from_str() {
382 let id = AgentId::from("abc123");
383 assert_eq!(id.short(), "abc123");
384 assert_eq!(format!("{}", id), "abc123");
385 }
386
387 #[test]
388 fn agent_role_names() {
389 assert_eq!(AgentRole::Explorer.name(), "explorer");
390 assert_eq!(AgentRole::Custom("planner".into()).name(), "planner");
391 }
392
393 #[test]
394 fn shared_swarm_creates() {
395 let swarm = shared_swarm(SwarmManager::new());
396 drop(swarm);
398 }
399}