1use std::collections::HashMap;
9use std::sync::{Arc, Mutex, PoisonError};
10
11use tokio::sync::{mpsc, oneshot};
12use tokio_util::sync::CancellationToken;
13use tracing::{info, warn};
14
15use crate::agent::{Agent, AgentOptions};
16use crate::error::AgentError;
17use crate::handle::AgentStatus;
18use crate::task_core::{TaskCore, resolve_status};
19use crate::types::{AgentMessage, AgentResult, ContentBlock, LlmMessage, UserMessage};
20use crate::util::now_timestamp;
21
22type OptionsFactoryArc = Arc<dyn Fn() -> AgentOptions + Send + Sync>;
25
26pub struct AgentRequest {
30 pub messages: Vec<AgentMessage>,
32 pub reply: oneshot::Sender<Result<AgentResult, AgentError>>,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum SupervisorAction {
41 Restart,
43 Stop,
45 Escalate,
47}
48
49pub trait SupervisorPolicy: Send + Sync {
54 fn on_agent_error(&self, name: &str, error: &AgentError) -> SupervisorAction;
56}
57
58#[derive(Debug, Clone)]
60pub struct DefaultSupervisor {
61 max_restarts: u32,
62}
63
64impl DefaultSupervisor {
65 #[must_use]
67 pub const fn new(max_restarts: u32) -> Self {
68 Self { max_restarts }
69 }
70
71 #[must_use]
73 pub const fn max_restarts(&self) -> u32 {
74 self.max_restarts
75 }
76}
77
78impl Default for DefaultSupervisor {
79 fn default() -> Self {
80 Self { max_restarts: 3 }
81 }
82}
83
84impl SupervisorPolicy for DefaultSupervisor {
85 fn on_agent_error(&self, _name: &str, error: &AgentError) -> SupervisorAction {
86 if error.is_retryable() {
87 SupervisorAction::Restart
88 } else {
89 SupervisorAction::Stop
90 }
91 }
92}
93
94struct AgentEntry {
98 options_factory: OptionsFactoryArc,
100 parent: Option<String>,
102 children: Vec<String>,
104 max_restarts: u32,
106}
107
108pub struct OrchestratedHandle {
115 name: String,
116 request_tx: mpsc::Sender<AgentRequest>,
117 core: TaskCore,
118}
119
120impl OrchestratedHandle {
121 #[must_use]
123 pub fn name(&self) -> &str {
124 &self.name
125 }
126
127 pub async fn send_message(&self, text: impl Into<String>) -> Result<AgentResult, AgentError> {
129 let msg = AgentMessage::Llm(LlmMessage::User(UserMessage {
130 content: vec![ContentBlock::Text { text: text.into() }],
131 timestamp: now_timestamp(),
132 cache_hint: None,
133 }));
134 self.send_messages(vec![msg]).await
135 }
136
137 pub async fn send_messages(
139 &self,
140 messages: Vec<AgentMessage>,
141 ) -> Result<AgentResult, AgentError> {
142 let (reply_tx, reply_rx) = oneshot::channel();
143 let request = AgentRequest {
144 messages,
145 reply: reply_tx,
146 };
147 self.request_tx.send(request).await.map_err(|_| {
148 AgentError::plugin(
149 "orchestrator",
150 std::io::Error::other("agent channel closed"),
151 )
152 })?;
153
154 reply_rx.await.map_err(|_| {
155 AgentError::plugin("orchestrator", std::io::Error::other("agent reply dropped"))
156 })?
157 }
158
159 pub async fn await_result(self) -> Result<AgentResult, AgentError> {
164 drop(self.request_tx);
165 self.core.result().await
166 }
167
168 pub fn cancel(&self) {
170 self.core.cancel();
171 }
172
173 pub fn status(&self) -> AgentStatus {
175 self.core.status()
176 }
177
178 pub fn is_done(&self) -> bool {
180 self.core.is_done()
181 }
182}
183
184impl std::fmt::Debug for OrchestratedHandle {
185 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186 f.debug_struct("OrchestratedHandle")
187 .field("name", &self.name)
188 .field("status", &self.status())
189 .finish_non_exhaustive()
190 }
191}
192
193pub struct AgentOrchestrator {
208 entries: HashMap<String, AgentEntry>,
209 supervisor: Option<Arc<dyn SupervisorPolicy>>,
210 channel_buffer: usize,
212 default_max_restarts: u32,
214}
215
216impl AgentOrchestrator {
217 #[must_use]
219 pub fn new() -> Self {
220 Self {
221 entries: HashMap::new(),
222 supervisor: None,
223 channel_buffer: 32,
224 default_max_restarts: 3,
225 }
226 }
227
228 #[must_use]
230 pub fn with_supervisor(mut self, policy: impl SupervisorPolicy + 'static) -> Self {
231 self.supervisor = Some(Arc::new(policy));
232 self
233 }
234
235 #[must_use]
237 pub const fn with_channel_buffer(mut self, size: usize) -> Self {
238 self.channel_buffer = size;
239 self
240 }
241
242 #[must_use]
244 pub const fn with_max_restarts(mut self, max: u32) -> Self {
245 self.default_max_restarts = max;
246 self
247 }
248
249 pub fn add_agent(
257 &mut self,
258 name: impl Into<String>,
259 options_factory: impl Fn() -> AgentOptions + Send + Sync + 'static,
260 ) {
261 let name = name.into();
262 assert!(
263 !self.entries.contains_key(&name),
264 "agent '{name}' already registered"
265 );
266 self.entries.insert(
267 name,
268 AgentEntry {
269 options_factory: Arc::new(options_factory),
270 parent: None,
271 children: Vec::new(),
272 max_restarts: self.default_max_restarts,
273 },
274 );
275 }
276
277 pub fn add_child(
284 &mut self,
285 name: impl Into<String>,
286 parent: impl Into<String>,
287 options_factory: impl Fn() -> AgentOptions + Send + Sync + 'static,
288 ) {
289 let name = name.into();
290 let parent = parent.into();
291 assert!(
292 self.entries.contains_key(&parent),
293 "parent agent '{parent}' not registered"
294 );
295 assert!(
296 !self.entries.contains_key(&name),
297 "agent '{name}' already registered"
298 );
299
300 self.entries
301 .get_mut(&parent)
302 .expect("parent checked above")
303 .children
304 .push(name.clone());
305
306 self.entries.insert(
307 name,
308 AgentEntry {
309 options_factory: Arc::new(options_factory),
310 parent: Some(parent),
311 children: Vec::new(),
312 max_restarts: self.default_max_restarts,
313 },
314 );
315 }
316
317 #[must_use]
319 pub fn parent_of(&self, name: &str) -> Option<&str> {
320 self.entries.get(name).and_then(|e| e.parent.as_deref())
321 }
322
323 #[must_use]
325 pub fn children_of(&self, name: &str) -> Option<&[String]> {
326 self.entries.get(name).map(|e| e.children.as_slice())
327 }
328
329 #[must_use]
331 pub fn names(&self) -> Vec<&str> {
332 self.entries.keys().map(String::as_str).collect()
333 }
334
335 #[must_use]
337 pub fn contains(&self, name: &str) -> bool {
338 self.entries.contains_key(name)
339 }
340
341 pub fn spawn(&self, name: &str) -> Result<OrchestratedHandle, AgentError> {
354 let entry = self.entries.get(name).ok_or_else(|| {
355 AgentError::plugin(
356 "orchestrator",
357 std::io::Error::other(format!("agent not registered: {name}")),
358 )
359 })?;
360
361 let factory = Arc::clone(&entry.options_factory);
362 let max_restarts = entry.max_restarts;
363 let agent_name = name.to_owned();
364 let supervisor = self.supervisor.clone();
365
366 let (request_tx, request_rx) = mpsc::channel::<AgentRequest>(self.channel_buffer);
367 let cancellation_token = CancellationToken::new();
368 let status = Arc::new(Mutex::new(AgentStatus::Running));
369
370 let status_clone = Arc::clone(&status);
371 let token_clone = cancellation_token.clone();
372
373 let join_handle = tokio::spawn(run_agent_loop(
374 agent_name,
375 factory,
376 request_rx,
377 token_clone,
378 status_clone,
379 supervisor,
380 max_restarts,
381 ));
382
383 Ok(OrchestratedHandle {
384 name: name.to_owned(),
385 request_tx,
386 core: TaskCore::new(join_handle, cancellation_token, status),
387 })
388 }
389}
390
391async fn run_agent_loop(
396 agent_name: String,
397 factory: OptionsFactoryArc,
398 mut request_rx: mpsc::Receiver<AgentRequest>,
399 cancellation_token: CancellationToken,
400 status: Arc<Mutex<AgentStatus>>,
401 supervisor: Option<Arc<dyn SupervisorPolicy>>,
402 max_restarts: u32,
403) -> Result<AgentResult, AgentError> {
404 let mut agent = Agent::new(factory());
405 let mut restarts: u32 = 0;
406
407 let final_result = loop {
408 tokio::select! {
409 biased;
410
411 () = cancellation_token.cancelled() => {
412 agent.abort();
413 break Err(AgentError::Aborted);
414 }
415
416 maybe_req = request_rx.recv() => {
417 if let Some(req) = maybe_req {
418 let result = tokio::select! {
419 biased;
420 () = cancellation_token.cancelled() => {
421 agent.abort();
422 let _ = req.reply.send(Err(AgentError::Aborted));
423 break Err(AgentError::Aborted);
424 }
425 r = agent.prompt_async(req.messages) => r,
426 };
427
428 match result {
429 Ok(r) => {
430 let _ = req.reply.send(Ok(r));
431 restarts = 0;
433 }
434 Err(err) => {
435 let action = supervisor
436 .as_ref()
437 .map_or(SupervisorAction::Escalate, |s| {
438 s.on_agent_error(&agent_name, &err)
439 });
440
441 match action {
442 SupervisorAction::Restart if restarts < max_restarts => {
443 warn!(
444 agent = %agent_name,
445 restart = restarts + 1,
446 max = max_restarts,
447 "supervisor restarting agent"
448 );
449 restarts += 1;
450 let _ = req.reply.send(Err(err));
451 agent = Agent::new(factory());
452 }
453 SupervisorAction::Escalate => {
454 let _ = req.reply.send(Err(err));
455 }
457 _ => {
458 let _ = req.reply.send(Err(err));
460 break Err(AgentError::plugin(
461 "orchestrator",
462 std::io::Error::other(format!(
463 "agent '{agent_name}' stopped by supervisor"
464 )),
465 ));
466 }
467 }
468 }
469 }
470 } else {
471 info!(agent = %agent_name, "request channel closed, shutting down");
473 break Ok(AgentResult {
474 messages: Vec::new(),
475 stop_reason: crate::types::StopReason::Stop,
476 usage: crate::types::Usage::default(),
477 cost: crate::types::Cost::default(),
478 error: None,
479 transfer_signal: None,
480 });
481 }
482 }
483 }
484 };
485
486 *status.lock().unwrap_or_else(PoisonError::into_inner) = resolve_status(&final_result);
487 final_result
488}
489
490impl Default for AgentOrchestrator {
491 fn default() -> Self {
492 Self::new()
493 }
494}
495
496impl std::fmt::Debug for AgentOrchestrator {
497 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
498 f.debug_struct("AgentOrchestrator")
499 .field("agents", &self.entries.keys().collect::<Vec<_>>())
500 .field(
501 "supervisor",
502 &if self.supervisor.is_some() {
503 "Some"
504 } else {
505 "None"
506 },
507 )
508 .field("channel_buffer", &self.channel_buffer)
509 .finish_non_exhaustive()
510 }
511}
512
513#[cfg(test)]
516mod tests {
517 use std::panic::AssertUnwindSafe;
518
519 use super::*;
520
521 #[test]
522 fn add_agent_and_names() {
523 let mut orch = AgentOrchestrator::new();
524 orch.add_agent("alpha", || panic!("not called"));
525 orch.add_agent("beta", || panic!("not called"));
526
527 let mut names = orch.names();
528 names.sort_unstable();
529 assert_eq!(names, vec!["alpha", "beta"]);
530 }
531
532 #[test]
533 fn contains_registered() {
534 let mut orch = AgentOrchestrator::new();
535 orch.add_agent("a", || panic!("not called"));
536 assert!(orch.contains("a"));
537 assert!(!orch.contains("b"));
538 }
539
540 #[test]
541 fn parent_child_hierarchy() {
542 let mut orch = AgentOrchestrator::new();
543 orch.add_agent("parent", || panic!("not called"));
544 orch.add_child("child1", "parent", || panic!("not called"));
545 orch.add_child("child2", "parent", || panic!("not called"));
546
547 assert_eq!(orch.parent_of("child1"), Some("parent"));
548 assert_eq!(orch.parent_of("child2"), Some("parent"));
549 assert_eq!(orch.parent_of("parent"), None);
550
551 let children = orch.children_of("parent").unwrap();
552 assert_eq!(children, &["child1", "child2"]);
553 assert!(orch.children_of("child1").unwrap().is_empty());
554 }
555
556 #[test]
557 #[should_panic(expected = "parent agent 'missing' not registered")]
558 fn add_child_missing_parent_panics() {
559 let mut orch = AgentOrchestrator::new();
560 orch.add_child("child", "missing", || panic!("not called"));
561 }
562
563 #[test]
564 #[should_panic(expected = "agent 'alpha' already registered")]
565 fn add_agent_duplicate_name_panics() {
566 let mut orch = AgentOrchestrator::new();
567 orch.add_agent("alpha", || panic!("not called"));
568 orch.add_agent("alpha", || panic!("not called"));
569 }
570
571 #[test]
572 fn duplicate_child_registration_preserves_existing_hierarchy() {
573 let mut orch = AgentOrchestrator::new();
574 orch.add_agent("parent1", || panic!("not called"));
575 orch.add_agent("parent2", || panic!("not called"));
576 orch.add_child("child", "parent1", || panic!("not called"));
577
578 let duplicate = std::panic::catch_unwind(AssertUnwindSafe(|| {
579 orch.add_child("child", "parent2", || panic!("not called"));
580 }));
581
582 assert!(duplicate.is_err());
583 assert_eq!(orch.parent_of("child"), Some("parent1"));
584 assert_eq!(orch.children_of("parent1").unwrap(), &["child"]);
585 assert!(orch.children_of("parent2").unwrap().is_empty());
586 }
587
588 #[test]
589 fn duplicate_top_level_registration_preserves_child_link() {
590 let mut orch = AgentOrchestrator::new();
591 orch.add_agent("parent", || panic!("not called"));
592 orch.add_child("child", "parent", || panic!("not called"));
593
594 let duplicate = std::panic::catch_unwind(AssertUnwindSafe(|| {
595 orch.add_agent("child", || panic!("not called"));
596 }));
597
598 assert!(duplicate.is_err());
599 assert_eq!(orch.parent_of("child"), Some("parent"));
600 assert_eq!(orch.children_of("parent").unwrap(), &["child"]);
601 }
602
603 #[test]
604 fn spawn_unregistered_agent_errors() {
605 let orch = AgentOrchestrator::new();
606 let result = orch.spawn("nonexistent");
607 assert!(result.is_err());
608 let err = result.unwrap_err();
609 assert!(format!("{err}").contains("orchestrator"));
610 }
611
612 #[test]
613 fn default_supervisor_retryable_restarts() {
614 let supervisor = DefaultSupervisor::default();
615 assert_eq!(supervisor.max_restarts(), 3);
616
617 let retryable = AgentError::ModelThrottled;
618 assert_eq!(
619 supervisor.on_agent_error("test", &retryable),
620 SupervisorAction::Restart
621 );
622
623 let non_retryable = AgentError::Aborted;
624 assert_eq!(
625 supervisor.on_agent_error("test", &non_retryable),
626 SupervisorAction::Stop
627 );
628 }
629
630 #[test]
631 fn supervisor_action_variants() {
632 assert_eq!(format!("{:?}", SupervisorAction::Restart), "Restart");
633 assert_eq!(format!("{:?}", SupervisorAction::Stop), "Stop");
634 assert_eq!(format!("{:?}", SupervisorAction::Escalate), "Escalate");
635 }
636
637 #[test]
638 fn orchestrator_debug_format() {
639 let orch = AgentOrchestrator::new();
640 let debug = format!("{orch:?}");
641 assert!(debug.contains("AgentOrchestrator"));
642 assert!(debug.contains("channel_buffer"));
643 }
644
645 #[test]
646 fn with_supervisor_sets_policy() {
647 let orch = AgentOrchestrator::new().with_supervisor(DefaultSupervisor::default());
648 assert!(orch.supervisor.is_some());
649 }
650
651 #[test]
652 fn with_channel_buffer_sets_size() {
653 let orch = AgentOrchestrator::new().with_channel_buffer(64);
654 assert_eq!(orch.channel_buffer, 64);
655 }
656
657 #[test]
658 fn with_max_restarts_sets_default() {
659 let mut orch = AgentOrchestrator::new().with_max_restarts(5);
660 orch.add_agent("a", || panic!("not called"));
661 assert_eq!(orch.entries["a"].max_restarts, 5);
662 }
663
664 #[test]
665 fn default_impl() {
666 let orch = AgentOrchestrator::default();
667 assert!(orch.entries.is_empty());
668 assert!(orch.supervisor.is_none());
669 }
670
671 #[test]
672 fn custom_supervisor_policy() {
673 struct AlwaysEscalate;
674 impl SupervisorPolicy for AlwaysEscalate {
675 fn on_agent_error(&self, _name: &str, _error: &AgentError) -> SupervisorAction {
676 SupervisorAction::Escalate
677 }
678 }
679
680 let supervisor = AlwaysEscalate;
681 assert_eq!(
682 supervisor.on_agent_error("x", &AgentError::ModelThrottled),
683 SupervisorAction::Escalate
684 );
685 }
686
687 #[test]
688 fn grandchild_hierarchy() {
689 let mut orch = AgentOrchestrator::new();
690 orch.add_agent("root", || panic!("not called"));
691 orch.add_child("mid", "root", || panic!("not called"));
692 orch.add_child("leaf", "mid", || panic!("not called"));
693
694 assert_eq!(orch.parent_of("leaf"), Some("mid"));
695 assert_eq!(orch.parent_of("mid"), Some("root"));
696 assert_eq!(orch.parent_of("root"), None);
697
698 assert_eq!(orch.children_of("root").unwrap(), &["mid"]);
699 assert_eq!(orch.children_of("mid").unwrap(), &["leaf"]);
700 assert!(orch.children_of("leaf").unwrap().is_empty());
701 }
702}