1use crate::agent::{Agent, AgentError};
7use crate::agent_loop::{LoopConfig, run_loop};
8use crate::context::AgentContext;
9use crate::registry::ToolRegistry;
10use crate::types::Message;
11use std::collections::HashMap;
12use std::fmt;
13use std::path::PathBuf;
14use std::sync::Arc;
15use tokio::sync::{Mutex, mpsc, oneshot};
16use tokio_util::sync::CancellationToken;
17
18#[derive(Debug, Clone, Hash, PartialEq, Eq)]
20pub struct AgentId(pub String);
21
22impl Default for AgentId {
23 fn default() -> Self {
24 Self(format!("agent-{}", next_id()))
25 }
26}
27
28impl AgentId {
29 pub fn new() -> Self {
30 Self::default()
31 }
32
33 pub fn short(&self) -> &str {
34 &self.0
35 }
36}
37
38impl fmt::Display for AgentId {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 write!(f, "{}", self.0)
41 }
42}
43
44fn next_id() -> u64 {
45 use std::sync::atomic::{AtomicU64, Ordering};
46 static COUNTER: AtomicU64 = AtomicU64::new(1);
47 COUNTER.fetch_add(1, Ordering::Relaxed)
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum AgentRole {
53 Explorer,
55 Worker,
57 Reviewer,
59 Custom(String),
61}
62
63impl AgentRole {
64 pub fn name(&self) -> &str {
65 match self {
66 Self::Explorer => "explorer",
67 Self::Worker => "worker",
68 Self::Reviewer => "reviewer",
69 Self::Custom(n) => n,
70 }
71 }
72}
73
74impl fmt::Display for AgentRole {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 write!(f, "{}", self.name())
77 }
78}
79
80#[derive(Debug, Clone, PartialEq, Eq)]
82pub enum AgentStatus {
83 Running,
84 Completed,
85 Failed(String),
86 Cancelled,
87}
88
89impl fmt::Display for AgentStatus {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 match self {
92 Self::Running => write!(f, "running"),
93 Self::Completed => write!(f, "completed"),
94 Self::Failed(e) => write!(f, "failed: {}", e),
95 Self::Cancelled => write!(f, "cancelled"),
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct SwarmResult {
103 pub id: AgentId,
104 pub role: AgentRole,
105 pub status: AgentStatus,
106 pub summary: String,
108 pub steps: usize,
110 pub events: Vec<String>,
112}
113
114pub struct SpawnConfig {
116 pub role: AgentRole,
118 pub system_prompt: Option<String>,
120 pub tool_names: Option<Vec<String>>,
122 pub cwd: Option<PathBuf>,
124 pub task: String,
126 pub max_steps: usize,
128 pub writable_roots: Option<Vec<PathBuf>>,
130}
131
132impl SpawnConfig {
133 pub fn explorer(task: impl Into<String>) -> Self {
134 Self {
135 role: AgentRole::Explorer,
136 system_prompt: None,
137 tool_names: None,
138 cwd: None,
139 task: task.into(),
140 max_steps: 10,
141 writable_roots: None,
142 }
143 }
144
145 pub fn worker(task: impl Into<String>) -> Self {
146 Self {
147 role: AgentRole::Worker,
148 system_prompt: None,
149 tool_names: None,
150 cwd: None,
151 task: task.into(),
152 max_steps: 30,
153 writable_roots: None,
154 }
155 }
156
157 pub fn reviewer(task: impl Into<String>) -> Self {
158 Self {
159 role: AgentRole::Reviewer,
160 system_prompt: None,
161 tool_names: None,
162 cwd: None,
163 task: task.into(),
164 max_steps: 15,
165 writable_roots: None,
166 }
167 }
168}
169
170#[derive(Debug, thiserror::Error)]
172pub enum SwarmError {
173 #[error("Max agents reached ({0})")]
174 MaxAgents(usize),
175 #[error("Max depth reached ({0})")]
176 MaxDepth(usize),
177 #[error("Agent not found: {0}")]
178 NotFound(AgentId),
179 #[error("Agent already completed: {0}")]
180 AlreadyCompleted(AgentId),
181 #[error("Agent error: {0}")]
182 Agent(#[from] AgentError),
183 #[error("Channel error")]
184 Channel,
185}
186
187struct AgentHandle {
189 id: AgentId,
190 role: AgentRole,
191 cancel: CancellationToken,
192 status: Arc<Mutex<AgentStatus>>,
193 result_rx: Option<oneshot::Receiver<SwarmResult>>,
194}
195
196#[derive(Debug, Clone)]
198pub struct AgentNotification {
199 pub id: AgentId,
200 pub role: AgentRole,
201 pub status: AgentStatus,
202 pub summary: String,
203}
204
205pub struct SwarmManager {
207 agents: HashMap<AgentId, AgentHandle>,
208 notification_tx: mpsc::Sender<AgentNotification>,
210 notification_rx: Arc<Mutex<mpsc::Receiver<AgentNotification>>>,
211 max_agents: usize,
212 max_depth: usize,
213 current_depth: usize,
214}
215
216impl SwarmManager {
217 pub fn new() -> Self {
218 let (tx, rx) = mpsc::channel(64);
219 Self {
220 agents: HashMap::new(),
221 notification_tx: tx,
222 notification_rx: Arc::new(Mutex::new(rx)),
223 max_agents: 8,
224 max_depth: 3,
225 current_depth: 0,
226 }
227 }
228
229 pub fn with_limits(mut self, max_agents: usize, max_depth: usize) -> Self {
230 self.max_agents = max_agents;
231 self.max_depth = max_depth;
232 self
233 }
234
235 pub fn with_depth(mut self, depth: usize) -> Self {
236 self.current_depth = depth;
237 self
238 }
239
240 pub fn spawn(
245 &mut self,
246 config: SpawnConfig,
247 agent: Box<dyn Agent>,
248 tools: ToolRegistry,
249 parent_ctx: &AgentContext,
250 ) -> Result<AgentId, SwarmError> {
251 if self.active_count() >= self.max_agents {
252 return Err(SwarmError::MaxAgents(self.max_agents));
253 }
254 if self.current_depth >= self.max_depth {
255 return Err(SwarmError::MaxDepth(self.max_depth));
256 }
257
258 let id = AgentId::new();
259 let cancel = CancellationToken::new();
260 let status = Arc::new(Mutex::new(AgentStatus::Running));
261 let (result_tx, result_rx) = oneshot::channel();
262
263 let mut ctx = AgentContext::new();
265 ctx.cwd = config.cwd.unwrap_or_else(|| parent_ctx.cwd.clone());
266 ctx.writable_roots = config
267 .writable_roots
268 .unwrap_or_else(|| parent_ctx.writable_roots.clone());
269
270 let system_prompt = config.system_prompt.unwrap_or_else(|| {
272 format!(
273 "You are a {} agent. Complete the assigned task efficiently.",
274 config.role.name()
275 )
276 });
277 let mut messages = vec![Message::system(&system_prompt), Message::user(&config.task)];
278
279 let loop_config = LoopConfig {
280 max_steps: config.max_steps,
281 ..Default::default()
282 };
283
284 let agent_id = id.clone();
285 let agent_role = config.role.clone();
286 let cancel_token = cancel.clone();
287 let status_clone = Arc::clone(&status);
288 let notify_tx = self.notification_tx.clone();
289
290 tokio::spawn(async move {
292 let mut events: Vec<String> = Vec::new();
293
294 let loop_result = tokio::select! {
295 result = run_loop(
296 agent.as_ref(),
297 &tools,
298 &mut ctx,
299 &mut messages,
300 &loop_config,
301 |event| {
302 events.push(format!("{:?}", event));
303 },
304 ) => result,
305 _ = cancel_token.cancelled() => {
306 Err(AgentError::Cancelled)
307 }
308 };
309
310 let (final_status, summary, steps) = match loop_result {
311 Ok(steps) => {
312 let summary = messages
313 .iter()
314 .rev()
315 .find(|m| m.role == crate::types::Role::Assistant)
316 .map(|m| m.content.clone())
317 .unwrap_or_else(|| "Completed".to_string());
318 (AgentStatus::Completed, summary, steps)
319 }
320 Err(AgentError::Cancelled) => (AgentStatus::Cancelled, "Cancelled".to_string(), 0),
321 Err(e) => (AgentStatus::Failed(e.to_string()), e.to_string(), 0),
322 };
323
324 *status_clone.lock().await = final_status.clone();
326
327 let result = SwarmResult {
328 id: agent_id.clone(),
329 role: agent_role.clone(),
330 status: final_status.clone(),
331 summary: summary.clone(),
332 steps,
333 events,
334 };
335
336 let _ = result_tx.send(result);
338
339 let _ = notify_tx
341 .send(AgentNotification {
342 id: agent_id,
343 role: agent_role,
344 status: final_status,
345 summary,
346 })
347 .await;
348 });
349
350 self.agents.insert(
351 id.clone(),
352 AgentHandle {
353 id: id.clone(),
354 role: config.role,
355 cancel,
356 status,
357 result_rx: Some(result_rx),
358 },
359 );
360
361 Ok(id)
362 }
363
364 pub async fn status(&self, id: &AgentId) -> Option<AgentStatus> {
366 if let Some(handle) = self.agents.get(id) {
367 Some(handle.status.lock().await.clone())
368 } else {
369 None
370 }
371 }
372
373 pub async fn status_all(&self) -> Vec<(AgentId, AgentRole, AgentStatus)> {
375 let mut result = Vec::new();
376 for handle in self.agents.values() {
377 let status = handle.status.lock().await.clone();
378 result.push((handle.id.clone(), handle.role.clone(), status));
379 }
380 result
381 }
382
383 pub fn take_receiver(
386 &mut self,
387 id: &AgentId,
388 ) -> Result<oneshot::Receiver<SwarmResult>, SwarmError> {
389 let handle = self
390 .agents
391 .get_mut(id)
392 .ok_or_else(|| SwarmError::NotFound(id.clone()))?;
393
394 handle
395 .result_rx
396 .take()
397 .ok_or_else(|| SwarmError::AlreadyCompleted(id.clone()))
398 }
399
400 pub fn take_all_receivers(&mut self) -> Vec<(AgentId, oneshot::Receiver<SwarmResult>)> {
402 let mut receivers = Vec::new();
403 for (id, handle) in &mut self.agents {
404 if let Some(rx) = handle.result_rx.take() {
405 receivers.push((id.clone(), rx));
406 }
407 }
408 receivers
409 }
410
411 pub async fn wait(&mut self, id: &AgentId) -> Result<SwarmResult, SwarmError> {
414 let rx = self.take_receiver(id)?;
415 let result = rx.await.map_err(|_| SwarmError::Channel)?;
416 self.agents.remove(id); Ok(result)
418 }
419
420 pub async fn wait_all(&mut self) -> Vec<SwarmResult> {
423 let receivers = self.take_all_receivers();
424 let mut results = Vec::new();
425 for (id, rx) in receivers {
426 if let Ok(result) = rx.await {
427 results.push(result);
428 self.agents.remove(&id);
429 }
430 }
431 results
432 }
433
434 pub fn cancel(&self, id: &AgentId) -> Result<(), SwarmError> {
436 let handle = self
437 .agents
438 .get(id)
439 .ok_or_else(|| SwarmError::NotFound(id.clone()))?;
440 handle.cancel.cancel();
441 Ok(())
442 }
443
444 pub fn cancel_all(&self) {
446 for handle in self.agents.values() {
447 handle.cancel.cancel();
448 }
449 }
450
451 pub async fn try_recv_notification(&self) -> Option<AgentNotification> {
453 let mut rx = self.notification_rx.lock().await;
454 rx.try_recv().ok()
455 }
456
457 pub async fn recv_notification(
459 &self,
460 timeout: std::time::Duration,
461 ) -> Option<AgentNotification> {
462 let mut rx = self.notification_rx.lock().await;
463 tokio::time::timeout(timeout, rx.recv())
464 .await
465 .ok()
466 .flatten()
467 }
468
469 pub fn cleanup(&mut self, id: &AgentId) {
471 self.agents.remove(id);
472 }
473
474 pub fn agent_count(&self) -> usize {
476 self.agents.len()
477 }
478
479 pub fn active_count(&self) -> usize {
481 self.agents
482 .values()
483 .filter(|h| h.result_rx.is_some())
484 .count()
485 }
486
487 pub fn all_agent_ids(&self) -> Vec<AgentId> {
489 self.agents.keys().cloned().collect()
490 }
491
492 pub async fn status_all_formatted(&self) -> String {
494 let statuses = self.status_all().await;
495 if statuses.is_empty() {
496 return "No agents.".to_string();
497 }
498 statuses
499 .iter()
500 .map(|(id, role, status)| format!("[{}] {} — {}", id, role, status))
501 .collect::<Vec<_>>()
502 .join("\n")
503 }
504
505 pub async fn wait_with_timeout(
507 &mut self,
508 ids: &[AgentId],
509 timeout: std::time::Duration,
510 ) -> Vec<(AgentId, String)> {
511 let mut results = Vec::new();
512 for id in ids {
513 let rx = match self.take_receiver(id) {
514 Ok(rx) => rx,
515 Err(e) => {
516 results.push((id.clone(), format!("Error: {}", e)));
517 continue;
518 }
519 };
520 match tokio::time::timeout(timeout, rx).await {
521 Ok(Ok(result)) => {
522 let summary = format!(
523 "{} ({}, {} steps): {}",
524 result.status,
525 result.role,
526 result.steps,
527 if result.summary.len() > 500 {
528 format!("{}...", &result.summary[..500])
529 } else {
530 result.summary.clone()
531 }
532 );
533 self.agents.remove(id);
534 results.push((id.clone(), summary));
535 }
536 Ok(Err(_)) => {
537 results.push((id.clone(), "Channel closed".into()));
538 }
539 Err(_) => {
540 results.push((id.clone(), format!("Timeout after {}s", timeout.as_secs())));
541 }
542 }
543 }
544 results
545 }
546
547 pub async fn status_summary(&self) -> String {
549 let mut lines = Vec::new();
550 for handle in self.agents.values() {
551 let status = handle.status.lock().await;
552 lines.push(format!(" {} ({}) — {}", handle.id, handle.role, *status));
553 }
554 if lines.is_empty() {
555 " (none)".to_string()
556 } else {
557 lines.join("\n")
558 }
559 }
560}
561
562impl Default for SwarmManager {
563 fn default() -> Self {
564 Self::new()
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571 use crate::agent::{Agent, AgentError, Decision};
572 use crate::agent_tool::{Tool, ToolError, ToolOutput};
573 use crate::types::{Message, ToolCall};
574 use serde_json::Value;
575
576 struct SimpleAgent {}
577
578 #[async_trait::async_trait]
579 impl Agent for SimpleAgent {
580 async fn decide(
581 &self,
582 _messages: &[Message],
583 _tools: &ToolRegistry,
584 ) -> Result<Decision, AgentError> {
585 Ok(Decision {
587 situation: "Task done.".into(),
588 task: vec![],
589 tool_calls: vec![],
590 completed: true,
591 })
592 }
593 }
594
595 struct StepAgent {
596 steps: usize,
597 }
598
599 #[async_trait::async_trait]
600 impl Agent for StepAgent {
601 async fn decide(
602 &self,
603 msgs: &[Message],
604 _tools: &ToolRegistry,
605 ) -> Result<Decision, AgentError> {
606 let tool_msgs = msgs
608 .iter()
609 .filter(|m| m.role == crate::types::Role::Tool)
610 .count();
611 if tool_msgs >= self.steps {
612 Ok(Decision {
613 situation: "All steps done.".into(),
614 task: vec![],
615 tool_calls: vec![],
616 completed: true,
617 })
618 } else {
619 Ok(Decision {
620 situation: format!("Step {}", tool_msgs + 1),
621 task: vec![],
622 tool_calls: vec![ToolCall {
623 id: format!("call_{}", tool_msgs),
624 name: "echo".into(),
625 arguments: serde_json::json!({}),
626 }],
627 completed: false,
628 })
629 }
630 }
631 }
632
633 struct EchoTool;
634
635 #[async_trait::async_trait]
636 impl Tool for EchoTool {
637 fn name(&self) -> &str {
638 "echo"
639 }
640 fn description(&self) -> &str {
641 "echo"
642 }
643 fn parameters_schema(&self) -> Value {
644 serde_json::json!({"type": "object"})
645 }
646 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
647 Ok(ToolOutput::text("echoed"))
648 }
649 }
650
651 #[tokio::test]
652 async fn spawn_and_wait() {
653 let mut swarm = SwarmManager::new();
654 let ctx = AgentContext::new();
655
656 let id = swarm
657 .spawn(
658 SpawnConfig::explorer("Find all Rust files"),
659 Box::new(SimpleAgent {}),
660 ToolRegistry::new(),
661 &ctx,
662 )
663 .unwrap();
664
665 let result = swarm.wait(&id).await.unwrap();
666 assert_eq!(result.status, AgentStatus::Completed);
667 assert!(result.summary.contains("Task done"));
668 }
669
670 #[tokio::test]
671 async fn spawn_with_tools() {
672 let mut swarm = SwarmManager::new();
673 let ctx = AgentContext::new();
674 let tools = ToolRegistry::new().register(EchoTool);
675
676 let id = swarm
677 .spawn(
678 SpawnConfig::worker("Do 2 steps"),
679 Box::new(StepAgent { steps: 2 }),
680 tools,
681 &ctx,
682 )
683 .unwrap();
684
685 let result = swarm.wait(&id).await.unwrap();
686 assert_eq!(result.status, AgentStatus::Completed);
687 assert!(result.steps >= 2);
688 }
689
690 #[tokio::test]
691 async fn cancel_agent() {
692 let mut swarm = SwarmManager::new();
693 let ctx = AgentContext::new();
694
695 let id = swarm
697 .spawn(
698 SpawnConfig {
699 role: AgentRole::Worker,
700 system_prompt: None,
701 tool_names: None,
702 cwd: None,
703 task: "Long task".into(),
704 max_steps: 100,
705 writable_roots: None,
706 },
707 Box::new(StepAgent { steps: 100 }),
708 ToolRegistry::new().register(EchoTool),
709 &ctx,
710 )
711 .unwrap();
712
713 swarm.cancel(&id).unwrap();
715
716 let result = swarm.wait(&id).await.unwrap();
717 assert!(
718 result.status == AgentStatus::Cancelled
719 || matches!(result.status, AgentStatus::Failed(_))
720 || result.status == AgentStatus::Completed );
722 }
723
724 #[tokio::test]
725 async fn max_agents_limit() {
726 let mut swarm = SwarmManager::new().with_limits(2, 3);
727 let ctx = AgentContext::new();
728
729 let _id1 = swarm
731 .spawn(
732 SpawnConfig::explorer("Task 1"),
733 Box::new(SimpleAgent {}),
734 ToolRegistry::new(),
735 &ctx,
736 )
737 .unwrap();
738
739 let _id2 = swarm
740 .spawn(
741 SpawnConfig::explorer("Task 2"),
742 Box::new(SimpleAgent {}),
743 ToolRegistry::new(),
744 &ctx,
745 )
746 .unwrap();
747
748 let err = swarm
750 .spawn(
751 SpawnConfig::explorer("Task 3"),
752 Box::new(SimpleAgent {}),
753 ToolRegistry::new(),
754 &ctx,
755 )
756 .err()
757 .unwrap();
758 assert!(matches!(err, SwarmError::MaxAgents(2)));
759 }
760
761 #[tokio::test]
762 async fn max_depth_limit() {
763 let mut swarm = SwarmManager::new().with_limits(8, 3).with_depth(3);
764 let ctx = AgentContext::new();
765
766 let err = swarm
767 .spawn(
768 SpawnConfig::explorer("Task"),
769 Box::new(SimpleAgent {}),
770 ToolRegistry::new(),
771 &ctx,
772 )
773 .err()
774 .unwrap();
775 assert!(matches!(err, SwarmError::MaxDepth(3)));
776 }
777
778 #[tokio::test]
779 async fn status_tracking() {
780 let mut swarm = SwarmManager::new();
781 let ctx = AgentContext::new();
782
783 let id = swarm
784 .spawn(
785 SpawnConfig::explorer("Quick task"),
786 Box::new(SimpleAgent {}),
787 ToolRegistry::new(),
788 &ctx,
789 )
790 .unwrap();
791
792 let result = swarm.wait(&id).await.unwrap();
794 assert_eq!(result.status, AgentStatus::Completed);
795
796 assert!(swarm.status(&id).await.is_none());
798 }
799
800 #[tokio::test]
801 async fn wait_all_returns_results() {
802 let mut swarm = SwarmManager::new();
803 let ctx = AgentContext::new();
804
805 let _id1 = swarm
806 .spawn(
807 SpawnConfig::explorer("Task 1"),
808 Box::new(SimpleAgent {}),
809 ToolRegistry::new(),
810 &ctx,
811 )
812 .unwrap();
813
814 let _id2 = swarm
815 .spawn(
816 SpawnConfig::worker("Task 2"),
817 Box::new(SimpleAgent {}),
818 ToolRegistry::new(),
819 &ctx,
820 )
821 .unwrap();
822
823 let results = swarm.wait_all().await;
824 assert_eq!(results.len(), 2);
825 assert!(results.iter().all(|r| r.status == AgentStatus::Completed));
826 }
827
828 #[test]
829 fn agent_role_display() {
830 assert_eq!(AgentRole::Explorer.name(), "explorer");
831 assert_eq!(AgentRole::Worker.name(), "worker");
832 assert_eq!(AgentRole::Reviewer.name(), "reviewer");
833 assert_eq!(AgentRole::Custom("planner".into()).name(), "planner");
834 }
835
836 #[test]
837 fn spawn_config_constructors() {
838 let cfg = SpawnConfig::explorer("Find files");
839 assert_eq!(cfg.role, AgentRole::Explorer);
840 assert_eq!(cfg.max_steps, 10);
841
842 let cfg = SpawnConfig::worker("Implement feature");
843 assert_eq!(cfg.role, AgentRole::Worker);
844 assert_eq!(cfg.max_steps, 30);
845
846 let cfg = SpawnConfig::reviewer("Review code");
847 assert_eq!(cfg.role, AgentRole::Reviewer);
848 assert_eq!(cfg.max_steps, 15);
849 }
850}