1use super::messaging::{AgentEnvelope, AgentPayload, MessageBus, MessagePriority};
9use super::routing::AgentRouter;
10use super::spawner::AgentSpawner;
11use async_trait::async_trait;
12use std::collections::HashMap;
13use uuid::Uuid;
14
15#[async_trait]
20pub trait TaskHandler: Send + Sync {
21 async fn handle_task(
22 &self,
23 description: &str,
24 args: &HashMap<String, String>,
25 ) -> Result<String, String>;
26}
27
28pub struct AgentOrchestrator {
31 spawner: AgentSpawner,
32 bus: MessageBus,
33 router: AgentRouter,
34 handlers: HashMap<Uuid, Box<dyn TaskHandler>>,
35 tool_call_counts: HashMap<Uuid, u32>,
36}
37
38impl AgentOrchestrator {
39 pub fn new(spawner: AgentSpawner, bus: MessageBus, router: AgentRouter) -> Self {
41 Self {
42 spawner,
43 bus,
44 router,
45 handlers: HashMap::new(),
46 tool_call_counts: HashMap::new(),
47 }
48 }
49
50 pub fn register_handler(&mut self, agent_id: Uuid, handler: Box<dyn TaskHandler>) {
52 self.handlers.insert(agent_id, handler);
53 }
54
55 pub fn spawner(&self) -> &AgentSpawner {
57 &self.spawner
58 }
59
60 pub fn spawner_mut(&mut self) -> &mut AgentSpawner {
62 &mut self.spawner
63 }
64
65 pub fn bus(&self) -> &MessageBus {
67 &self.bus
68 }
69
70 pub fn bus_mut(&mut self) -> &mut MessageBus {
72 &mut self.bus
73 }
74
75 pub fn router(&self) -> &AgentRouter {
77 &self.router
78 }
79
80 pub fn router_mut(&mut self) -> &mut AgentRouter {
82 &mut self.router
83 }
84
85 pub fn tool_call_count(&self, agent_id: &Uuid) -> u32 {
87 self.tool_call_counts.get(agent_id).copied().unwrap_or(0)
88 }
89
90 pub fn reset_tool_counts(&mut self, agent_id: &Uuid) {
92 self.tool_call_counts.insert(*agent_id, 0);
93 }
94
95 pub fn check_resource_limits(&self, agent_id: &Uuid) -> Result<(), String> {
99 let limits = self
100 .spawner
101 .get(agent_id)
102 .map(|ctx| &ctx.resource_limits)
103 .cloned()
104 .unwrap_or_default();
105
106 if let Some(max_calls) = limits.max_tool_calls {
108 let current = self.tool_call_count(agent_id);
109 if current >= max_calls {
110 return Err(format!(
111 "Agent {} exceeded max_tool_calls limit ({}/{})",
112 agent_id, current, max_calls
113 ));
114 }
115 }
116
117 Ok(())
118 }
119
120 pub async fn process_pending(&mut self) -> usize {
130 let agent_ids: Vec<Uuid> = self
132 .handlers
133 .keys()
134 .filter(|id| self.bus.pending_count(id) > 0)
135 .copied()
136 .collect();
137
138 let mut processed = 0;
139
140 for agent_id in agent_ids {
141 if let Err(reason) = self.check_resource_limits(&agent_id) {
143 if let Some(envelope) = self.bus.receive(&agent_id) {
145 let error_response = AgentEnvelope::new(
146 agent_id,
147 envelope.from,
148 AgentPayload::Error {
149 code: "RESOURCE_LIMIT".into(),
150 message: reason,
151 recoverable: false,
152 },
153 )
154 .with_priority(MessagePriority::High);
155 if let Some(corr) = envelope.correlation_id {
156 let error_response = error_response.with_correlation(corr);
157 let _ = self.bus.send(error_response);
158 } else {
159 let _ = self.bus.send(error_response);
160 }
161 processed += 1;
162 }
163 continue;
164 }
165
166 let envelope = match self.bus.receive(&agent_id) {
168 Some(e) => e,
169 None => continue,
170 };
171
172 match &envelope.payload {
173 AgentPayload::TaskRequest { description, args } => {
174 *self.tool_call_counts.entry(agent_id).or_insert(0) += 1;
176
177 let handler = match self.handlers.get(&agent_id) {
178 Some(h) => h,
179 None => continue,
180 };
181
182 let result = handler.handle_task(description, args).await;
183
184 let response_payload = match result {
185 Ok(output) => AgentPayload::TaskResult {
186 success: true,
187 output,
188 },
189 Err(err) => AgentPayload::TaskResult {
190 success: false,
191 output: err,
192 },
193 };
194
195 let mut response =
196 AgentEnvelope::new(agent_id, envelope.from, response_payload);
197 if let Some(corr) = envelope.correlation_id {
198 response = response.with_correlation(corr);
199 }
200 let _ = self.bus.send(response);
201 processed += 1;
202 }
203 AgentPayload::Shutdown => {
204 self.spawner.terminate(agent_id);
206 self.handlers.remove(&agent_id);
207 self.tool_call_counts.remove(&agent_id);
208 processed += 1;
209 }
210 AgentPayload::StatusQuery => {
211 let pending = self.bus.pending_count(&agent_id);
212 let agent_name = self
213 .spawner
214 .get(&agent_id)
215 .map(|ctx| ctx.name.clone())
216 .unwrap_or_else(|| "unknown".to_string());
217 let response = AgentEnvelope::new(
218 agent_id,
219 envelope.from,
220 AgentPayload::StatusResponse {
221 agent_name,
222 active: true,
223 pending_tasks: pending,
224 },
225 );
226 let _ = self.bus.send(response);
227 processed += 1;
228 }
229 _ => {
230 processed += 1;
232 }
233 }
234 }
235
236 processed
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::multi::spawner::SpawnerConfig;
244
245 struct EchoHandler;
246
247 #[async_trait]
248 impl TaskHandler for EchoHandler {
249 async fn handle_task(
250 &self,
251 description: &str,
252 _args: &HashMap<String, String>,
253 ) -> Result<String, String> {
254 Ok(format!("echo: {}", description))
255 }
256 }
257
258 struct FailHandler;
259
260 #[async_trait]
261 impl TaskHandler for FailHandler {
262 async fn handle_task(
263 &self,
264 _description: &str,
265 _args: &HashMap<String, String>,
266 ) -> Result<String, String> {
267 Err("task failed".to_string())
268 }
269 }
270
271 fn setup_orchestrator() -> (AgentOrchestrator, Uuid) {
272 let mut spawner = AgentSpawner::default();
273 let agent_id = spawner.spawn("test-agent").unwrap();
274
275 let mut bus = MessageBus::new(100);
276 bus.register(agent_id);
277
278 let router = AgentRouter::new();
279 let mut orch = AgentOrchestrator::new(spawner, bus, router);
280 orch.register_handler(agent_id, Box::new(EchoHandler));
281
282 (orch, agent_id)
283 }
284
285 #[tokio::test]
286 async fn test_orchestrator_processes_task_request() {
287 let (mut orch, agent_id) = setup_orchestrator();
288
289 let sender_id = orch.spawner_mut().spawn("sender").unwrap();
291 orch.bus_mut().register(sender_id);
292
293 let task = AgentEnvelope::new(
294 sender_id,
295 agent_id,
296 AgentPayload::TaskRequest {
297 description: "hello world".into(),
298 args: HashMap::new(),
299 },
300 );
301 orch.bus_mut().send(task).unwrap();
302
303 let processed = orch.process_pending().await;
304 assert_eq!(processed, 1);
305
306 let response = orch.bus_mut().receive(&sender_id).unwrap();
308 match &response.payload {
309 AgentPayload::TaskResult { success, output } => {
310 assert!(success);
311 assert_eq!(output, "echo: hello world");
312 }
313 _ => panic!("Expected TaskResult"),
314 }
315 }
316
317 #[tokio::test]
318 async fn test_orchestrator_handles_task_failure() {
319 let mut spawner = AgentSpawner::default();
320 let agent_id = spawner.spawn("fail-agent").unwrap();
321 let sender_id = spawner.spawn("sender").unwrap();
322
323 let mut bus = MessageBus::new(100);
324 bus.register(agent_id);
325 bus.register(sender_id);
326
327 let router = AgentRouter::new();
328 let mut orch = AgentOrchestrator::new(spawner, bus, router);
329 orch.register_handler(agent_id, Box::new(FailHandler));
330
331 let task = AgentEnvelope::new(
332 sender_id,
333 agent_id,
334 AgentPayload::TaskRequest {
335 description: "will fail".into(),
336 args: HashMap::new(),
337 },
338 );
339 orch.bus_mut().send(task).unwrap();
340
341 orch.process_pending().await;
342
343 let response = orch.bus_mut().receive(&sender_id).unwrap();
344 match &response.payload {
345 AgentPayload::TaskResult { success, output } => {
346 assert!(!success);
347 assert_eq!(output, "task failed");
348 }
349 _ => panic!("Expected TaskResult"),
350 }
351 }
352
353 #[tokio::test]
354 async fn test_orchestrator_correlation_id_preserved() {
355 let (mut orch, agent_id) = setup_orchestrator();
356 let sender_id = orch.spawner_mut().spawn("sender").unwrap();
357 orch.bus_mut().register(sender_id);
358
359 let corr_id = Uuid::new_v4();
360 let task = AgentEnvelope::new(
361 sender_id,
362 agent_id,
363 AgentPayload::TaskRequest {
364 description: "correlated".into(),
365 args: HashMap::new(),
366 },
367 )
368 .with_correlation(corr_id);
369 orch.bus_mut().send(task).unwrap();
370
371 orch.process_pending().await;
372
373 let response = orch.bus_mut().receive(&sender_id).unwrap();
374 assert_eq!(response.correlation_id, Some(corr_id));
375 }
376
377 #[tokio::test]
378 async fn test_orchestrator_handles_shutdown() {
379 let (mut orch, agent_id) = setup_orchestrator();
380 let sender_id = orch.spawner_mut().spawn("sender").unwrap();
381 orch.bus_mut().register(sender_id);
382
383 let shutdown = AgentEnvelope::new(sender_id, agent_id, AgentPayload::Shutdown);
384 orch.bus_mut().send(shutdown).unwrap();
385
386 let processed = orch.process_pending().await;
387 assert_eq!(processed, 1);
388
389 assert!(orch.spawner().get(&agent_id).is_none());
391 }
392
393 #[tokio::test]
394 async fn test_orchestrator_handles_status_query() {
395 let (mut orch, agent_id) = setup_orchestrator();
396 let sender_id = orch.spawner_mut().spawn("sender").unwrap();
397 orch.bus_mut().register(sender_id);
398
399 let query = AgentEnvelope::new(sender_id, agent_id, AgentPayload::StatusQuery);
400 orch.bus_mut().send(query).unwrap();
401
402 orch.process_pending().await;
403
404 let response = orch.bus_mut().receive(&sender_id).unwrap();
405 match &response.payload {
406 AgentPayload::StatusResponse {
407 agent_name,
408 active,
409 pending_tasks,
410 } => {
411 assert_eq!(agent_name, "test-agent");
412 assert!(active);
413 assert_eq!(*pending_tasks, 0);
414 }
415 _ => panic!("Expected StatusResponse"),
416 }
417 }
418
419 #[tokio::test]
420 async fn test_orchestrator_respects_tool_call_limit() {
421 let mut spawner = AgentSpawner::new(SpawnerConfig::default());
422 let agent_id = spawner.spawn("limited-agent").unwrap();
423 let sender_id = spawner.spawn("sender").unwrap();
424
425 if let Some(ctx) = spawner.get_mut(&agent_id) {
427 ctx.resource_limits.max_tool_calls = Some(2);
428 }
429
430 let mut bus = MessageBus::new(100);
431 bus.register(agent_id);
432 bus.register(sender_id);
433
434 let router = AgentRouter::new();
435 let mut orch = AgentOrchestrator::new(spawner, bus, router);
436 orch.register_handler(agent_id, Box::new(EchoHandler));
437
438 let task1 = AgentEnvelope::new(
441 sender_id,
442 agent_id,
443 AgentPayload::TaskRequest {
444 description: "task-0".into(),
445 args: HashMap::new(),
446 },
447 );
448 orch.bus_mut().send(task1).unwrap();
449 orch.process_pending().await;
450
451 let r1 = orch.bus_mut().receive(&sender_id).unwrap();
452 match &r1.payload {
453 AgentPayload::TaskResult { success, .. } => assert!(success),
454 other => panic!(
455 "Expected TaskResult, got {:?}",
456 std::mem::discriminant(other)
457 ),
458 }
459
460 let task2 = AgentEnvelope::new(
462 sender_id,
463 agent_id,
464 AgentPayload::TaskRequest {
465 description: "task-1".into(),
466 args: HashMap::new(),
467 },
468 );
469 orch.bus_mut().send(task2).unwrap();
470 orch.process_pending().await;
471
472 let r2 = orch.bus_mut().receive(&sender_id).unwrap();
473 match &r2.payload {
474 AgentPayload::TaskResult { success, .. } => assert!(success),
475 other => panic!(
476 "Expected TaskResult, got {:?}",
477 std::mem::discriminant(other)
478 ),
479 }
480
481 let task3 = AgentEnvelope::new(
483 sender_id,
484 agent_id,
485 AgentPayload::TaskRequest {
486 description: "task-2".into(),
487 args: HashMap::new(),
488 },
489 );
490 orch.bus_mut().send(task3).unwrap();
491 orch.process_pending().await;
492
493 let r3 = orch.bus_mut().receive(&sender_id).unwrap();
494 match &r3.payload {
495 AgentPayload::Error {
496 code, recoverable, ..
497 } => {
498 assert_eq!(code, "RESOURCE_LIMIT");
499 assert!(!recoverable);
500 }
501 other => panic!(
502 "Expected Error for third task, got {:?}",
503 std::mem::discriminant(other)
504 ),
505 }
506 }
507
508 #[test]
509 fn test_tool_call_count_tracking() {
510 let spawner = AgentSpawner::default();
511 let bus = MessageBus::new(100);
512 let router = AgentRouter::new();
513 let mut orch = AgentOrchestrator::new(spawner, bus, router);
514
515 let agent_id = Uuid::new_v4();
516 assert_eq!(orch.tool_call_count(&agent_id), 0);
517
518 orch.tool_call_counts.insert(agent_id, 5);
519 assert_eq!(orch.tool_call_count(&agent_id), 5);
520
521 orch.reset_tool_counts(&agent_id);
522 assert_eq!(orch.tool_call_count(&agent_id), 0);
523 }
524
525 #[tokio::test]
526 async fn test_orchestrator_no_pending_returns_zero() {
527 let (mut orch, _) = setup_orchestrator();
528 let processed = orch.process_pending().await;
529 assert_eq!(processed, 0);
530 }
531
532 #[tokio::test]
533 async fn test_orchestrator_parent_delegates_to_child() {
534 let mut spawner = AgentSpawner::default();
535 let parent_id = spawner.spawn("parent").unwrap();
536 let child_id = spawner.spawn_child("child", parent_id).unwrap();
537
538 let mut bus = MessageBus::new(100);
539 bus.register(parent_id);
540 bus.register(child_id);
541
542 let router = AgentRouter::new();
543 let mut orch = AgentOrchestrator::new(spawner, bus, router);
544 orch.register_handler(child_id, Box::new(EchoHandler));
545
546 let task = AgentEnvelope::new(
548 parent_id,
549 child_id,
550 AgentPayload::TaskRequest {
551 description: "delegated task".into(),
552 args: HashMap::new(),
553 },
554 );
555 orch.bus_mut().send(task).unwrap();
556
557 orch.process_pending().await;
558
559 let response = orch.bus_mut().receive(&parent_id).unwrap();
561 match &response.payload {
562 AgentPayload::TaskResult { success, output } => {
563 assert!(success);
564 assert_eq!(output, "echo: delegated task");
565 }
566 _ => panic!("Expected TaskResult"),
567 }
568 }
569
570 #[tokio::test]
571 async fn test_orchestrator_multiple_agents() {
572 let mut spawner = AgentSpawner::default();
573 let agent_a = spawner.spawn("agent-a").unwrap();
574 let agent_b = spawner.spawn("agent-b").unwrap();
575 let coordinator = spawner.spawn("coordinator").unwrap();
576
577 let mut bus = MessageBus::new(100);
578 bus.register(agent_a);
579 bus.register(agent_b);
580 bus.register(coordinator);
581
582 let router = AgentRouter::new();
583 let mut orch = AgentOrchestrator::new(spawner, bus, router);
584 orch.register_handler(agent_a, Box::new(EchoHandler));
585 orch.register_handler(agent_b, Box::new(EchoHandler));
586
587 let task_a = AgentEnvelope::new(
589 coordinator,
590 agent_a,
591 AgentPayload::TaskRequest {
592 description: "task-for-a".into(),
593 args: HashMap::new(),
594 },
595 );
596 let task_b = AgentEnvelope::new(
597 coordinator,
598 agent_b,
599 AgentPayload::TaskRequest {
600 description: "task-for-b".into(),
601 args: HashMap::new(),
602 },
603 );
604 orch.bus_mut().send(task_a).unwrap();
605 orch.bus_mut().send(task_b).unwrap();
606
607 let processed = orch.process_pending().await;
608 assert_eq!(processed, 2);
609
610 let r1 = orch.bus_mut().receive(&coordinator).unwrap();
612 let r2 = orch.bus_mut().receive(&coordinator).unwrap();
613
614 let mut outputs: Vec<String> = Vec::new();
615 for r in [&r1, &r2] {
616 if let AgentPayload::TaskResult { output, .. } = &r.payload {
617 outputs.push(output.clone());
618 }
619 }
620 outputs.sort();
621 assert_eq!(outputs, vec!["echo: task-for-a", "echo: task-for-b"]);
622 }
623
624 #[test]
625 fn test_check_resource_limits_no_limits() {
626 let mut spawner = AgentSpawner::default();
627 let agent_id = spawner.spawn("no-limits").unwrap();
628 let bus = MessageBus::new(100);
629 let router = AgentRouter::new();
630 let orch = AgentOrchestrator::new(spawner, bus, router);
631 assert!(orch.check_resource_limits(&agent_id).is_ok());
632 }
633
634 #[test]
635 fn test_check_resource_limits_exceeded() {
636 let mut spawner = AgentSpawner::new(SpawnerConfig::default());
637 let agent_id = spawner.spawn("limited").unwrap();
638 if let Some(ctx) = spawner.get_mut(&agent_id) {
639 ctx.resource_limits.max_tool_calls = Some(3);
640 }
641
642 let bus = MessageBus::new(100);
643 let router = AgentRouter::new();
644 let mut orch = AgentOrchestrator::new(spawner, bus, router);
645
646 orch.tool_call_counts.insert(agent_id, 3);
648
649 let result = orch.check_resource_limits(&agent_id);
650 assert!(result.is_err());
651 assert!(result.unwrap_err().contains("max_tool_calls"));
652 }
653}