Skip to main content

rustant_core/multi/
orchestrator.rs

1//! Agent orchestrator — the async run loop that connects message bus, spawner,
2//! routing, and task handlers into a cohesive multi-agent execution engine.
3//!
4//! The orchestrator receives messages from the `MessageBus`, dispatches them
5//! to registered `TaskHandler` implementations, and returns results via the bus.
6//! It enforces `ResourceLimits` on each agent.
7
8use super::messaging::{AgentEnvelope, AgentPayload, MessageBus, MessagePriority};
9use super::routing::AgentRouter;
10use super::spawner::AgentSpawner;
11use async_trait::async_trait;
12use std::collections::HashMap;
13use uuid::Uuid;
14
15/// Trait for handling tasks dispatched by the orchestrator.
16///
17/// Implementations receive a task description and arguments, execute the work,
18/// and return a string result or an error message.
19#[async_trait]
20pub trait TaskHandler: Send + Sync {
21    async fn handle_task(
22        &self,
23        description: &str,
24        args: &HashMap<String, String>,
25    ) -> Result<String, String>;
26}
27
28/// The agent orchestrator ties together the spawner, message bus, router,
29/// and task handlers into a cohesive execution engine.
30pub struct AgentOrchestrator {
31    spawner: AgentSpawner,
32    bus: MessageBus,
33    router: AgentRouter,
34    handlers: HashMap<Uuid, Box<dyn TaskHandler>>,
35    tool_call_counts: HashMap<Uuid, u32>,
36}
37
38impl AgentOrchestrator {
39    /// Create a new orchestrator with the given components.
40    pub fn new(spawner: AgentSpawner, bus: MessageBus, router: AgentRouter) -> Self {
41        Self {
42            spawner,
43            bus,
44            router,
45            handlers: HashMap::new(),
46            tool_call_counts: HashMap::new(),
47        }
48    }
49
50    /// Register a task handler for a specific agent.
51    pub fn register_handler(&mut self, agent_id: Uuid, handler: Box<dyn TaskHandler>) {
52        self.handlers.insert(agent_id, handler);
53    }
54
55    /// Access the spawner.
56    pub fn spawner(&self) -> &AgentSpawner {
57        &self.spawner
58    }
59
60    /// Mutably access the spawner.
61    pub fn spawner_mut(&mut self) -> &mut AgentSpawner {
62        &mut self.spawner
63    }
64
65    /// Access the message bus.
66    pub fn bus(&self) -> &MessageBus {
67        &self.bus
68    }
69
70    /// Mutably access the message bus.
71    pub fn bus_mut(&mut self) -> &mut MessageBus {
72        &mut self.bus
73    }
74
75    /// Access the router.
76    pub fn router(&self) -> &AgentRouter {
77        &self.router
78    }
79
80    /// Mutably access the router.
81    pub fn router_mut(&mut self) -> &mut AgentRouter {
82        &mut self.router
83    }
84
85    /// Get the current tool call count for an agent.
86    pub fn tool_call_count(&self, agent_id: &Uuid) -> u32 {
87        self.tool_call_counts.get(agent_id).copied().unwrap_or(0)
88    }
89
90    /// Reset tool call counts for an agent (e.g., at the start of a new turn).
91    pub fn reset_tool_counts(&mut self, agent_id: &Uuid) {
92        self.tool_call_counts.insert(*agent_id, 0);
93    }
94
95    /// Check whether processing a task would violate the agent's resource limits.
96    ///
97    /// Returns `Ok(())` if within limits, or `Err(reason)` if a limit would be exceeded.
98    pub fn check_resource_limits(&self, agent_id: &Uuid) -> Result<(), String> {
99        let limits = self
100            .spawner
101            .get(agent_id)
102            .map(|ctx| &ctx.resource_limits)
103            .cloned()
104            .unwrap_or_default();
105
106        // Check tool call limit
107        if let Some(max_calls) = limits.max_tool_calls {
108            let current = self.tool_call_count(agent_id);
109            if current >= max_calls {
110                return Err(format!(
111                    "Agent {} exceeded max_tool_calls limit ({}/{})",
112                    agent_id, current, max_calls
113                ));
114            }
115        }
116
117        Ok(())
118    }
119
120    /// Process all pending messages for all registered agents.
121    ///
122    /// For each agent with pending messages:
123    /// 1. Check resource limits
124    /// 2. Receive the message
125    /// 3. Dispatch to the registered handler (for TaskRequest)
126    /// 4. Send the result back via the bus
127    ///
128    /// Returns the number of messages processed.
129    pub async fn process_pending(&mut self) -> usize {
130        // Collect agent IDs that have pending messages and handlers
131        let agent_ids: Vec<Uuid> = self
132            .handlers
133            .keys()
134            .filter(|id| self.bus.pending_count(id) > 0)
135            .copied()
136            .collect();
137
138        let mut processed = 0;
139
140        for agent_id in agent_ids {
141            // Check resource limits before processing
142            if let Err(reason) = self.check_resource_limits(&agent_id) {
143                // Send an error back if there's a pending message
144                if let Some(envelope) = self.bus.receive(&agent_id) {
145                    let error_response = AgentEnvelope::new(
146                        agent_id,
147                        envelope.from,
148                        AgentPayload::Error {
149                            code: "RESOURCE_LIMIT".into(),
150                            message: reason,
151                            recoverable: false,
152                        },
153                    )
154                    .with_priority(MessagePriority::High);
155                    if let Some(corr) = envelope.correlation_id {
156                        let error_response = error_response.with_correlation(corr);
157                        let _ = self.bus.send(error_response);
158                    } else {
159                        let _ = self.bus.send(error_response);
160                    }
161                    processed += 1;
162                }
163                continue;
164            }
165
166            // Receive the next message
167            let envelope = match self.bus.receive(&agent_id) {
168                Some(e) => e,
169                None => continue,
170            };
171
172            match &envelope.payload {
173                AgentPayload::TaskRequest { description, args } => {
174                    // Increment tool call count
175                    *self.tool_call_counts.entry(agent_id).or_insert(0) += 1;
176
177                    let handler = match self.handlers.get(&agent_id) {
178                        Some(h) => h,
179                        None => continue,
180                    };
181
182                    let result = handler.handle_task(description, args).await;
183
184                    let response_payload = match result {
185                        Ok(output) => AgentPayload::TaskResult {
186                            success: true,
187                            output,
188                        },
189                        Err(err) => AgentPayload::TaskResult {
190                            success: false,
191                            output: err,
192                        },
193                    };
194
195                    let mut response =
196                        AgentEnvelope::new(agent_id, envelope.from, response_payload);
197                    if let Some(corr) = envelope.correlation_id {
198                        response = response.with_correlation(corr);
199                    }
200                    let _ = self.bus.send(response);
201                    processed += 1;
202                }
203                AgentPayload::Shutdown => {
204                    // Terminate the agent and its children
205                    self.spawner.terminate(agent_id);
206                    self.handlers.remove(&agent_id);
207                    self.tool_call_counts.remove(&agent_id);
208                    processed += 1;
209                }
210                AgentPayload::StatusQuery => {
211                    let pending = self.bus.pending_count(&agent_id);
212                    let agent_name = self
213                        .spawner
214                        .get(&agent_id)
215                        .map(|ctx| ctx.name.clone())
216                        .unwrap_or_else(|| "unknown".to_string());
217                    let response = AgentEnvelope::new(
218                        agent_id,
219                        envelope.from,
220                        AgentPayload::StatusResponse {
221                            agent_name,
222                            active: true,
223                            pending_tasks: pending,
224                        },
225                    );
226                    let _ = self.bus.send(response);
227                    processed += 1;
228                }
229                _ => {
230                    // Other payload types are forwarded as-is (no special handling)
231                    processed += 1;
232                }
233            }
234        }
235
236        processed
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::multi::spawner::SpawnerConfig;
244
245    struct EchoHandler;
246
247    #[async_trait]
248    impl TaskHandler for EchoHandler {
249        async fn handle_task(
250            &self,
251            description: &str,
252            _args: &HashMap<String, String>,
253        ) -> Result<String, String> {
254            Ok(format!("echo: {}", description))
255        }
256    }
257
258    struct FailHandler;
259
260    #[async_trait]
261    impl TaskHandler for FailHandler {
262        async fn handle_task(
263            &self,
264            _description: &str,
265            _args: &HashMap<String, String>,
266        ) -> Result<String, String> {
267            Err("task failed".to_string())
268        }
269    }
270
271    fn setup_orchestrator() -> (AgentOrchestrator, Uuid) {
272        let mut spawner = AgentSpawner::default();
273        let agent_id = spawner.spawn("test-agent").unwrap();
274
275        let mut bus = MessageBus::new(100);
276        bus.register(agent_id);
277
278        let router = AgentRouter::new();
279        let mut orch = AgentOrchestrator::new(spawner, bus, router);
280        orch.register_handler(agent_id, Box::new(EchoHandler));
281
282        (orch, agent_id)
283    }
284
285    #[tokio::test]
286    async fn test_orchestrator_processes_task_request() {
287        let (mut orch, agent_id) = setup_orchestrator();
288
289        // Also register a "sender" so we can receive the response
290        let sender_id = orch.spawner_mut().spawn("sender").unwrap();
291        orch.bus_mut().register(sender_id);
292
293        let task = AgentEnvelope::new(
294            sender_id,
295            agent_id,
296            AgentPayload::TaskRequest {
297                description: "hello world".into(),
298                args: HashMap::new(),
299            },
300        );
301        orch.bus_mut().send(task).unwrap();
302
303        let processed = orch.process_pending().await;
304        assert_eq!(processed, 1);
305
306        // Check the response was sent back
307        let response = orch.bus_mut().receive(&sender_id).unwrap();
308        match &response.payload {
309            AgentPayload::TaskResult { success, output } => {
310                assert!(success);
311                assert_eq!(output, "echo: hello world");
312            }
313            _ => panic!("Expected TaskResult"),
314        }
315    }
316
317    #[tokio::test]
318    async fn test_orchestrator_handles_task_failure() {
319        let mut spawner = AgentSpawner::default();
320        let agent_id = spawner.spawn("fail-agent").unwrap();
321        let sender_id = spawner.spawn("sender").unwrap();
322
323        let mut bus = MessageBus::new(100);
324        bus.register(agent_id);
325        bus.register(sender_id);
326
327        let router = AgentRouter::new();
328        let mut orch = AgentOrchestrator::new(spawner, bus, router);
329        orch.register_handler(agent_id, Box::new(FailHandler));
330
331        let task = AgentEnvelope::new(
332            sender_id,
333            agent_id,
334            AgentPayload::TaskRequest {
335                description: "will fail".into(),
336                args: HashMap::new(),
337            },
338        );
339        orch.bus_mut().send(task).unwrap();
340
341        orch.process_pending().await;
342
343        let response = orch.bus_mut().receive(&sender_id).unwrap();
344        match &response.payload {
345            AgentPayload::TaskResult { success, output } => {
346                assert!(!success);
347                assert_eq!(output, "task failed");
348            }
349            _ => panic!("Expected TaskResult"),
350        }
351    }
352
353    #[tokio::test]
354    async fn test_orchestrator_correlation_id_preserved() {
355        let (mut orch, agent_id) = setup_orchestrator();
356        let sender_id = orch.spawner_mut().spawn("sender").unwrap();
357        orch.bus_mut().register(sender_id);
358
359        let corr_id = Uuid::new_v4();
360        let task = AgentEnvelope::new(
361            sender_id,
362            agent_id,
363            AgentPayload::TaskRequest {
364                description: "correlated".into(),
365                args: HashMap::new(),
366            },
367        )
368        .with_correlation(corr_id);
369        orch.bus_mut().send(task).unwrap();
370
371        orch.process_pending().await;
372
373        let response = orch.bus_mut().receive(&sender_id).unwrap();
374        assert_eq!(response.correlation_id, Some(corr_id));
375    }
376
377    #[tokio::test]
378    async fn test_orchestrator_handles_shutdown() {
379        let (mut orch, agent_id) = setup_orchestrator();
380        let sender_id = orch.spawner_mut().spawn("sender").unwrap();
381        orch.bus_mut().register(sender_id);
382
383        let shutdown = AgentEnvelope::new(sender_id, agent_id, AgentPayload::Shutdown);
384        orch.bus_mut().send(shutdown).unwrap();
385
386        let processed = orch.process_pending().await;
387        assert_eq!(processed, 1);
388
389        // Agent should be terminated
390        assert!(orch.spawner().get(&agent_id).is_none());
391    }
392
393    #[tokio::test]
394    async fn test_orchestrator_handles_status_query() {
395        let (mut orch, agent_id) = setup_orchestrator();
396        let sender_id = orch.spawner_mut().spawn("sender").unwrap();
397        orch.bus_mut().register(sender_id);
398
399        let query = AgentEnvelope::new(sender_id, agent_id, AgentPayload::StatusQuery);
400        orch.bus_mut().send(query).unwrap();
401
402        orch.process_pending().await;
403
404        let response = orch.bus_mut().receive(&sender_id).unwrap();
405        match &response.payload {
406            AgentPayload::StatusResponse {
407                agent_name,
408                active,
409                pending_tasks,
410            } => {
411                assert_eq!(agent_name, "test-agent");
412                assert!(active);
413                assert_eq!(*pending_tasks, 0);
414            }
415            _ => panic!("Expected StatusResponse"),
416        }
417    }
418
419    #[tokio::test]
420    async fn test_orchestrator_respects_tool_call_limit() {
421        let mut spawner = AgentSpawner::new(SpawnerConfig::default());
422        let agent_id = spawner.spawn("limited-agent").unwrap();
423        let sender_id = spawner.spawn("sender").unwrap();
424
425        // Set resource limits on the agent
426        if let Some(ctx) = spawner.get_mut(&agent_id) {
427            ctx.resource_limits.max_tool_calls = Some(2);
428        }
429
430        let mut bus = MessageBus::new(100);
431        bus.register(agent_id);
432        bus.register(sender_id);
433
434        let router = AgentRouter::new();
435        let mut orch = AgentOrchestrator::new(spawner, bus, router);
436        orch.register_handler(agent_id, Box::new(EchoHandler));
437
438        // Send and process tasks one at a time to verify limit enforcement
439        // Task 1 — should succeed
440        let task1 = AgentEnvelope::new(
441            sender_id,
442            agent_id,
443            AgentPayload::TaskRequest {
444                description: "task-0".into(),
445                args: HashMap::new(),
446            },
447        );
448        orch.bus_mut().send(task1).unwrap();
449        orch.process_pending().await;
450
451        let r1 = orch.bus_mut().receive(&sender_id).unwrap();
452        match &r1.payload {
453            AgentPayload::TaskResult { success, .. } => assert!(success),
454            other => panic!(
455                "Expected TaskResult, got {:?}",
456                std::mem::discriminant(other)
457            ),
458        }
459
460        // Task 2 — should succeed (count = 2 now)
461        let task2 = AgentEnvelope::new(
462            sender_id,
463            agent_id,
464            AgentPayload::TaskRequest {
465                description: "task-1".into(),
466                args: HashMap::new(),
467            },
468        );
469        orch.bus_mut().send(task2).unwrap();
470        orch.process_pending().await;
471
472        let r2 = orch.bus_mut().receive(&sender_id).unwrap();
473        match &r2.payload {
474            AgentPayload::TaskResult { success, .. } => assert!(success),
475            other => panic!(
476                "Expected TaskResult, got {:?}",
477                std::mem::discriminant(other)
478            ),
479        }
480
481        // Task 3 — should hit resource limit (count = 2, max = 2)
482        let task3 = AgentEnvelope::new(
483            sender_id,
484            agent_id,
485            AgentPayload::TaskRequest {
486                description: "task-2".into(),
487                args: HashMap::new(),
488            },
489        );
490        orch.bus_mut().send(task3).unwrap();
491        orch.process_pending().await;
492
493        let r3 = orch.bus_mut().receive(&sender_id).unwrap();
494        match &r3.payload {
495            AgentPayload::Error {
496                code, recoverable, ..
497            } => {
498                assert_eq!(code, "RESOURCE_LIMIT");
499                assert!(!recoverable);
500            }
501            other => panic!(
502                "Expected Error for third task, got {:?}",
503                std::mem::discriminant(other)
504            ),
505        }
506    }
507
508    #[test]
509    fn test_tool_call_count_tracking() {
510        let spawner = AgentSpawner::default();
511        let bus = MessageBus::new(100);
512        let router = AgentRouter::new();
513        let mut orch = AgentOrchestrator::new(spawner, bus, router);
514
515        let agent_id = Uuid::new_v4();
516        assert_eq!(orch.tool_call_count(&agent_id), 0);
517
518        orch.tool_call_counts.insert(agent_id, 5);
519        assert_eq!(orch.tool_call_count(&agent_id), 5);
520
521        orch.reset_tool_counts(&agent_id);
522        assert_eq!(orch.tool_call_count(&agent_id), 0);
523    }
524
525    #[tokio::test]
526    async fn test_orchestrator_no_pending_returns_zero() {
527        let (mut orch, _) = setup_orchestrator();
528        let processed = orch.process_pending().await;
529        assert_eq!(processed, 0);
530    }
531
532    #[tokio::test]
533    async fn test_orchestrator_parent_delegates_to_child() {
534        let mut spawner = AgentSpawner::default();
535        let parent_id = spawner.spawn("parent").unwrap();
536        let child_id = spawner.spawn_child("child", parent_id).unwrap();
537
538        let mut bus = MessageBus::new(100);
539        bus.register(parent_id);
540        bus.register(child_id);
541
542        let router = AgentRouter::new();
543        let mut orch = AgentOrchestrator::new(spawner, bus, router);
544        orch.register_handler(child_id, Box::new(EchoHandler));
545
546        // Parent sends task to child
547        let task = AgentEnvelope::new(
548            parent_id,
549            child_id,
550            AgentPayload::TaskRequest {
551                description: "delegated task".into(),
552                args: HashMap::new(),
553            },
554        );
555        orch.bus_mut().send(task).unwrap();
556
557        orch.process_pending().await;
558
559        // Parent receives the result
560        let response = orch.bus_mut().receive(&parent_id).unwrap();
561        match &response.payload {
562            AgentPayload::TaskResult { success, output } => {
563                assert!(success);
564                assert_eq!(output, "echo: delegated task");
565            }
566            _ => panic!("Expected TaskResult"),
567        }
568    }
569
570    #[tokio::test]
571    async fn test_orchestrator_multiple_agents() {
572        let mut spawner = AgentSpawner::default();
573        let agent_a = spawner.spawn("agent-a").unwrap();
574        let agent_b = spawner.spawn("agent-b").unwrap();
575        let coordinator = spawner.spawn("coordinator").unwrap();
576
577        let mut bus = MessageBus::new(100);
578        bus.register(agent_a);
579        bus.register(agent_b);
580        bus.register(coordinator);
581
582        let router = AgentRouter::new();
583        let mut orch = AgentOrchestrator::new(spawner, bus, router);
584        orch.register_handler(agent_a, Box::new(EchoHandler));
585        orch.register_handler(agent_b, Box::new(EchoHandler));
586
587        // Coordinator sends tasks to both agents
588        let task_a = AgentEnvelope::new(
589            coordinator,
590            agent_a,
591            AgentPayload::TaskRequest {
592                description: "task-for-a".into(),
593                args: HashMap::new(),
594            },
595        );
596        let task_b = AgentEnvelope::new(
597            coordinator,
598            agent_b,
599            AgentPayload::TaskRequest {
600                description: "task-for-b".into(),
601                args: HashMap::new(),
602            },
603        );
604        orch.bus_mut().send(task_a).unwrap();
605        orch.bus_mut().send(task_b).unwrap();
606
607        let processed = orch.process_pending().await;
608        assert_eq!(processed, 2);
609
610        // Coordinator should have 2 responses
611        let r1 = orch.bus_mut().receive(&coordinator).unwrap();
612        let r2 = orch.bus_mut().receive(&coordinator).unwrap();
613
614        let mut outputs: Vec<String> = Vec::new();
615        for r in [&r1, &r2] {
616            if let AgentPayload::TaskResult { output, .. } = &r.payload {
617                outputs.push(output.clone());
618            }
619        }
620        outputs.sort();
621        assert_eq!(outputs, vec!["echo: task-for-a", "echo: task-for-b"]);
622    }
623
624    #[test]
625    fn test_check_resource_limits_no_limits() {
626        let mut spawner = AgentSpawner::default();
627        let agent_id = spawner.spawn("no-limits").unwrap();
628        let bus = MessageBus::new(100);
629        let router = AgentRouter::new();
630        let orch = AgentOrchestrator::new(spawner, bus, router);
631        assert!(orch.check_resource_limits(&agent_id).is_ok());
632    }
633
634    #[test]
635    fn test_check_resource_limits_exceeded() {
636        let mut spawner = AgentSpawner::new(SpawnerConfig::default());
637        let agent_id = spawner.spawn("limited").unwrap();
638        if let Some(ctx) = spawner.get_mut(&agent_id) {
639            ctx.resource_limits.max_tool_calls = Some(3);
640        }
641
642        let bus = MessageBus::new(100);
643        let router = AgentRouter::new();
644        let mut orch = AgentOrchestrator::new(spawner, bus, router);
645
646        // Simulate 3 tool calls
647        orch.tool_call_counts.insert(agent_id, 3);
648
649        let result = orch.check_resource_limits(&agent_id);
650        assert!(result.is_err());
651        assert!(result.unwrap_err().contains("max_tool_calls"));
652    }
653}