Skip to main content

swink_agent/
orchestrator.rs

1//! Multi-agent orchestration with parent/child hierarchies and supervision.
2//!
3//! [`AgentOrchestrator`] manages a set of named agents, tracks parent/child
4//! relationships, and applies a [`SupervisorPolicy`] when agents fail. Each
5//! spawned agent is represented by an [`OrchestratedHandle`] that supports
6//! request/response messaging, result retrieval, and cancellation.
7
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex, PoisonError};
10
11use tokio::sync::{mpsc, oneshot};
12use tokio_util::sync::CancellationToken;
13use tracing::{info, warn};
14
15use crate::agent::{Agent, AgentOptions};
16use crate::error::AgentError;
17use crate::handle::AgentStatus;
18use crate::task_core::{TaskCore, resolve_status};
19use crate::types::{AgentMessage, AgentResult, ContentBlock, LlmMessage, UserMessage};
20use crate::util::now_timestamp;
21
22// ─── Type aliases ───────────────────────────────────────────────────────────
23
24type OptionsFactoryArc = Arc<dyn Fn() -> AgentOptions + Send + Sync>;
25
26// ─── Request / Response channel ─────────────────────────────────────────────
27
28/// A message sent to a running agent via its request channel.
29pub struct AgentRequest {
30    /// The messages to inject into the agent.
31    pub messages: Vec<AgentMessage>,
32    /// A one-shot channel for the agent's response.
33    pub reply: oneshot::Sender<Result<AgentResult, AgentError>>,
34}
35
36// ─── Supervisor ─────────────────────────────────────────────────────────────
37
38/// What the supervisor decides after an agent error.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum SupervisorAction {
41    /// Restart the failed agent with the same options.
42    Restart,
43    /// Stop the agent permanently.
44    Stop,
45    /// Escalate the error to the caller but keep the agent alive.
46    Escalate,
47}
48
49/// Policy that determines how to handle agent failures.
50///
51/// Implement this trait and pass it to
52/// [`AgentOrchestrator::with_supervisor`] to customise recovery behaviour.
53pub trait SupervisorPolicy: Send + Sync {
54    /// Called when a spawned agent terminates with an error.
55    fn on_agent_error(&self, name: &str, error: &AgentError) -> SupervisorAction;
56}
57
58/// A supervisor that restarts on retryable errors and stops otherwise.
59#[derive(Debug, Clone)]
60pub struct DefaultSupervisor {
61    max_restarts: u32,
62}
63
64impl DefaultSupervisor {
65    /// Create a supervisor that allows up to `max_restarts` consecutive restarts.
66    #[must_use]
67    pub const fn new(max_restarts: u32) -> Self {
68        Self { max_restarts }
69    }
70
71    /// The maximum number of consecutive restarts allowed.
72    #[must_use]
73    pub const fn max_restarts(&self) -> u32 {
74        self.max_restarts
75    }
76}
77
78impl Default for DefaultSupervisor {
79    fn default() -> Self {
80        Self { max_restarts: 3 }
81    }
82}
83
84impl SupervisorPolicy for DefaultSupervisor {
85    fn on_agent_error(&self, _name: &str, error: &AgentError) -> SupervisorAction {
86        if error.is_retryable() {
87            SupervisorAction::Restart
88        } else {
89            SupervisorAction::Stop
90        }
91    }
92}
93
94// ─── Agent entry (internal bookkeeping) ─────────────────────────────────────
95
96/// Registration info stored in the orchestrator for each agent.
97struct AgentEntry {
98    /// Factory that produces fresh `AgentOptions` for (re)spawning.
99    options_factory: OptionsFactoryArc,
100    /// Parent agent name, if this is a child.
101    parent: Option<String>,
102    /// Child agent names.
103    children: Vec<String>,
104    /// Max restarts allowed by the supervisor (per spawn cycle).
105    max_restarts: u32,
106}
107
108// ─── OrchestratedHandle ─────────────────────────────────────────────────────
109
110/// Handle to a spawned orchestrated agent.
111///
112/// Provides request/response messaging, status polling, and cancellation.
113/// Lifecycle methods (status, cancel, `is_done`) are delegated to a shared task core.
114pub struct OrchestratedHandle {
115    name: String,
116    request_tx: mpsc::Sender<AgentRequest>,
117    core: TaskCore,
118}
119
120impl OrchestratedHandle {
121    /// The name of the agent this handle refers to.
122    #[must_use]
123    pub fn name(&self) -> &str {
124        &self.name
125    }
126
127    /// Send a text message to the running agent and await its response.
128    pub async fn send_message(&self, text: impl Into<String>) -> Result<AgentResult, AgentError> {
129        let msg = AgentMessage::Llm(LlmMessage::User(UserMessage {
130            content: vec![ContentBlock::Text { text: text.into() }],
131            timestamp: now_timestamp(),
132            cache_hint: None,
133        }));
134        self.send_messages(vec![msg]).await
135    }
136
137    /// Send multiple messages to the agent and await its response.
138    pub async fn send_messages(
139        &self,
140        messages: Vec<AgentMessage>,
141    ) -> Result<AgentResult, AgentError> {
142        let (reply_tx, reply_rx) = oneshot::channel();
143        let request = AgentRequest {
144            messages,
145            reply: reply_tx,
146        };
147        self.request_tx.send(request).await.map_err(|_| {
148            AgentError::plugin(
149                "orchestrator",
150                std::io::Error::other("agent channel closed"),
151            )
152        })?;
153
154        reply_rx.await.map_err(|_| {
155            AgentError::plugin("orchestrator", std::io::Error::other("agent reply dropped"))
156        })?
157    }
158
159    /// Consume the handle and await the agent's final result.
160    ///
161    /// Drops the request channel so the agent shuts down after processing
162    /// any remaining requests.
163    pub async fn await_result(self) -> Result<AgentResult, AgentError> {
164        drop(self.request_tx);
165        self.core.result().await
166    }
167
168    /// Cancel the agent.
169    pub fn cancel(&self) {
170        self.core.cancel();
171    }
172
173    /// Current status of the agent.
174    pub fn status(&self) -> AgentStatus {
175        self.core.status()
176    }
177
178    /// Whether the agent has finished (completed, failed, or cancelled).
179    pub fn is_done(&self) -> bool {
180        self.core.is_done()
181    }
182}
183
184impl std::fmt::Debug for OrchestratedHandle {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        f.debug_struct("OrchestratedHandle")
187            .field("name", &self.name)
188            .field("status", &self.status())
189            .finish_non_exhaustive()
190    }
191}
192
193// ─── AgentOrchestrator ──────────────────────────────────────────────────────
194
195/// Manages a set of named agents with parent/child hierarchies and supervision.
196///
197/// # Usage
198///
199/// ```ignore
200/// let mut orchestrator = AgentOrchestrator::new();
201/// orchestrator.add_agent("planner", || planner_options());
202/// orchestrator.add_child("researcher", "planner", || researcher_options());
203///
204/// let handle = orchestrator.spawn("planner")?;
205/// let result = handle.send_message("Plan a trip to Paris").await?;
206/// ```
207pub struct AgentOrchestrator {
208    entries: HashMap<String, AgentEntry>,
209    supervisor: Option<Arc<dyn SupervisorPolicy>>,
210    /// Channel buffer size for request channels.
211    channel_buffer: usize,
212    /// Default max restarts for agents (used when supervisor is set).
213    default_max_restarts: u32,
214}
215
216impl AgentOrchestrator {
217    /// Create a new empty orchestrator.
218    #[must_use]
219    pub fn new() -> Self {
220        Self {
221            entries: HashMap::new(),
222            supervisor: None,
223            channel_buffer: 32,
224            default_max_restarts: 3,
225        }
226    }
227
228    /// Set a supervisor policy for error recovery.
229    #[must_use]
230    pub fn with_supervisor(mut self, policy: impl SupervisorPolicy + 'static) -> Self {
231        self.supervisor = Some(Arc::new(policy));
232        self
233    }
234
235    /// Set the request channel buffer size (default: 32).
236    #[must_use]
237    pub const fn with_channel_buffer(mut self, size: usize) -> Self {
238        self.channel_buffer = size;
239        self
240    }
241
242    /// Set the default max restarts for supervised agents (default: 3).
243    #[must_use]
244    pub const fn with_max_restarts(mut self, max: u32) -> Self {
245        self.default_max_restarts = max;
246        self
247    }
248
249    /// Register an agent by name with a factory that produces its options.
250    ///
251    /// The factory is called each time the agent is spawned or restarted.
252    ///
253    /// # Panics
254    ///
255    /// Panics if an agent with the same name has already been registered.
256    pub fn add_agent(
257        &mut self,
258        name: impl Into<String>,
259        options_factory: impl Fn() -> AgentOptions + Send + Sync + 'static,
260    ) {
261        let name = name.into();
262        assert!(
263            !self.entries.contains_key(&name),
264            "agent '{name}' already registered"
265        );
266        self.entries.insert(
267            name,
268            AgentEntry {
269                options_factory: Arc::new(options_factory),
270                parent: None,
271                children: Vec::new(),
272                max_restarts: self.default_max_restarts,
273            },
274        );
275    }
276
277    /// Register a child agent under the given parent.
278    ///
279    /// # Panics
280    ///
281    /// Panics if the parent agent has not been registered or if the child name
282    /// is already registered.
283    pub fn add_child(
284        &mut self,
285        name: impl Into<String>,
286        parent: impl Into<String>,
287        options_factory: impl Fn() -> AgentOptions + Send + Sync + 'static,
288    ) {
289        let name = name.into();
290        let parent = parent.into();
291        assert!(
292            self.entries.contains_key(&parent),
293            "parent agent '{parent}' not registered"
294        );
295        assert!(
296            !self.entries.contains_key(&name),
297            "agent '{name}' already registered"
298        );
299
300        self.entries
301            .get_mut(&parent)
302            .expect("parent checked above")
303            .children
304            .push(name.clone());
305
306        self.entries.insert(
307            name,
308            AgentEntry {
309                options_factory: Arc::new(options_factory),
310                parent: Some(parent),
311                children: Vec::new(),
312                max_restarts: self.default_max_restarts,
313            },
314        );
315    }
316
317    /// Get the parent name for a registered agent.
318    #[must_use]
319    pub fn parent_of(&self, name: &str) -> Option<&str> {
320        self.entries.get(name).and_then(|e| e.parent.as_deref())
321    }
322
323    /// Get the child names for a registered agent.
324    #[must_use]
325    pub fn children_of(&self, name: &str) -> Option<&[String]> {
326        self.entries.get(name).map(|e| e.children.as_slice())
327    }
328
329    /// List all registered agent names.
330    #[must_use]
331    pub fn names(&self) -> Vec<&str> {
332        self.entries.keys().map(String::as_str).collect()
333    }
334
335    /// Whether an agent with this name is registered.
336    #[must_use]
337    pub fn contains(&self, name: &str) -> bool {
338        self.entries.contains_key(name)
339    }
340
341    /// Spawn a registered agent, returning a handle for interaction.
342    ///
343    /// The agent runs in a background tokio task listening for requests. Each
344    /// request triggers `prompt_async` and the result is sent via a one-shot
345    /// reply channel.
346    ///
347    /// If a [`SupervisorPolicy`] is set, the agent is automatically restarted
348    /// when the supervisor returns [`SupervisorAction::Restart`].
349    ///
350    /// # Errors
351    ///
352    /// Returns [`AgentError::Plugin`] if the agent name is not registered.
353    pub fn spawn(&self, name: &str) -> Result<OrchestratedHandle, AgentError> {
354        let entry = self.entries.get(name).ok_or_else(|| {
355            AgentError::plugin(
356                "orchestrator",
357                std::io::Error::other(format!("agent not registered: {name}")),
358            )
359        })?;
360
361        let factory = Arc::clone(&entry.options_factory);
362        let max_restarts = entry.max_restarts;
363        let agent_name = name.to_owned();
364        let supervisor = self.supervisor.clone();
365
366        let (request_tx, request_rx) = mpsc::channel::<AgentRequest>(self.channel_buffer);
367        let cancellation_token = CancellationToken::new();
368        let status = Arc::new(Mutex::new(AgentStatus::Running));
369
370        let status_clone = Arc::clone(&status);
371        let token_clone = cancellation_token.clone();
372
373        let join_handle = tokio::spawn(run_agent_loop(
374            agent_name,
375            factory,
376            request_rx,
377            token_clone,
378            status_clone,
379            supervisor,
380            max_restarts,
381        ));
382
383        Ok(OrchestratedHandle {
384            name: name.to_owned(),
385            request_tx,
386            core: TaskCore::new(join_handle, cancellation_token, status),
387        })
388    }
389}
390
391/// The core agent loop that runs inside a spawned tokio task.
392///
393/// Receives requests on the channel, processes them with the agent, and
394/// optionally restarts the agent on failure per the supervisor policy.
395async fn run_agent_loop(
396    agent_name: String,
397    factory: OptionsFactoryArc,
398    mut request_rx: mpsc::Receiver<AgentRequest>,
399    cancellation_token: CancellationToken,
400    status: Arc<Mutex<AgentStatus>>,
401    supervisor: Option<Arc<dyn SupervisorPolicy>>,
402    max_restarts: u32,
403) -> Result<AgentResult, AgentError> {
404    let mut agent = Agent::new(factory());
405    let mut restarts: u32 = 0;
406
407    let final_result = loop {
408        tokio::select! {
409            biased;
410
411            () = cancellation_token.cancelled() => {
412                agent.abort();
413                break Err(AgentError::Aborted);
414            }
415
416            maybe_req = request_rx.recv() => {
417                if let Some(req) = maybe_req {
418                    let result = tokio::select! {
419                        biased;
420                        () = cancellation_token.cancelled() => {
421                            agent.abort();
422                            let _ = req.reply.send(Err(AgentError::Aborted));
423                            break Err(AgentError::Aborted);
424                        }
425                        r = agent.prompt_async(req.messages) => r,
426                    };
427
428                    match result {
429                        Ok(r) => {
430                            let _ = req.reply.send(Ok(r));
431                            // Reset restart counter on success.
432                            restarts = 0;
433                        }
434                        Err(err) => {
435                            let action = supervisor
436                                .as_ref()
437                                .map_or(SupervisorAction::Escalate, |s| {
438                                    s.on_agent_error(&agent_name, &err)
439                                });
440
441                            match action {
442                                SupervisorAction::Restart if restarts < max_restarts => {
443                                    warn!(
444                                        agent = %agent_name,
445                                        restart = restarts + 1,
446                                        max = max_restarts,
447                                        "supervisor restarting agent"
448                                    );
449                                    restarts += 1;
450                                    let _ = req.reply.send(Err(err));
451                                    agent = Agent::new(factory());
452                                }
453                                SupervisorAction::Escalate => {
454                                    let _ = req.reply.send(Err(err));
455                                    // Agent stays alive.
456                                }
457                                _ => {
458                                    // Stop (or restart budget exhausted).
459                                    let _ = req.reply.send(Err(err));
460                                    break Err(AgentError::plugin(
461                                        "orchestrator",
462                                        std::io::Error::other(format!(
463                                            "agent '{agent_name}' stopped by supervisor"
464                                        )),
465                                    ));
466                                }
467                            }
468                        }
469                    }
470                } else {
471                    // Channel closed — clean shutdown.
472                    info!(agent = %agent_name, "request channel closed, shutting down");
473                    break Ok(AgentResult {
474                        messages: Vec::new(),
475                        stop_reason: crate::types::StopReason::Stop,
476                        usage: crate::types::Usage::default(),
477                        cost: crate::types::Cost::default(),
478                        error: None,
479                        transfer_signal: None,
480                    });
481                }
482            }
483        }
484    };
485
486    *status.lock().unwrap_or_else(PoisonError::into_inner) = resolve_status(&final_result);
487    final_result
488}
489
490impl Default for AgentOrchestrator {
491    fn default() -> Self {
492        Self::new()
493    }
494}
495
496impl std::fmt::Debug for AgentOrchestrator {
497    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
498        f.debug_struct("AgentOrchestrator")
499            .field("agents", &self.entries.keys().collect::<Vec<_>>())
500            .field(
501                "supervisor",
502                &if self.supervisor.is_some() {
503                    "Some"
504                } else {
505                    "None"
506                },
507            )
508            .field("channel_buffer", &self.channel_buffer)
509            .finish_non_exhaustive()
510    }
511}
512
513// ─── Tests ──────────────────────────────────────────────────────────────────
514
515#[cfg(test)]
516mod tests {
517    use std::panic::AssertUnwindSafe;
518
519    use super::*;
520
521    #[test]
522    fn add_agent_and_names() {
523        let mut orch = AgentOrchestrator::new();
524        orch.add_agent("alpha", || panic!("not called"));
525        orch.add_agent("beta", || panic!("not called"));
526
527        let mut names = orch.names();
528        names.sort_unstable();
529        assert_eq!(names, vec!["alpha", "beta"]);
530    }
531
532    #[test]
533    fn contains_registered() {
534        let mut orch = AgentOrchestrator::new();
535        orch.add_agent("a", || panic!("not called"));
536        assert!(orch.contains("a"));
537        assert!(!orch.contains("b"));
538    }
539
540    #[test]
541    fn parent_child_hierarchy() {
542        let mut orch = AgentOrchestrator::new();
543        orch.add_agent("parent", || panic!("not called"));
544        orch.add_child("child1", "parent", || panic!("not called"));
545        orch.add_child("child2", "parent", || panic!("not called"));
546
547        assert_eq!(orch.parent_of("child1"), Some("parent"));
548        assert_eq!(orch.parent_of("child2"), Some("parent"));
549        assert_eq!(orch.parent_of("parent"), None);
550
551        let children = orch.children_of("parent").unwrap();
552        assert_eq!(children, &["child1", "child2"]);
553        assert!(orch.children_of("child1").unwrap().is_empty());
554    }
555
556    #[test]
557    #[should_panic(expected = "parent agent 'missing' not registered")]
558    fn add_child_missing_parent_panics() {
559        let mut orch = AgentOrchestrator::new();
560        orch.add_child("child", "missing", || panic!("not called"));
561    }
562
563    #[test]
564    #[should_panic(expected = "agent 'alpha' already registered")]
565    fn add_agent_duplicate_name_panics() {
566        let mut orch = AgentOrchestrator::new();
567        orch.add_agent("alpha", || panic!("not called"));
568        orch.add_agent("alpha", || panic!("not called"));
569    }
570
571    #[test]
572    fn duplicate_child_registration_preserves_existing_hierarchy() {
573        let mut orch = AgentOrchestrator::new();
574        orch.add_agent("parent1", || panic!("not called"));
575        orch.add_agent("parent2", || panic!("not called"));
576        orch.add_child("child", "parent1", || panic!("not called"));
577
578        let duplicate = std::panic::catch_unwind(AssertUnwindSafe(|| {
579            orch.add_child("child", "parent2", || panic!("not called"));
580        }));
581
582        assert!(duplicate.is_err());
583        assert_eq!(orch.parent_of("child"), Some("parent1"));
584        assert_eq!(orch.children_of("parent1").unwrap(), &["child"]);
585        assert!(orch.children_of("parent2").unwrap().is_empty());
586    }
587
588    #[test]
589    fn duplicate_top_level_registration_preserves_child_link() {
590        let mut orch = AgentOrchestrator::new();
591        orch.add_agent("parent", || panic!("not called"));
592        orch.add_child("child", "parent", || panic!("not called"));
593
594        let duplicate = std::panic::catch_unwind(AssertUnwindSafe(|| {
595            orch.add_agent("child", || panic!("not called"));
596        }));
597
598        assert!(duplicate.is_err());
599        assert_eq!(orch.parent_of("child"), Some("parent"));
600        assert_eq!(orch.children_of("parent").unwrap(), &["child"]);
601    }
602
603    #[test]
604    fn spawn_unregistered_agent_errors() {
605        let orch = AgentOrchestrator::new();
606        let result = orch.spawn("nonexistent");
607        assert!(result.is_err());
608        let err = result.unwrap_err();
609        assert!(format!("{err}").contains("orchestrator"));
610    }
611
612    #[test]
613    fn default_supervisor_retryable_restarts() {
614        let supervisor = DefaultSupervisor::default();
615        assert_eq!(supervisor.max_restarts(), 3);
616
617        let retryable = AgentError::ModelThrottled;
618        assert_eq!(
619            supervisor.on_agent_error("test", &retryable),
620            SupervisorAction::Restart
621        );
622
623        let non_retryable = AgentError::Aborted;
624        assert_eq!(
625            supervisor.on_agent_error("test", &non_retryable),
626            SupervisorAction::Stop
627        );
628    }
629
630    #[test]
631    fn supervisor_action_variants() {
632        assert_eq!(format!("{:?}", SupervisorAction::Restart), "Restart");
633        assert_eq!(format!("{:?}", SupervisorAction::Stop), "Stop");
634        assert_eq!(format!("{:?}", SupervisorAction::Escalate), "Escalate");
635    }
636
637    #[test]
638    fn orchestrator_debug_format() {
639        let orch = AgentOrchestrator::new();
640        let debug = format!("{orch:?}");
641        assert!(debug.contains("AgentOrchestrator"));
642        assert!(debug.contains("channel_buffer"));
643    }
644
645    #[test]
646    fn with_supervisor_sets_policy() {
647        let orch = AgentOrchestrator::new().with_supervisor(DefaultSupervisor::default());
648        assert!(orch.supervisor.is_some());
649    }
650
651    #[test]
652    fn with_channel_buffer_sets_size() {
653        let orch = AgentOrchestrator::new().with_channel_buffer(64);
654        assert_eq!(orch.channel_buffer, 64);
655    }
656
657    #[test]
658    fn with_max_restarts_sets_default() {
659        let mut orch = AgentOrchestrator::new().with_max_restarts(5);
660        orch.add_agent("a", || panic!("not called"));
661        assert_eq!(orch.entries["a"].max_restarts, 5);
662    }
663
664    #[test]
665    fn default_impl() {
666        let orch = AgentOrchestrator::default();
667        assert!(orch.entries.is_empty());
668        assert!(orch.supervisor.is_none());
669    }
670
671    #[test]
672    fn custom_supervisor_policy() {
673        struct AlwaysEscalate;
674        impl SupervisorPolicy for AlwaysEscalate {
675            fn on_agent_error(&self, _name: &str, _error: &AgentError) -> SupervisorAction {
676                SupervisorAction::Escalate
677            }
678        }
679
680        let supervisor = AlwaysEscalate;
681        assert_eq!(
682            supervisor.on_agent_error("x", &AgentError::ModelThrottled),
683            SupervisorAction::Escalate
684        );
685    }
686
687    #[test]
688    fn grandchild_hierarchy() {
689        let mut orch = AgentOrchestrator::new();
690        orch.add_agent("root", || panic!("not called"));
691        orch.add_child("mid", "root", || panic!("not called"));
692        orch.add_child("leaf", "mid", || panic!("not called"));
693
694        assert_eq!(orch.parent_of("leaf"), Some("mid"));
695        assert_eq!(orch.parent_of("mid"), Some("root"));
696        assert_eq!(orch.parent_of("root"), None);
697
698        assert_eq!(orch.children_of("root").unwrap(), &["mid"]);
699        assert_eq!(orch.children_of("mid").unwrap(), &["leaf"]);
700        assert!(orch.children_of("leaf").unwrap().is_empty());
701    }
702}