1use std::collections::HashMap;
61use std::future::Future;
62use std::pin::Pin;
63use std::sync::{
64 atomic::{AtomicUsize, Ordering},
65 Arc,
66};
67
68use async_openai::types::{
69 ChatCompletionRequestMessage, ChatCompletionTool, CreateChatCompletionRequest,
70};
71use serde::{Deserialize, Serialize};
72use serde_json::Value;
73use tower::{BoxError, Layer, Service, ServiceExt};
74use tracing::{debug, error, info, instrument, trace, warn};
75
76use crate::core::{
77 AgentRun, AgentStopReason, AgentSvc, LoopState, StepOutcome, ToolInvocation, ToolOutput,
78};
79
80pub type AgentName = String;
81
82#[derive(Debug, Clone)]
83pub struct PickRequest {
84 pub messages: Vec<async_openai::types::ChatCompletionRequestMessage>,
85 pub last_stop: AgentStopReason,
86}
87
88pub trait AgentPicker: Service<PickRequest, Response = AgentName, Error = BoxError> {}
89impl<T> AgentPicker for T where T: Service<PickRequest, Response = AgentName, Error = BoxError> {}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct HandoffRequest {
98 pub target_agent: String,
100 pub context: Option<Value>,
102 pub reason: Option<String>,
104}
105
106#[derive(Debug, Clone)]
108pub struct HandoffResponse {
109 pub success: bool,
111 pub target_agent: String,
113 pub context: Option<Value>,
115}
116
117#[derive(Debug, Clone)]
119pub enum GroupOutcome {
120 Continue(AgentRun),
122 Handoff(HandoffRequest),
124 Done(AgentRun),
126}
127
128pub trait HandoffPolicy: Send + Sync + 'static {
134 fn handoff_tools(&self) -> Vec<ChatCompletionTool>;
137
138 fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError>;
141
142 fn should_handoff(&self, state: &LoopState, outcome: &StepOutcome) -> Option<HandoffRequest>;
145
146 fn is_handoff_tool(&self, tool_name: &str) -> bool;
148
149 fn transform_on_handoff(
160 &self,
161 messages: Vec<ChatCompletionRequestMessage>,
162 _from_agent: &str,
163 _to_agent: &str,
164 _handoff: &HandoffRequest,
165 ) -> Pin<Box<dyn Future<Output = Result<Vec<ChatCompletionRequestMessage>, BoxError>> + Send>>
166 {
167 Box::pin(async move { Ok(messages) })
168 }
169}
170
171pub struct GroupBuilder<P = (), H = ()> {
172 agents: HashMap<AgentName, AgentSvc>,
173 picker: Option<P>,
174 handoff_policy: Option<H>,
175}
176
177impl GroupBuilder<(), ()> {
178 pub fn new() -> Self {
179 Self {
180 agents: HashMap::new(),
181 picker: None,
182 handoff_policy: None,
183 }
184 }
185}
186
187impl Default for GroupBuilder<(), ()> {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193impl<P, H> GroupBuilder<P, H> {
194 pub fn agent(mut self, name: impl Into<String>, svc: AgentSvc) -> Self {
195 self.agents.insert(name.into(), svc);
196 self
197 }
198
199 pub fn picker<NewP>(self, p: NewP) -> GroupBuilder<NewP, H> {
200 GroupBuilder {
201 agents: self.agents,
202 picker: Some(p),
203 handoff_policy: self.handoff_policy,
204 }
205 }
206
207 pub fn handoff_policy<NewH>(self, policy: NewH) -> GroupBuilder<P, NewH>
209 where
210 NewH: HandoffPolicy + Clone,
211 {
212 GroupBuilder {
213 agents: self.agents,
214 picker: self.picker,
215 handoff_policy: Some(policy),
216 }
217 }
218}
219
220impl<P> GroupBuilder<P, ()> {
221 pub fn build(self) -> GroupRouter<P>
223 where
224 P: AgentPicker + Clone + Send + 'static,
225 P::Future: Send + 'static,
226 {
227 GroupRouter {
228 agents: std::sync::Arc::new(tokio::sync::Mutex::new(self.agents)),
229 picker: self.picker.expect("picker"),
230 }
231 }
232}
233
234impl<P, H> GroupBuilder<P, H>
235where
236 H: HandoffPolicy + Clone + Send + 'static,
237{
238 pub fn build(self) -> HandoffCoordinator<P, H>
240 where
241 P: AgentPicker + Clone + Send + 'static,
242 P::Future: Send + 'static,
243 {
244 HandoffCoordinator::new(
245 self.agents,
246 self.picker.expect("picker"),
247 self.handoff_policy.expect("handoff_policy"),
248 )
249 }
250}
251
252pub struct GroupRouter<P> {
253 agents: std::sync::Arc<tokio::sync::Mutex<HashMap<AgentName, AgentSvc>>>,
254 picker: P,
255}
256
257impl<P> Service<CreateChatCompletionRequest> for GroupRouter<P>
258where
259 P: AgentPicker + Clone + Send + 'static,
260 P::Future: Send + 'static,
261{
262 type Response = AgentRun;
263 type Error = BoxError;
264 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
265
266 fn poll_ready(
267 &mut self,
268 _cx: &mut std::task::Context<'_>,
269 ) -> std::task::Poll<Result<(), Self::Error>> {
270 std::task::Poll::Ready(Ok(()))
271 }
272
273 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
274 let mut picker = self.picker.clone();
275 let agents = self.agents.clone();
276 Box::pin(async move {
277 let pick = ServiceExt::ready(&mut picker)
278 .await?
279 .call(PickRequest {
280 messages: req.messages.clone(),
281 last_stop: AgentStopReason::DoneNoToolCalls,
282 })
283 .await?;
284 let mut guard = agents.lock().await;
285 let agent = guard
286 .get_mut(&pick)
287 .ok_or_else(|| format!("unknown agent: {}", pick))?;
288 let run = ServiceExt::ready(agent).await?.call(req).await?;
289 Ok(run)
290 })
291 }
292}
293
294pub struct HandoffCoordinator<P, H> {
307 agents: Arc<tokio::sync::Mutex<HashMap<AgentName, AgentSvc>>>,
308 picker: P,
309 handoff_policy: H,
310 current_agent: Arc<tokio::sync::Mutex<Option<AgentName>>>,
311 conversation_context:
312 Arc<tokio::sync::Mutex<Vec<async_openai::types::ChatCompletionRequestMessage>>>,
313}
314
315impl<P, H> HandoffCoordinator<P, H>
316where
317 P: AgentPicker,
318 H: HandoffPolicy + Clone,
319{
320 pub fn new(agents: HashMap<AgentName, AgentSvc>, picker: P, handoff_policy: H) -> Self {
322 Self {
323 agents: Arc::new(tokio::sync::Mutex::new(agents)),
324 picker,
325 handoff_policy,
326 current_agent: Arc::new(tokio::sync::Mutex::new(None)),
327 conversation_context: Arc::new(tokio::sync::Mutex::new(Vec::new())),
328 }
329 }
330
331 pub fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
333 self.handoff_policy.handoff_tools()
334 }
335}
336
337impl<P, H> Service<CreateChatCompletionRequest> for HandoffCoordinator<P, H>
338where
339 P: AgentPicker + Clone + Send + 'static,
340 P::Future: Send + 'static,
341 H: HandoffPolicy + Clone + Send + 'static,
342{
343 type Response = AgentRun;
344 type Error = BoxError;
345 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
346
347 fn poll_ready(
348 &mut self,
349 _cx: &mut std::task::Context<'_>,
350 ) -> std::task::Poll<Result<(), Self::Error>> {
351 std::task::Poll::Ready(Ok(()))
352 }
353
354 #[instrument(skip_all, fields(request_id = %uuid::Uuid::new_v4()))]
355 fn call(&mut self, mut request: CreateChatCompletionRequest) -> Self::Future {
356 let agents = self.agents.clone();
357 let mut picker = self.picker.clone();
358 let handoff_policy = self.handoff_policy.clone();
359 let current_agent = self.current_agent.clone();
360 let conversation_context = self.conversation_context.clone();
361
362 Box::pin(async move {
363 const MAX_HANDOFFS: usize = 10;
364 let mut handoff_count = 0;
365 let mut all_messages = Vec::new();
366 let mut total_steps = 0;
367
368 info!("🚀 Starting handoff coordinator");
369 debug!("Initial request has {} messages", request.messages.len());
370
371 let original_messages = request.messages.clone();
373
374 debug!("Invoking picker to determine initial agent");
376 let initial_pick = ServiceExt::ready(&mut picker)
377 .await?
378 .call(PickRequest {
379 messages: request.messages.clone(),
380 last_stop: AgentStopReason::DoneNoToolCalls,
381 })
382 .await?;
383
384 info!("📍 Initial agent selected: {}", initial_pick);
385 let mut current_agent_name = initial_pick;
386
387 {
389 let mut current = current_agent.lock().await;
390 *current = Some(current_agent_name.clone());
391 }
392
393 loop {
395 if handoff_count >= MAX_HANDOFFS {
397 error!("❌ Maximum handoffs exceeded ({})", MAX_HANDOFFS);
398 return Err("Maximum handoffs exceeded".into());
399 }
400
401 info!(
402 "🤖 Executing agent: {} (handoff #{}/{})",
403 current_agent_name,
404 handoff_count + 1,
405 MAX_HANDOFFS
406 );
407
408 let mut agents_guard = agents.lock().await;
410 let agent = agents_guard.get_mut(¤t_agent_name).ok_or_else(|| {
411 error!("Agent not found: {}", current_agent_name);
412 format!("Unknown agent: {}", current_agent_name)
413 })?;
414
415 let handoff_tools = handoff_policy.handoff_tools();
417 if !handoff_tools.is_empty() {
418 debug!(
419 "Injecting {} handoff tools into request",
420 handoff_tools.len()
421 );
422 trace!(
423 "Handoff tools: {:?}",
424 handoff_tools
425 .iter()
426 .map(|t| &t.function.name)
427 .collect::<Vec<_>>()
428 );
429
430 if request.tools.is_none() {
432 request.tools = Some(handoff_tools);
433 } else {
434 request.tools.as_mut().unwrap().extend(handoff_tools);
436 }
437 }
438
439 debug!("Calling agent with {} messages", request.messages.len());
441 let delta_start = request.messages.len();
443 let agent_run = ServiceExt::ready(agent)
444 .await?
445 .call(request.clone())
446 .await?;
447
448 info!(
449 "✅ Agent {} completed: {} messages, {} steps, stop reason: {:?}",
450 current_agent_name,
451 agent_run.messages.len(),
452 agent_run.steps,
453 agent_run.stop
454 );
455
456 all_messages.extend(agent_run.messages.clone());
458 total_steps += agent_run.steps;
459
460 let mut handoff_requested = None;
462
463 debug!("Checking for handoff tool calls in agent response");
465 for message in agent_run.messages.iter().skip(delta_start) {
466 if let async_openai::types::ChatCompletionRequestMessage::Assistant(msg) =
467 message
468 {
469 if let Some(tool_calls) = &msg.tool_calls {
470 trace!("Found {} tool calls in message", tool_calls.len());
471 for tool_call in tool_calls {
472 if handoff_policy.is_handoff_tool(&tool_call.function.name) {
473 info!("🔄 Handoff tool detected: {}", tool_call.function.name);
474
475 let invocation = ToolInvocation {
477 id: tool_call.id.clone(),
478 name: tool_call.function.name.clone(),
479 arguments: serde_json::from_str(
480 &tool_call.function.arguments,
481 )
482 .unwrap_or_else(|e| {
483 warn!("Failed to parse handoff tool arguments: {}", e);
484 serde_json::json!({})
485 }),
486 };
487
488 match handoff_policy.handle_handoff_tool(&invocation) {
489 Ok(handoff_req) => {
490 info!(
491 "📋 Handoff request: {} → {} (reason: {:?})",
492 current_agent_name,
493 handoff_req.target_agent,
494 handoff_req.reason
495 );
496 handoff_requested = Some(handoff_req);
497 break;
498 }
499 Err(e) => {
500 warn!("Failed to handle handoff tool: {}", e);
501 }
502 }
503 }
504 }
505 }
506 }
507 if handoff_requested.is_some() {
508 break;
509 }
510 }
511
512 if handoff_requested.is_none() {
514 debug!(
515 "No explicit handoff tool called, checking policy for automatic handoff"
516 );
517
518 let loop_state = LoopState { steps: total_steps };
520
521 let step_outcome = if matches!(agent_run.stop, AgentStopReason::DoneNoToolCalls)
523 {
524 StepOutcome::Done {
525 messages: agent_run.messages.clone(),
526 aux: crate::core::StepAux::default(),
527 }
528 } else {
529 StepOutcome::Next {
530 messages: agent_run.messages.clone(),
531 aux: crate::core::StepAux::default(),
532 invoked_tools: vec![],
533 }
534 };
535
536 if let Some(handoff) = handoff_policy.should_handoff(&loop_state, &step_outcome)
537 {
538 info!(
539 "🔀 Automatic handoff triggered by policy: {} → {} (reason: {:?})",
540 current_agent_name, handoff.target_agent, handoff.reason
541 );
542 handoff_requested = Some(handoff);
543 } else {
544 debug!("No automatic handoff triggered");
545 }
546 }
547
548 if let Some(handoff) = handoff_requested {
550 info!(
551 "🚦 Processing handoff: {} → {}",
552 current_agent_name, handoff.target_agent
553 );
554
555 let previous_agent = current_agent_name.clone();
557 current_agent_name = handoff.target_agent.clone();
558 handoff_count += 1;
559
560 {
562 let mut current = current_agent.lock().await;
563 *current = Some(current_agent_name.clone());
564 }
565
566 {
568 let mut context = conversation_context.lock().await;
569 context.extend(agent_run.messages.clone());
570 debug!(
571 "Updated conversation context with {} messages",
572 agent_run.messages.len()
573 );
574 }
575
576 let mut messages_for_next = original_messages.clone();
578 messages_for_next.extend(all_messages.clone());
579
580 debug!(
582 "Applying handoff transformation from {} to {}",
583 previous_agent, current_agent_name
584 );
585 match handoff_policy
586 .transform_on_handoff(
587 messages_for_next,
588 &previous_agent,
589 ¤t_agent_name,
590 &handoff,
591 )
592 .await
593 {
594 Ok(transformed) => {
595 debug!(
596 "Handoff transformation complete: {} -> {} messages",
597 original_messages.len() + all_messages.len(),
598 transformed.len()
599 );
600 request.messages = transformed;
601 }
602 Err(e) => {
603 warn!(
604 "Handoff transformation failed, using original messages: {}",
605 e
606 );
607 request.messages = original_messages.clone();
608 request.messages.extend(all_messages.clone());
609 }
610 }
611
612 info!(
613 "🔗 Handoff complete: {} → {} (total handoffs: {})",
614 previous_agent, current_agent_name, handoff_count
615 );
616
617 continue;
619 }
620
621 info!(
623 "🎯 Workflow complete: {} total messages, {} steps, final agent: {}",
624 all_messages.len(),
625 total_steps,
626 current_agent_name
627 );
628
629 return Ok(AgentRun {
630 messages: all_messages,
631 steps: total_steps,
632 stop: agent_run.stop,
633 });
634 }
635 })
636 }
637}
638
639#[derive(Debug, Clone)]
646pub struct ExplicitHandoffPolicy {
647 target_agent: String,
648 tool_name: Option<String>,
649 description: Option<String>,
650}
651
652impl ExplicitHandoffPolicy {
653 pub fn new(target_agent: impl Into<String>) -> Self {
655 Self {
656 target_agent: target_agent.into(),
657 tool_name: None,
658 description: None,
659 }
660 }
661
662 pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
664 self.tool_name = Some(name.into());
665 self
666 }
667
668 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
670 self.description = Some(desc.into());
671 self
672 }
673
674 fn tool_name(&self) -> String {
675 self.tool_name
676 .clone()
677 .unwrap_or_else(|| format!("handoff_to_{}", self.target_agent))
678 }
679}
680
681impl HandoffPolicy for ExplicitHandoffPolicy {
682 #[instrument(skip(self))]
683 fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
684 let tool_name = self.tool_name();
685 let description = self
686 .description
687 .clone()
688 .unwrap_or_else(|| format!("Hand off the conversation to {}", self.target_agent));
689
690 debug!(
691 "ExplicitHandoffPolicy generating tool: {} → {}",
692 tool_name, self.target_agent
693 );
694
695 vec![ChatCompletionTool {
696 r#type: async_openai::types::ChatCompletionToolType::Function,
697 function: async_openai::types::FunctionObject {
698 name: tool_name,
699 description: Some(description),
700 parameters: Some(serde_json::json!({
701 "type": "object",
702 "properties": {
703 "reason": {
704 "type": "string",
705 "description": "Reason for the handoff"
706 },
707 "context": {
708 "type": "object",
709 "description": "Optional context to pass to the target agent"
710 }
711 }
712 })),
713 ..Default::default()
714 },
715 }]
716 }
717
718 fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
719 if !self.is_handoff_tool(&invocation.name) {
720 return Err(format!("Not a handoff tool: {}", invocation.name).into());
721 }
722
723 let reason = invocation
724 .arguments
725 .get("reason")
726 .and_then(|v| v.as_str())
727 .map(|s| s.to_string());
728
729 let context = invocation.arguments.get("context").cloned();
730
731 Ok(HandoffRequest {
732 target_agent: self.target_agent.clone(),
733 context,
734 reason,
735 })
736 }
737
738 fn should_handoff(&self, _state: &LoopState, _outcome: &StepOutcome) -> Option<HandoffRequest> {
739 None
741 }
742
743 fn is_handoff_tool(&self, tool_name: &str) -> bool {
744 tool_name == self.tool_name()
745 }
746}
747
748#[derive(Debug, Clone)]
751pub struct SequentialHandoffPolicy {
752 agents: Vec<String>,
753 current_index: Arc<AtomicUsize>,
754}
755
756impl SequentialHandoffPolicy {
757 pub fn new(agents: Vec<String>) -> Self {
759 Self {
760 agents,
761 current_index: Arc::new(AtomicUsize::new(0)),
762 }
763 }
764
765 fn next_agent(&self) -> Option<String> {
766 let current = self.current_index.fetch_add(1, Ordering::SeqCst);
767 if current + 1 < self.agents.len() {
768 Some(self.agents[current + 1].clone())
769 } else {
770 None
771 }
772 }
773}
774
775impl HandoffPolicy for SequentialHandoffPolicy {
776 #[instrument(skip(self))]
777 fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
778 debug!("SequentialHandoffPolicy: no tools (automatic handoffs only)");
780 vec![]
781 }
782
783 #[instrument(skip(self, invocation))]
784 fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
785 warn!(
786 "Sequential policy received unexpected handoff tool call: {}",
787 invocation.name
788 );
789 Err(format!(
790 "Sequential policy has no handoff tools: {}",
791 invocation.name
792 )
793 .into())
794 }
795
796 #[instrument(skip(self, _state, outcome))]
797 fn should_handoff(&self, _state: &LoopState, outcome: &StepOutcome) -> Option<HandoffRequest> {
798 match outcome {
799 StepOutcome::Done { .. } => {
800 if let Some(target) = self.next_agent() {
802 let current_idx = self.current_index.load(Ordering::SeqCst);
803 info!(
804 "📈 Sequential handoff: step {}/{} → {}",
805 current_idx,
806 self.agents.len(),
807 target
808 );
809 Some(HandoffRequest {
810 target_agent: target,
811 context: None,
812 reason: Some("Sequential workflow step complete".to_string()),
813 })
814 } else {
815 debug!("Sequential workflow complete (all steps finished)");
816 None
817 }
818 }
819 _ => {
820 trace!("Sequential policy: no handoff (agent not done)");
821 None
822 }
823 }
824 }
825
826 fn is_handoff_tool(&self, _tool_name: &str) -> bool {
827 false }
829}
830
831#[derive(Debug, Clone)]
834pub struct MultiExplicitHandoffPolicy {
835 handoffs: HashMap<String, String>,
836}
837
838impl MultiExplicitHandoffPolicy {
839 pub fn new(handoffs: HashMap<String, String>) -> Self {
841 Self { handoffs }
842 }
843
844 pub fn add_handoff(mut self, tool_name: impl Into<String>, target: impl Into<String>) -> Self {
846 self.handoffs.insert(tool_name.into(), target.into());
847 self
848 }
849}
850
851impl HandoffPolicy for MultiExplicitHandoffPolicy {
852 #[instrument(skip(self))]
853 fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
854 debug!(
855 "MultiExplicitHandoffPolicy generating {} handoff tools",
856 self.handoffs.len()
857 );
858 self.handoffs
859 .iter()
860 .map(|(tool_name, target_agent)| {
861 trace!(" Tool: {} → {}", tool_name, target_agent);
862 ChatCompletionTool {
863 r#type: async_openai::types::ChatCompletionToolType::Function,
864 function: async_openai::types::FunctionObject {
865 name: tool_name.clone(),
866 description: Some(format!("Hand off the conversation to {}", target_agent)),
867 parameters: Some(serde_json::json!({
868 "type": "object",
869 "properties": {
870 "reason": {
871 "type": "string",
872 "description": "Reason for the handoff"
873 },
874 "context": {
875 "type": "object",
876 "description": "Optional context to pass to the target agent"
877 }
878 }
879 })),
880 ..Default::default()
881 },
882 }
883 })
884 .collect()
885 }
886
887 fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
888 let target_agent = self
889 .handoffs
890 .get(&invocation.name)
891 .ok_or_else(|| format!("Not a handoff tool: {}", invocation.name))?;
892
893 let reason = invocation
894 .arguments
895 .get("reason")
896 .and_then(|v| v.as_str())
897 .map(|s| s.to_string());
898
899 let context = invocation.arguments.get("context").cloned();
900
901 Ok(HandoffRequest {
902 target_agent: target_agent.clone(),
903 context,
904 reason,
905 })
906 }
907
908 fn should_handoff(&self, _state: &LoopState, _outcome: &StepOutcome) -> Option<HandoffRequest> {
909 None
911 }
912
913 fn is_handoff_tool(&self, tool_name: &str) -> bool {
914 self.handoffs.contains_key(tool_name)
915 }
916}
917
918#[derive(Debug, Clone)]
920pub enum AnyHandoffPolicy {
921 Explicit(ExplicitHandoffPolicy),
922 MultiExplicit(MultiExplicitHandoffPolicy),
923 Sequential(SequentialHandoffPolicy),
924 Composite(CompositeHandoffPolicy),
925}
926
927impl From<ExplicitHandoffPolicy> for AnyHandoffPolicy {
928 fn from(policy: ExplicitHandoffPolicy) -> Self {
929 AnyHandoffPolicy::Explicit(policy)
930 }
931}
932
933impl From<SequentialHandoffPolicy> for AnyHandoffPolicy {
934 fn from(policy: SequentialHandoffPolicy) -> Self {
935 AnyHandoffPolicy::Sequential(policy)
936 }
937}
938
939impl From<MultiExplicitHandoffPolicy> for AnyHandoffPolicy {
940 fn from(policy: MultiExplicitHandoffPolicy) -> Self {
941 AnyHandoffPolicy::MultiExplicit(policy)
942 }
943}
944
945impl From<CompositeHandoffPolicy> for AnyHandoffPolicy {
946 fn from(policy: CompositeHandoffPolicy) -> Self {
947 AnyHandoffPolicy::Composite(policy)
948 }
949}
950
951impl HandoffPolicy for AnyHandoffPolicy {
952 fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
953 match self {
954 AnyHandoffPolicy::Explicit(p) => p.handoff_tools(),
955 AnyHandoffPolicy::MultiExplicit(p) => p.handoff_tools(),
956 AnyHandoffPolicy::Sequential(p) => p.handoff_tools(),
957 AnyHandoffPolicy::Composite(p) => p.handoff_tools(),
958 }
959 }
960
961 fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
962 match self {
963 AnyHandoffPolicy::Explicit(p) => p.handle_handoff_tool(invocation),
964 AnyHandoffPolicy::MultiExplicit(p) => p.handle_handoff_tool(invocation),
965 AnyHandoffPolicy::Sequential(p) => p.handle_handoff_tool(invocation),
966 AnyHandoffPolicy::Composite(p) => p.handle_handoff_tool(invocation),
967 }
968 }
969
970 fn should_handoff(&self, state: &LoopState, outcome: &StepOutcome) -> Option<HandoffRequest> {
971 match self {
972 AnyHandoffPolicy::Explicit(p) => p.should_handoff(state, outcome),
973 AnyHandoffPolicy::MultiExplicit(p) => p.should_handoff(state, outcome),
974 AnyHandoffPolicy::Sequential(p) => p.should_handoff(state, outcome),
975 AnyHandoffPolicy::Composite(p) => p.should_handoff(state, outcome),
976 }
977 }
978
979 fn is_handoff_tool(&self, tool_name: &str) -> bool {
980 match self {
981 AnyHandoffPolicy::Explicit(p) => p.is_handoff_tool(tool_name),
982 AnyHandoffPolicy::MultiExplicit(p) => p.is_handoff_tool(tool_name),
983 AnyHandoffPolicy::Sequential(p) => p.is_handoff_tool(tool_name),
984 AnyHandoffPolicy::Composite(p) => p.is_handoff_tool(tool_name),
985 }
986 }
987}
988
989#[derive(Debug, Clone)]
991pub struct CompositeHandoffPolicy {
992 policies: Vec<AnyHandoffPolicy>,
993}
994
995impl CompositeHandoffPolicy {
996 pub fn new(policies: Vec<AnyHandoffPolicy>) -> Self {
998 Self { policies }
999 }
1000}
1001
1002impl HandoffPolicy for CompositeHandoffPolicy {
1003 fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
1004 self.policies
1005 .iter()
1006 .flat_map(|p| p.handoff_tools())
1007 .collect()
1008 }
1009
1010 fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
1011 for policy in &self.policies {
1012 if policy.is_handoff_tool(&invocation.name) {
1013 return policy.handle_handoff_tool(invocation);
1014 }
1015 }
1016 Err(format!("No policy handles handoff tool: {}", invocation.name).into())
1017 }
1018
1019 fn should_handoff(&self, state: &LoopState, outcome: &StepOutcome) -> Option<HandoffRequest> {
1020 for policy in &self.policies {
1022 if let Some(handoff) = policy.should_handoff(state, outcome) {
1023 return Some(handoff);
1024 }
1025 }
1026 None
1027 }
1028
1029 fn is_handoff_tool(&self, tool_name: &str) -> bool {
1030 self.policies.iter().any(|p| p.is_handoff_tool(tool_name))
1031 }
1032}
1033
1034#[derive(Clone)]
1044pub struct CompactingHandoffPolicy<P> {
1045 inner: P,
1046 compaction_policy: crate::auto_compaction::CompactionPolicy,
1047 provider: Arc<tokio::sync::Mutex<crate::provider::OpenAIProvider>>,
1048 token_counter: Arc<crate::auto_compaction::SimpleTokenCounter>,
1049}
1050
1051impl<P> CompactingHandoffPolicy<P> {
1052 pub fn new(
1054 inner: P,
1055 compaction_policy: crate::auto_compaction::CompactionPolicy,
1056 provider: Arc<tokio::sync::Mutex<crate::provider::OpenAIProvider>>,
1057 ) -> Self {
1058 Self {
1059 inner,
1060 compaction_policy,
1061 provider,
1062 token_counter: Arc::new(crate::auto_compaction::SimpleTokenCounter::new()),
1063 }
1064 }
1065}
1066
1067struct DummyStepService;
1069
1070impl Service<CreateChatCompletionRequest> for DummyStepService {
1071 type Response = crate::core::StepOutcome;
1072 type Error = BoxError;
1073 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1074
1075 fn poll_ready(
1076 &mut self,
1077 _: &mut std::task::Context<'_>,
1078 ) -> std::task::Poll<Result<(), Self::Error>> {
1079 std::task::Poll::Ready(Ok(()))
1080 }
1081
1082 fn call(&mut self, _: CreateChatCompletionRequest) -> Self::Future {
1083 Box::pin(async {
1084 Ok(crate::core::StepOutcome::Done {
1085 messages: vec![],
1086 aux: Default::default(),
1087 })
1088 })
1089 }
1090}
1091
1092impl<P> HandoffPolicy for CompactingHandoffPolicy<P>
1093where
1094 P: HandoffPolicy + Clone,
1095{
1096 fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
1097 self.inner.handoff_tools()
1098 }
1099
1100 fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
1101 self.inner.handle_handoff_tool(invocation)
1102 }
1103
1104 fn should_handoff(&self, state: &LoopState, outcome: &StepOutcome) -> Option<HandoffRequest> {
1105 self.inner.should_handoff(state, outcome)
1106 }
1107
1108 fn is_handoff_tool(&self, tool_name: &str) -> bool {
1109 self.inner.is_handoff_tool(tool_name)
1110 }
1111
1112 fn transform_on_handoff(
1113 &self,
1114 messages: Vec<ChatCompletionRequestMessage>,
1115 from_agent: &str,
1116 to_agent: &str,
1117 handoff: &HandoffRequest,
1118 ) -> Pin<Box<dyn Future<Output = Result<Vec<ChatCompletionRequestMessage>, BoxError>> + Send>>
1119 {
1120 let compaction_policy = self.compaction_policy.clone();
1121 let provider = self.provider.clone();
1122 let token_counter = self.token_counter.clone();
1123 let inner_policy = self.inner.clone();
1124 let from_agent = from_agent.to_string();
1125 let to_agent = to_agent.to_string();
1126 let handoff = handoff.clone();
1127
1128 Box::pin(async move {
1129 let messages = inner_policy
1131 .transform_on_handoff(messages, &from_agent, &to_agent, &handoff)
1132 .await?;
1133
1134 tracing::debug!(
1136 "Applying compaction during handoff from {} to {}",
1137 from_agent,
1138 to_agent
1139 );
1140
1141 let compactor = crate::auto_compaction::AutoCompaction::<
1143 DummyStepService,
1144 crate::provider::OpenAIProvider,
1145 crate::auto_compaction::SimpleTokenCounter,
1146 > {
1147 inner: Arc::new(tokio::sync::Mutex::new(DummyStepService)),
1148 policy: compaction_policy,
1149 provider,
1150 token_counter,
1151 };
1152
1153 match compactor.compact_messages(messages.clone()).await {
1154 Ok(compacted) => {
1155 tracing::info!(
1156 "Successfully compacted messages during handoff: {} -> {} messages",
1157 messages.len(),
1158 compacted.len()
1159 );
1160 Ok(compacted)
1161 }
1162 Err(e) => {
1163 tracing::warn!(
1164 "Compaction failed during handoff, using original messages: {}",
1165 e
1166 );
1167 Ok(messages)
1168 }
1169 }
1170 })
1171 }
1172}
1173
1174pub fn explicit_handoff_to(target: impl Into<String>) -> ExplicitHandoffPolicy {
1186 ExplicitHandoffPolicy::new(target)
1187}
1188
1189pub fn sequential_handoff(agents: Vec<String>) -> SequentialHandoffPolicy {
1196 SequentialHandoffPolicy::new(agents)
1197}
1198
1199pub fn composite_handoff(policies: Vec<AnyHandoffPolicy>) -> CompositeHandoffPolicy {
1209 CompositeHandoffPolicy::new(policies)
1210}
1211
1212#[derive(Debug, Clone)]
1218pub enum ToolOutputResult {
1219 Tool(ToolOutput),
1221 Handoff(HandoffRequest),
1223}
1224
1225impl From<ToolOutput> for ToolOutputResult {
1226 fn from(output: ToolOutput) -> Self {
1227 ToolOutputResult::Tool(output)
1228 }
1229}
1230
1231#[derive(Debug, Clone)]
1239pub struct HandoffLayer<P> {
1240 handoff_policy: P,
1241}
1242
1243impl<P> HandoffLayer<P>
1244where
1245 P: HandoffPolicy,
1246{
1247 pub fn new(policy: P) -> Self {
1249 Self {
1250 handoff_policy: policy,
1251 }
1252 }
1253
1254 pub fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
1256 self.handoff_policy.handoff_tools()
1257 }
1258}
1259
1260impl<S, P> Layer<S> for HandoffLayer<P>
1261where
1262 P: HandoffPolicy + Clone,
1263{
1264 type Service = HandoffService<S, P>;
1265
1266 fn layer(&self, inner: S) -> Self::Service {
1267 HandoffService::new(inner, self.handoff_policy.clone())
1268 }
1269}
1270
1271#[derive(Debug, Clone)]
1277pub struct HandoffService<S, P> {
1278 inner: S,
1279 handoff_policy: P,
1280}
1281
1282pub fn layer_tool_router_with_handoff<P>(
1284 router: crate::core::ToolRouter,
1285 policy: P,
1286) -> crate::core::ToolSvc
1287where
1288 P: HandoffPolicy + Clone,
1289{
1290 let layer = HandoffLayer::new(policy);
1291 let svc = layer.layer(router);
1292 tower::util::BoxCloneService::new(svc)
1293}
1294
1295impl<S, P> HandoffService<S, P>
1296where
1297 P: HandoffPolicy,
1298{
1299 pub fn new(inner: S, policy: P) -> Self {
1301 Self {
1302 inner,
1303 handoff_policy: policy,
1304 }
1305 }
1306}
1307
1308impl<S, P> Service<ToolInvocation> for HandoffService<S, P>
1309where
1310 S: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Send + 'static,
1311 S::Future: Send + 'static,
1312 P: HandoffPolicy + Clone,
1313{
1314 type Response = ToolOutput;
1315 type Error = BoxError;
1316 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1317
1318 fn poll_ready(
1319 &mut self,
1320 cx: &mut std::task::Context<'_>,
1321 ) -> std::task::Poll<Result<(), Self::Error>> {
1322 self.inner.poll_ready(cx)
1323 }
1324
1325 fn call(&mut self, req: ToolInvocation) -> Self::Future {
1326 if self.handoff_policy.is_handoff_tool(&req.name) {
1328 let policy = self.handoff_policy.clone();
1330 Box::pin(async move {
1331 let handoff_request = policy.handle_handoff_tool(&req)?;
1332 let out = ToolOutput {
1333 id: req.id,
1334 result: serde_json::json!({
1335 "handoff": handoff_request.target_agent,
1336 "ack": true,
1337 "reason": handoff_request.reason,
1338 "context": handoff_request.context,
1339 }),
1340 };
1341 Ok(out)
1342 })
1343 } else {
1344 let future = self.inner.call(req);
1346 Box::pin(future)
1347 }
1348 }
1349}
1350
1351#[cfg(test)]
1352mod tests {
1353 use super::*;
1354 use async_openai::types::{CreateChatCompletionRequest, CreateChatCompletionRequestArgs};
1355 use tower::util::BoxService;
1356
1357 mod handoff_policy_tests {
1362 use super::*;
1363 use async_openai::types::ChatCompletionRequestUserMessageArgs;
1364
1365 #[test]
1366 fn explicit_handoff_policy_generates_correct_tools() {
1367 let policy =
1368 explicit_handoff_to("specialist").with_description("Escalate to specialist");
1369
1370 let tools = policy.handoff_tools();
1371 assert_eq!(tools.len(), 1);
1372
1373 let tool = &tools[0];
1374 assert_eq!(tool.function.name, "handoff_to_specialist");
1375 assert_eq!(
1376 tool.function.description,
1377 Some("Escalate to specialist".to_string())
1378 );
1379
1380 let params = tool.function.parameters.as_ref().unwrap();
1382 assert!(params.get("properties").is_some());
1383 assert!(params["properties"]["reason"].is_object());
1384 }
1385
1386 #[test]
1387 fn explicit_handoff_policy_handles_tool_calls() {
1388 let policy = explicit_handoff_to("specialist");
1389
1390 let invocation = ToolInvocation {
1391 id: "test_id".to_string(),
1392 name: "handoff_to_specialist".to_string(),
1393 arguments: serde_json::json!({
1394 "reason": "Complex technical issue",
1395 "context": {"priority": "high"}
1396 }),
1397 };
1398
1399 let result = policy.handle_handoff_tool(&invocation).unwrap();
1400 assert_eq!(result.target_agent, "specialist");
1401 assert_eq!(result.reason, Some("Complex technical issue".to_string()));
1402 assert!(result.context.is_some());
1403 }
1404
1405 #[test]
1406 fn explicit_handoff_policy_rejects_wrong_tools() {
1407 let policy = explicit_handoff_to("specialist");
1408
1409 let invocation = ToolInvocation {
1410 id: "test_id".to_string(),
1411 name: "some_other_tool".to_string(),
1412 arguments: serde_json::json!({}),
1413 };
1414
1415 let result = policy.handle_handoff_tool(&invocation);
1416 assert!(result.is_err());
1417 assert!(result
1418 .unwrap_err()
1419 .to_string()
1420 .contains("Not a handoff tool"));
1421 }
1422
1423 #[test]
1424 fn explicit_handoff_policy_no_automatic_handoffs() {
1425 let policy = explicit_handoff_to("specialist");
1426 let state = LoopState { steps: 1 };
1427 let outcome = StepOutcome::Done {
1428 messages: vec![],
1429 aux: crate::core::StepAux::default(),
1430 };
1431
1432 assert!(policy.should_handoff(&state, &outcome).is_none());
1434 }
1435
1436 #[tokio::test]
1437 async fn handoff_policy_default_transformation() {
1438 let policy = explicit_handoff_to("specialist");
1440
1441 let messages = vec![ChatCompletionRequestUserMessageArgs::default()
1442 .content("Test message")
1443 .build()
1444 .unwrap()
1445 .into()];
1446
1447 let handoff = HandoffRequest {
1448 target_agent: "specialist".to_string(),
1449 context: None,
1450 reason: Some("Test handoff".to_string()),
1451 };
1452
1453 let result = policy
1454 .transform_on_handoff(messages.clone(), "agent1", "specialist", &handoff)
1455 .await
1456 .unwrap();
1457
1458 assert_eq!(result.len(), messages.len());
1459 assert_eq!(result, messages);
1460 }
1461
1462 #[test]
1463 fn sequential_handoff_policy_advances_correctly() {
1464 let agents = vec!["a".to_string(), "b".to_string(), "c".to_string()];
1465 let policy = sequential_handoff(agents.clone());
1466
1467 let state = LoopState { steps: 1 };
1468 let outcome = StepOutcome::Done {
1469 messages: vec![],
1470 aux: crate::core::StepAux::default(),
1471 };
1472
1473 let handoff1 = policy.should_handoff(&state, &outcome).unwrap();
1475 assert_eq!(handoff1.target_agent, "b");
1476 assert!(handoff1.reason.is_some());
1477
1478 let handoff2 = policy.should_handoff(&state, &outcome).unwrap();
1480 assert_eq!(handoff2.target_agent, "c");
1481
1482 let handoff3 = policy.should_handoff(&state, &outcome);
1484 assert!(handoff3.is_none());
1485 }
1486
1487 #[test]
1488 fn sequential_handoff_policy_no_tools() {
1489 let policy = sequential_handoff(vec!["a".to_string(), "b".to_string()]);
1490
1491 assert!(policy.handoff_tools().is_empty());
1493 assert!(!policy.is_handoff_tool("any_tool"));
1494 }
1495
1496 #[test]
1497 fn composite_handoff_policy_combines_tools() {
1498 let explicit1 = explicit_handoff_to("specialist");
1499 let explicit2 = explicit_handoff_to("supervisor");
1500 let sequential = sequential_handoff(vec!["a".to_string(), "b".to_string()]);
1501
1502 let composite = composite_handoff(vec![
1503 AnyHandoffPolicy::Explicit(explicit1),
1504 AnyHandoffPolicy::Explicit(explicit2),
1505 AnyHandoffPolicy::Sequential(sequential),
1506 ]);
1507
1508 let tools = composite.handoff_tools();
1509 assert_eq!(tools.len(), 2);
1511
1512 let tool_names: Vec<&str> = tools.iter().map(|t| t.function.name.as_str()).collect();
1513 assert!(tool_names.contains(&"handoff_to_specialist"));
1514 assert!(tool_names.contains(&"handoff_to_supervisor"));
1515 }
1516
1517 #[test]
1518 fn composite_handoff_policy_routes_to_correct_handler() {
1519 let explicit = explicit_handoff_to("specialist");
1520 let sequential = sequential_handoff(vec!["a".to_string()]);
1521
1522 let composite = composite_handoff(vec![
1523 AnyHandoffPolicy::Explicit(explicit),
1524 AnyHandoffPolicy::Sequential(sequential),
1525 ]);
1526
1527 let invocation = ToolInvocation {
1528 id: "test_id".to_string(),
1529 name: "handoff_to_specialist".to_string(),
1530 arguments: serde_json::json!({"reason": "test"}),
1531 };
1532
1533 let result = composite.handle_handoff_tool(&invocation).unwrap();
1534 assert_eq!(result.target_agent, "specialist");
1535 }
1536
1537 #[test]
1538 fn composite_handoff_policy_first_match_wins() {
1539 let explicit = explicit_handoff_to("specialist");
1540 let sequential = sequential_handoff(vec!["a".to_string(), "b".to_string()]);
1541
1542 let composite = composite_handoff(vec![
1543 AnyHandoffPolicy::Explicit(explicit),
1544 AnyHandoffPolicy::Sequential(sequential),
1545 ]);
1546
1547 let state = LoopState { steps: 1 };
1548 let outcome = StepOutcome::Done {
1549 messages: vec![],
1550 aux: crate::core::StepAux::default(),
1551 };
1552
1553 let result = composite.should_handoff(&state, &outcome).unwrap();
1556 assert_eq!(result.target_agent, "b"); }
1558
1559 #[test]
1560 fn any_handoff_policy_conversions_work() {
1561 let explicit = explicit_handoff_to("specialist");
1562 let sequential = sequential_handoff(vec!["a".to_string()]);
1563
1564 let _any1: AnyHandoffPolicy = explicit.into();
1566 let _any2: AnyHandoffPolicy = sequential.into();
1567
1568 }
1570
1571 #[test]
1572 fn handoff_request_serialization() {
1573 let request = HandoffRequest {
1574 target_agent: "specialist".to_string(),
1575 context: Some(serde_json::json!({"priority": "high"})),
1576 reason: Some("Complex issue".to_string()),
1577 };
1578
1579 let json = serde_json::to_string(&request).unwrap();
1581 let deserialized: HandoffRequest = serde_json::from_str(&json).unwrap();
1582
1583 assert_eq!(deserialized.target_agent, request.target_agent);
1584 assert_eq!(deserialized.reason, request.reason);
1585 assert_eq!(deserialized.context, request.context);
1586 }
1587 }
1588
1589 #[tokio::test]
1590 async fn per_agent_instructions_applied_to_each_agent_request() {
1591 use crate::core::{Agent, CompositePolicy};
1592 use crate::provider::ProviderResponse;
1593 use async_openai::config::OpenAIConfig;
1594 use async_openai::types::{ChatCompletionResponseMessage, Role as RespRole};
1595 use std::sync::Arc as StdArc;
1596
1597 #[derive(Clone)]
1599 struct CapturingProvider {
1600 captured: StdArc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>>,
1601 }
1602 impl tower::Service<CreateChatCompletionRequest> for CapturingProvider {
1603 type Response = ProviderResponse;
1604 type Error = BoxError;
1605 type Future = std::pin::Pin<
1606 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
1607 >;
1608 fn poll_ready(
1609 &mut self,
1610 _cx: &mut std::task::Context<'_>,
1611 ) -> std::task::Poll<Result<(), Self::Error>> {
1612 std::task::Poll::Ready(Ok(()))
1613 }
1614 fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
1615 let captured = self.captured.clone();
1616 Box::pin(async move {
1617 *captured.lock().await = Some(req);
1618 #[allow(deprecated)]
1619 let assistant = ChatCompletionResponseMessage {
1620 content: Some("ok".into()),
1621 role: RespRole::Assistant,
1622 tool_calls: None,
1623 function_call: None,
1624 refusal: None,
1625 audio: None,
1626 };
1627 Ok(ProviderResponse {
1628 assistant,
1629 prompt_tokens: 1,
1630 completion_tokens: 1,
1631 })
1632 })
1633 }
1634 }
1635
1636 let client = StdArc::new(async_openai::Client::<OpenAIConfig>::new());
1637
1638 let cap_a: StdArc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>> =
1640 StdArc::new(tokio::sync::Mutex::new(None));
1641 let provider_a = CapturingProvider {
1642 captured: cap_a.clone(),
1643 };
1644 let agent_a = Agent::builder(client.clone())
1645 .with_provider(provider_a)
1646 .model("gpt-4o")
1647 .instructions("A")
1648 .policy(CompositePolicy::new(vec![
1649 crate::core::policies::until_no_tool_calls(),
1650 crate::core::policies::max_steps(1),
1651 ]))
1652 .build();
1653
1654 let cap_b: StdArc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>> =
1656 StdArc::new(tokio::sync::Mutex::new(None));
1657 let provider_b = CapturingProvider {
1658 captured: cap_b.clone(),
1659 };
1660 let agent_b = Agent::builder(client.clone())
1661 .with_provider(provider_b)
1662 .model("gpt-4o")
1663 .instructions("B")
1664 .policy(CompositePolicy::new(vec![
1665 crate::core::policies::until_no_tool_calls(),
1666 crate::core::policies::max_steps(1),
1667 ]))
1668 .build();
1669
1670 let mut agents = std::collections::HashMap::new();
1672 agents.insert("a".to_string(), agent_a);
1673 agents.insert("b".to_string(), agent_b);
1674 let picker =
1675 tower::service_fn(|_pr: PickRequest| async move { Ok::<_, BoxError>("a".to_string()) });
1676 let policy = SequentialHandoffPolicy::new(vec!["a".into(), "b".into()]);
1677 let mut coord = HandoffCoordinator::new(agents, picker, policy);
1678
1679 let req = CreateChatCompletionRequestArgs::default()
1681 .model("gpt-4o")
1682 .messages(vec![
1683 async_openai::types::ChatCompletionRequestUserMessageArgs::default()
1684 .content("hi")
1685 .build()
1686 .unwrap()
1687 .into(),
1688 ])
1689 .build()
1690 .unwrap();
1691
1692 let _run = tower::ServiceExt::ready(&mut coord)
1693 .await
1694 .unwrap()
1695 .call(req)
1696 .await
1697 .unwrap();
1698
1699 let got_a = cap_a.lock().await.clone().expect("captured request a");
1701 match &got_a.messages[0] {
1702 ChatCompletionRequestMessage::System(s) => match &s.content {
1703 async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1704 assert_eq!(t, "A");
1705 }
1706 _ => panic!("expected text content"),
1707 },
1708 _ => panic!("expected first message to be system"),
1709 }
1710 let got_b = cap_b.lock().await.clone().expect("captured request b");
1711 match &got_b.messages[0] {
1712 ChatCompletionRequestMessage::System(s) => match &s.content {
1713 async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1714 assert_eq!(t, "B");
1715 }
1716 _ => panic!("expected text content"),
1717 },
1718 _ => panic!("expected first message to be system"),
1719 }
1720 }
1721
1722 mod handoff_layer_tests {
1727 use super::*;
1728 use tower::{service_fn, ServiceExt};
1729
1730 fn mock_tool_service() -> tower::util::BoxService<ToolInvocation, ToolOutput, BoxError> {
1732 tower::util::BoxService::new(service_fn(|req: ToolInvocation| async move {
1733 Ok(ToolOutput {
1734 id: req.id,
1735 result: serde_json::json!({"status": "success", "tool": req.name}),
1736 })
1737 }))
1738 }
1739
1740 #[tokio::test]
1741 async fn handoff_layer_passes_through_regular_tools() {
1742 let policy = explicit_handoff_to("specialist");
1743 let layer = HandoffLayer::new(policy);
1744 let mut service = layer.layer(mock_tool_service());
1745
1746 let invocation = ToolInvocation {
1747 id: "test_id".to_string(),
1748 name: "regular_tool".to_string(),
1749 arguments: serde_json::json!({"param": "value"}),
1750 };
1751
1752 let result = ServiceExt::ready(&mut service)
1753 .await
1754 .unwrap()
1755 .call(invocation)
1756 .await
1757 .unwrap();
1758 assert_eq!(result.id, "test_id");
1759 assert_eq!(result.result["tool"], "regular_tool");
1760 }
1761
1762 #[tokio::test]
1763 async fn handoff_layer_intercepts_handoff_tools() {
1764 let policy = explicit_handoff_to("specialist");
1765 let layer = HandoffLayer::new(policy);
1766 let mut service = layer.layer(mock_tool_service());
1767
1768 let invocation = ToolInvocation {
1769 id: "test_id".to_string(),
1770 name: "handoff_to_specialist".to_string(),
1771 arguments: serde_json::json!({"reason": "Complex issue"}),
1772 };
1773
1774 let result = ServiceExt::ready(&mut service)
1775 .await
1776 .unwrap()
1777 .call(invocation)
1778 .await
1779 .unwrap();
1780 assert_eq!(result.id, "test_id");
1781 assert_eq!(result.result["handoff"], "specialist");
1782 }
1783
1784 #[tokio::test]
1785 async fn handoff_layer_exposes_policy_tools() {
1786 let policy =
1787 explicit_handoff_to("specialist").with_description("Escalate to specialist");
1788 let layer = HandoffLayer::new(policy);
1789
1790 let tools = layer.handoff_tools();
1791 assert_eq!(tools.len(), 1);
1792 assert_eq!(tools[0].function.name, "handoff_to_specialist");
1793 assert_eq!(
1794 tools[0].function.description,
1795 Some("Escalate to specialist".to_string())
1796 );
1797 }
1798
1799 #[tokio::test]
1800 async fn handoff_layer_with_composite_policy() {
1801 let composite = composite_handoff(vec![
1802 AnyHandoffPolicy::Explicit(explicit_handoff_to("specialist")),
1803 AnyHandoffPolicy::Explicit(explicit_handoff_to("supervisor")),
1804 ]);
1805
1806 let layer = HandoffLayer::new(composite.clone());
1807 let mut service = layer.layer(mock_tool_service());
1808
1809 let invocation1 = ToolInvocation {
1811 id: "test_id1".to_string(),
1812 name: "handoff_to_specialist".to_string(),
1813 arguments: serde_json::json!({"reason": "Technical issue"}),
1814 };
1815
1816 let result1 = ServiceExt::ready(&mut service)
1817 .await
1818 .unwrap()
1819 .call(invocation1)
1820 .await
1821 .unwrap();
1822 assert_eq!(result1.result["handoff"], "specialist");
1823
1824 let invocation2 = ToolInvocation {
1826 id: "test_id2".to_string(),
1827 name: "handoff_to_supervisor".to_string(),
1828 arguments: serde_json::json!({"reason": "Escalation needed"}),
1829 };
1830
1831 let result2 = ServiceExt::ready(&mut service)
1832 .await
1833 .unwrap()
1834 .call(invocation2)
1835 .await
1836 .unwrap();
1837 assert_eq!(result2.result["handoff"], "supervisor");
1838
1839 let tools = layer.handoff_tools();
1841 assert_eq!(tools.len(), 2);
1842 let tool_names: Vec<&str> = tools.iter().map(|t| t.function.name.as_str()).collect();
1843 assert!(tool_names.contains(&"handoff_to_specialist"));
1844 assert!(tool_names.contains(&"handoff_to_supervisor"));
1845 }
1846
1847 #[tokio::test]
1848 async fn handoff_layer_error_handling() {
1849 let policy = explicit_handoff_to("specialist");
1850 let layer = HandoffLayer::new(policy);
1851 let mut service = layer.layer(mock_tool_service());
1852
1853 let invocation = ToolInvocation {
1855 id: "test_id".to_string(),
1856 name: "handoff_to_specialist".to_string(),
1857 arguments: serde_json::json!({}), };
1859
1860 let result = ServiceExt::ready(&mut service)
1862 .await
1863 .unwrap()
1864 .call(invocation)
1865 .await
1866 .unwrap();
1867 assert_eq!(result.result["handoff"], "specialist");
1868 }
1869 }
1870
1871 mod handoff_coordinator_tests {
1876 #![allow(deprecated)]
1877 use super::*;
1878 use crate::provider::{FixedProvider, ProviderResponse};
1879 use crate::validation::{gen, validate_conversation, ValidationPolicy};
1880 use crate::Agent;
1881 use async_openai::{config::OpenAIConfig, Client};
1882 use proptest::prelude::*;
1883 use std::sync::Arc;
1884 use tower::{service_fn, ServiceExt};
1885
1886 fn mock_agent(name: &'static str, response: &'static str) -> AgentSvc {
1888 let name = name.to_string();
1889 let response = response.to_string();
1890 tower::util::BoxService::new(service_fn(move |_req: CreateChatCompletionRequest| {
1891 let name = name.clone();
1892 let response = response.clone();
1893 async move {
1894 use async_openai::types::{
1895 ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
1896 };
1897
1898 let message = ChatCompletionRequestAssistantMessageArgs::default()
1899 .content(format!("[{}]: {}", name, response))
1900 .build()?;
1901
1902 Ok::<AgentRun, BoxError>(AgentRun {
1903 messages: vec![ChatCompletionRequestMessage::Assistant(message)],
1904 steps: 1,
1905 stop: AgentStopReason::DoneNoToolCalls,
1906 })
1907 }
1908 }))
1909 }
1910
1911 #[derive(Clone)]
1913 struct MockPicker;
1914
1915 impl Service<PickRequest> for MockPicker {
1916 type Response = String;
1917 type Error = BoxError;
1918 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1919
1920 fn poll_ready(
1921 &mut self,
1922 _: &mut std::task::Context<'_>,
1923 ) -> std::task::Poll<Result<(), Self::Error>> {
1924 std::task::Poll::Ready(Ok(()))
1925 }
1926
1927 fn call(&mut self, req: PickRequest) -> Self::Future {
1928 Box::pin(async move {
1929 let content = req.messages.first()
1930 .and_then(|msg| match msg {
1931 async_openai::types::ChatCompletionRequestMessage::User(user_msg) => {
1932 match &user_msg.content {
1933 async_openai::types::ChatCompletionRequestUserMessageContent::Text(text) => {
1934 Some(text.as_str())
1935 }
1936 _ => None,
1937 }
1938 }
1939 _ => None,
1940 })
1941 .unwrap_or("");
1942
1943 let agent = if content.contains("billing") {
1944 "billing_agent"
1945 } else if content.contains("technical") {
1946 "tech_agent"
1947 } else {
1948 "triage_agent"
1949 };
1950
1951 Ok::<String, BoxError>(agent.to_string())
1952 })
1953 }
1954 }
1955
1956 fn mock_picker() -> MockPicker {
1957 MockPicker
1958 }
1959
1960 #[tokio::test]
1961 async fn handoff_coordinator_basic_operation() -> Result<(), BoxError> {
1962 let coordinator = GroupBuilder::new()
1963 .agent(
1964 "triage_agent",
1965 mock_agent("triage", "I'll handle your request"),
1966 )
1967 .agent(
1968 "billing_agent",
1969 mock_agent("billing", "Billing issue resolved"),
1970 )
1971 .picker(mock_picker())
1972 .handoff_policy(explicit_handoff_to("billing_agent"))
1973 .build();
1974
1975 let mut service = coordinator;
1976
1977 use async_openai::types::{
1978 ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
1979 };
1980 let user_message = ChatCompletionRequestUserMessageArgs::default()
1981 .content("I have a general question")
1982 .build()?;
1983
1984 let request = CreateChatCompletionRequest {
1985 messages: vec![ChatCompletionRequestMessage::User(user_message)],
1986 model: "gpt-4o".to_string(),
1987 ..Default::default()
1988 };
1989
1990 let result = ServiceExt::ready(&mut service).await?.call(request).await?;
1991
1992 assert_eq!(result.messages.len(), 1);
1994 assert!(
1995 format!("{:?}", result.messages[0]).contains("[triage]: I'll handle your request")
1996 );
1997 assert_eq!(result.steps, 1);
1998 let policy = ValidationPolicy {
1999 allow_repeated_roles: true,
2000 require_user_first: false,
2001 require_user_present: false,
2002 ..Default::default()
2003 };
2004 assert!(validate_conversation(&result.messages, &policy).is_none());
2005
2006 Ok(())
2007 }
2008
2009 #[tokio::test]
2010 async fn handoff_coordinator_sequential_workflow() -> Result<(), BoxError> {
2011 let sequential_policy = sequential_handoff(vec![
2012 "researcher".to_string(),
2013 "writer".to_string(),
2014 "reviewer".to_string(),
2015 ]);
2016
2017 #[derive(Clone)]
2018 struct ResearchPicker;
2019 impl Service<PickRequest> for ResearchPicker {
2020 type Response = String;
2021 type Error = BoxError;
2022 type Future =
2023 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2024
2025 fn poll_ready(
2026 &mut self,
2027 _: &mut std::task::Context<'_>,
2028 ) -> std::task::Poll<Result<(), Self::Error>> {
2029 std::task::Poll::Ready(Ok(()))
2030 }
2031
2032 fn call(&mut self, _req: PickRequest) -> Self::Future {
2033 Box::pin(async move { Ok::<String, BoxError>("researcher".to_string()) })
2034 }
2035 }
2036
2037 let coordinator = GroupBuilder::new()
2038 .agent("researcher", mock_agent("researcher", "Research complete"))
2039 .agent("writer", mock_agent("writer", "Article written"))
2040 .agent("reviewer", mock_agent("reviewer", "Review complete"))
2041 .picker(ResearchPicker)
2042 .handoff_policy(sequential_policy)
2043 .build();
2044
2045 let mut service = coordinator;
2046
2047 use async_openai::types::{
2048 ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
2049 };
2050 let user_message = ChatCompletionRequestUserMessageArgs::default()
2051 .content("Write an article about AI")
2052 .build()?;
2053
2054 let request = CreateChatCompletionRequest {
2055 messages: vec![ChatCompletionRequestMessage::User(user_message)],
2056 model: "gpt-4o".to_string(),
2057 ..Default::default()
2058 };
2059
2060 let result = ServiceExt::ready(&mut service).await?.call(request).await?;
2061
2062 assert!(!result.messages.is_empty());
2064 assert!(format!("{:?}", result.messages[0]).contains("[researcher]: Research complete"));
2067 let policy = ValidationPolicy {
2068 allow_repeated_roles: true,
2069 require_user_first: false,
2070 require_user_present: false,
2071 ..Default::default()
2072 };
2073 assert!(validate_conversation(&result.messages, &policy).is_none());
2074
2075 Ok(())
2076 }
2077
2078 #[tokio::test]
2079 async fn handoff_coordinator_picker_routing() -> Result<(), BoxError> {
2080 let coordinator = GroupBuilder::new()
2081 .agent("triage_agent", mock_agent("triage", "General help"))
2082 .agent("billing_agent", mock_agent("billing", "Billing help"))
2083 .agent("tech_agent", mock_agent("tech", "Technical help"))
2084 .picker(mock_picker())
2085 .handoff_policy(explicit_handoff_to("tech_agent"))
2086 .build();
2087
2088 let mut service = coordinator;
2089
2090 use async_openai::types::{
2092 ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
2093 };
2094 let billing_message = ChatCompletionRequestUserMessageArgs::default()
2095 .content("I have a billing question")
2096 .build()?;
2097
2098 let billing_request = CreateChatCompletionRequest {
2099 messages: vec![ChatCompletionRequestMessage::User(billing_message)],
2100 model: "gpt-4o".to_string(),
2101 ..Default::default()
2102 };
2103
2104 let result = ServiceExt::ready(&mut service)
2105 .await?
2106 .call(billing_request)
2107 .await?;
2108
2109 assert!(format!("{:?}", result.messages[0]).contains("billing"));
2111 let policy = ValidationPolicy {
2112 allow_repeated_roles: true,
2113 require_user_first: false,
2114 require_user_present: false,
2115 ..Default::default()
2116 };
2117 assert!(validate_conversation(&result.messages, &policy).is_none());
2118
2119 let tech_message = ChatCompletionRequestUserMessageArgs::default()
2121 .content("I have a technical issue")
2122 .build()?;
2123
2124 let tech_request = CreateChatCompletionRequest {
2125 messages: vec![ChatCompletionRequestMessage::User(tech_message)],
2126 model: "gpt-4o".to_string(),
2127 ..Default::default()
2128 };
2129
2130 let result2 = ServiceExt::ready(&mut service)
2131 .await?
2132 .call(tech_request)
2133 .await?;
2134
2135 assert!(format!("{:?}", result2.messages[0]).contains("tech"));
2137 let policy = ValidationPolicy {
2138 allow_repeated_roles: true,
2139 require_user_first: false,
2140 require_user_present: false,
2141 ..Default::default()
2142 };
2143 assert!(validate_conversation(&result2.messages, &policy).is_none());
2144
2145 Ok(())
2146 }
2147
2148 #[tokio::test]
2149 async fn handoff_coordinator_composite_policy() -> Result<(), BoxError> {
2150 let composite_policy = composite_handoff(vec![
2151 AnyHandoffPolicy::Explicit(explicit_handoff_to("specialist")),
2152 AnyHandoffPolicy::Sequential(sequential_handoff(vec![
2153 "a".to_string(),
2154 "b".to_string(),
2155 ])),
2156 ]);
2157
2158 #[derive(Clone)]
2159 struct TriagePicker;
2160 impl Service<PickRequest> for TriagePicker {
2161 type Response = String;
2162 type Error = BoxError;
2163 type Future =
2164 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2165
2166 fn poll_ready(
2167 &mut self,
2168 _: &mut std::task::Context<'_>,
2169 ) -> std::task::Poll<Result<(), Self::Error>> {
2170 std::task::Poll::Ready(Ok(()))
2171 }
2172
2173 fn call(&mut self, _req: PickRequest) -> Self::Future {
2174 Box::pin(async move { Ok::<String, BoxError>("triage".to_string()) })
2175 }
2176 }
2177
2178 let coordinator = GroupBuilder::new()
2179 .agent("triage", mock_agent("triage", "Triage response"))
2180 .agent(
2181 "specialist",
2182 mock_agent("specialist", "Specialist response"),
2183 )
2184 .agent("a", mock_agent("a", "Agent A response"))
2185 .agent("b", mock_agent("b", "Agent B response"))
2186 .picker(TriagePicker)
2187 .handoff_policy(composite_policy)
2188 .build();
2189
2190 let tools = coordinator.handoff_tools();
2192 assert_eq!(tools.len(), 1); assert_eq!(tools[0].function.name, "handoff_to_specialist");
2194
2195 let mut service = coordinator;
2196
2197 use async_openai::types::{
2198 ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
2199 };
2200 let user_message = ChatCompletionRequestUserMessageArgs::default()
2201 .content("Help me")
2202 .build()?;
2203
2204 let request = CreateChatCompletionRequest {
2205 messages: vec![ChatCompletionRequestMessage::User(user_message)],
2206 model: "gpt-4o".to_string(),
2207 ..Default::default()
2208 };
2209
2210 let result = ServiceExt::ready(&mut service).await?.call(request).await?;
2211
2212 assert!(!result.messages.is_empty());
2214 assert!(format!("{:?}", result.messages[0]).contains("[triage]: Triage response"));
2215 let policy = ValidationPolicy {
2216 allow_repeated_roles: true,
2217 require_user_first: false,
2218 require_user_present: false,
2219 ..Default::default()
2220 };
2221 assert!(validate_conversation(&result.messages, &policy).is_none());
2222
2223 Ok(())
2224 }
2225
2226 #[tokio::test]
2227 async fn handoff_coordinator_error_handling() -> Result<(), BoxError> {
2228 #[derive(Clone)]
2229 struct AgentAPicker;
2230 impl Service<PickRequest> for AgentAPicker {
2231 type Response = String;
2232 type Error = BoxError;
2233 type Future =
2234 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2235
2236 fn poll_ready(
2237 &mut self,
2238 _: &mut std::task::Context<'_>,
2239 ) -> std::task::Poll<Result<(), Self::Error>> {
2240 std::task::Poll::Ready(Ok(()))
2241 }
2242
2243 fn call(&mut self, _req: PickRequest) -> Self::Future {
2244 Box::pin(async move { Ok::<String, BoxError>("agent_a".to_string()) })
2245 }
2246 }
2247
2248 let coordinator = GroupBuilder::new()
2249 .agent("agent_a", mock_agent("a", "Response A"))
2250 .picker(AgentAPicker)
2251 .handoff_policy(explicit_handoff_to("nonexistent_agent"))
2252 .build();
2253
2254 let mut service = coordinator;
2255
2256 use async_openai::types::{
2257 ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
2258 };
2259 let user_message = ChatCompletionRequestUserMessageArgs::default()
2260 .content("Test message")
2261 .build()?;
2262
2263 let request = CreateChatCompletionRequest {
2264 messages: vec![ChatCompletionRequestMessage::User(user_message)],
2265 model: "gpt-4o".to_string(),
2266 ..Default::default()
2267 };
2268
2269 let result = ServiceExt::ready(&mut service).await?.call(request).await?;
2271
2272 assert_eq!(result.messages.len(), 1);
2273 assert!(format!("{:?}", result.messages[0]).contains("[a]: Response A"));
2274
2275 Ok(())
2276 }
2277
2278 #[tokio::test]
2279 async fn handoff_coordinator_max_handoffs_protection() -> Result<(), BoxError> {
2280 struct InfiniteHandoffPolicy;
2282
2283 impl HandoffPolicy for InfiniteHandoffPolicy {
2284 fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
2285 vec![]
2286 }
2287 fn handle_handoff_tool(
2288 &self,
2289 _: &ToolInvocation,
2290 ) -> Result<HandoffRequest, BoxError> {
2291 Err("No tools".into())
2292 }
2293 fn should_handoff(&self, _: &LoopState, _: &StepOutcome) -> Option<HandoffRequest> {
2294 Some(HandoffRequest {
2295 target_agent: "agent_b".to_string(),
2296 context: None,
2297 reason: Some("Infinite handoff test".to_string()),
2298 })
2299 }
2300 fn is_handoff_tool(&self, _: &str) -> bool {
2301 false
2302 }
2303 }
2304
2305 impl Clone for InfiniteHandoffPolicy {
2306 fn clone(&self) -> Self {
2307 InfiniteHandoffPolicy
2308 }
2309 }
2310
2311 #[derive(Clone)]
2312 struct StartPicker;
2313 impl Service<PickRequest> for StartPicker {
2314 type Response = String;
2315 type Error = BoxError;
2316 type Future =
2317 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2318
2319 fn poll_ready(
2320 &mut self,
2321 _: &mut std::task::Context<'_>,
2322 ) -> std::task::Poll<Result<(), Self::Error>> {
2323 std::task::Poll::Ready(Ok(()))
2324 }
2325
2326 fn call(&mut self, _req: PickRequest) -> Self::Future {
2327 Box::pin(async move { Ok::<String, BoxError>("agent_a".to_string()) })
2328 }
2329 }
2330
2331 let coordinator = GroupBuilder::new()
2332 .agent("agent_a", mock_agent("a", "Response A"))
2333 .agent("agent_b", mock_agent("b", "Response B"))
2334 .picker(StartPicker)
2335 .handoff_policy(InfiniteHandoffPolicy)
2336 .build();
2337
2338 let mut service = coordinator;
2339
2340 use async_openai::types::{
2341 ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
2342 };
2343 let user_message = ChatCompletionRequestUserMessageArgs::default()
2344 .content("Test infinite handoff protection")
2345 .build()?;
2346
2347 let request = CreateChatCompletionRequest {
2348 messages: vec![ChatCompletionRequestMessage::User(user_message)],
2349 model: "gpt-4o".to_string(),
2350 ..Default::default()
2351 };
2352
2353 let result = ServiceExt::ready(&mut service).await?.call(request).await;
2354
2355 assert!(result.is_err());
2357 assert!(result
2358 .unwrap_err()
2359 .to_string()
2360 .contains("Maximum handoffs exceeded"));
2361
2362 Ok(())
2363 }
2364
2365 #[tokio::test]
2366 async fn handoff_coordinator_explicit_tool_transition() -> Result<(), BoxError> {
2367 let tool_name = "handoff_to_specialist".to_string();
2370 let tc = async_openai::types::ChatCompletionMessageToolCall {
2371 id: "call_1".to_string(),
2372 r#type: async_openai::types::ChatCompletionToolType::Function,
2373 function: async_openai::types::FunctionCall {
2374 name: tool_name.clone(),
2375 arguments: "{\"reason\":\"escalate\"}".to_string(),
2376 },
2377 };
2378 let assistant_triage = async_openai::types::ChatCompletionResponseMessage {
2379 content: None,
2380 role: async_openai::types::Role::Assistant,
2381 tool_calls: Some(vec![tc]),
2382 function_call: None,
2383 refusal: None,
2384 audio: None,
2385 };
2386 let triage_provider = FixedProvider::new(ProviderResponse {
2387 assistant: assistant_triage,
2388 prompt_tokens: 1,
2389 completion_tokens: 1,
2390 });
2391
2392 let assistant_specialist = async_openai::types::ChatCompletionResponseMessage {
2394 content: Some("[specialist]: done".to_string()),
2395 role: async_openai::types::Role::Assistant,
2396 tool_calls: None,
2397 function_call: None,
2398 refusal: None,
2399 audio: None,
2400 };
2401 let specialist_provider = FixedProvider::new(ProviderResponse {
2402 assistant: assistant_specialist,
2403 prompt_tokens: 1,
2404 completion_tokens: 1,
2405 });
2406
2407 let client = Arc::new(Client::<OpenAIConfig>::new());
2408 let triage_agent = Agent::builder(client.clone())
2409 .model("gpt-4o")
2410 .handoff_policy(explicit_handoff_to("specialist").into())
2411 .with_provider(triage_provider)
2412 .policy(crate::CompositePolicy::new(vec![
2413 crate::core::policies::max_steps(2),
2414 ]))
2415 .build();
2416 let specialist_agent = Agent::builder(client.clone())
2417 .model("gpt-4o")
2418 .with_provider(specialist_provider)
2419 .policy(crate::CompositePolicy::new(vec![
2420 crate::core::policies::max_steps(1),
2421 ]))
2422 .build();
2423
2424 #[derive(Clone)]
2425 struct Picker;
2426 impl Service<PickRequest> for Picker {
2427 type Response = String;
2428 type Error = BoxError;
2429 type Future =
2430 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2431 fn poll_ready(
2432 &mut self,
2433 _cx: &mut std::task::Context<'_>,
2434 ) -> std::task::Poll<Result<(), Self::Error>> {
2435 std::task::Poll::Ready(Ok(()))
2436 }
2437 fn call(&mut self, _req: PickRequest) -> Self::Future {
2438 Box::pin(async move { Ok::<_, BoxError>("triage".to_string()) })
2439 }
2440 }
2441
2442 let coordinator = GroupBuilder::new()
2443 .agent("triage", triage_agent)
2444 .agent("specialist", specialist_agent)
2445 .picker(Picker)
2446 .handoff_policy(explicit_handoff_to("specialist"))
2447 .build();
2448
2449 let mut svc = coordinator;
2450 let user = async_openai::types::ChatCompletionRequestUserMessageArgs::default()
2451 .content("start")
2452 .build()?;
2453 let req = CreateChatCompletionRequest {
2454 messages: vec![async_openai::types::ChatCompletionRequestMessage::User(
2455 user,
2456 )],
2457 model: "gpt-4o".to_string(),
2458 ..Default::default()
2459 };
2460
2461 let out = ServiceExt::ready(&mut svc).await?.call(req).await?;
2462 assert!(format!("{:?}", out.messages).contains("[specialist]: done"));
2464 let policy = ValidationPolicy {
2465 allow_repeated_roles: true,
2466 ..Default::default()
2467 };
2468 assert!(validate_conversation(&out.messages, &policy).is_none());
2469 Ok(())
2470 }
2471
2472 proptest! {
2473 #[test]
2474 fn handoff_preserves_validity_for_valid_inputs(msgs in gen::valid_conversation(gen::GeneratorConfig::default())) {
2475 use async_openai::types::CreateChatCompletionRequest;
2476 let tool_name = "handoff_to_specialist".to_string();
2478 let tc = async_openai::types::ChatCompletionMessageToolCall {
2479 id: "call_prop".to_string(),
2480 r#type: async_openai::types::ChatCompletionToolType::Function,
2481 function: async_openai::types::FunctionCall {
2482 name: tool_name.clone(),
2483 arguments: "{\"reason\":\"pbt\"}".to_string(),
2484 },
2485 };
2486 let assistant_triage = async_openai::types::ChatCompletionResponseMessage {
2487 content: None,
2488 role: async_openai::types::Role::Assistant,
2489 tool_calls: Some(vec![tc]),
2490 function_call: None,
2491 refusal: None,
2492 audio: None,
2493 };
2494 let triage_provider = FixedProvider::new(ProviderResponse {
2495 assistant: assistant_triage,
2496 prompt_tokens: 1,
2497 completion_tokens: 1,
2498 });
2499
2500 let assistant_specialist = async_openai::types::ChatCompletionResponseMessage {
2501 content: Some("[specialist]: ok".to_string()),
2502 role: async_openai::types::Role::Assistant,
2503 tool_calls: None,
2504 function_call: None,
2505 refusal: None,
2506 audio: None,
2507 };
2508 let specialist_provider = FixedProvider::new(ProviderResponse {
2509 assistant: assistant_specialist,
2510 prompt_tokens: 1,
2511 completion_tokens: 1,
2512 });
2513
2514 let client = Arc::new(Client::<OpenAIConfig>::new());
2515 let triage_agent = Agent::builder(client.clone())
2516 .model("gpt-4o")
2517 .handoff_policy(explicit_handoff_to("specialist").into())
2518 .with_provider(triage_provider)
2519 .policy(crate::CompositePolicy::new(vec![crate::core::policies::max_steps(2)]))
2520 .build();
2521 let specialist_agent = Agent::builder(client.clone())
2522 .model("gpt-4o")
2523 .with_provider(specialist_provider)
2524 .policy(crate::CompositePolicy::new(vec![crate::core::policies::max_steps(1)]))
2525 .build();
2526
2527 #[derive(Clone)]
2528 struct Picker;
2529 impl Service<PickRequest> for Picker {
2530 type Response = String;
2531 type Error = BoxError;
2532 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2533 fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> { std::task::Poll::Ready(Ok(())) }
2534 fn call(&mut self, _req: PickRequest) -> Self::Future { Box::pin(async move { Ok::<_, BoxError>("triage".to_string()) }) }
2535 }
2536
2537 let coordinator = GroupBuilder::new()
2538 .agent("triage", triage_agent)
2539 .agent("specialist", specialist_agent)
2540 .picker(Picker)
2541 .handoff_policy(explicit_handoff_to("specialist"))
2542 .build();
2543
2544 let mut svc = coordinator;
2545 let req = CreateChatCompletionRequest { messages: msgs, model: "gpt-4o".to_string(), ..Default::default() };
2546 let result = futures::executor::block_on(async move { ServiceExt::ready(&mut svc).await.unwrap().call(req).await }).unwrap();
2547 let policy = ValidationPolicy { allow_repeated_roles: true, allow_system_anywhere: true, require_user_first: false, require_user_present: false, ..Default::default() };
2548 prop_assert!(validate_conversation(&result.messages, &policy).is_none());
2549 }
2550 }
2551 }
2552
2553 #[tokio::test]
2558 async fn routes_to_named_agent() {
2559 let a: AgentSvc = BoxService::new(tower::service_fn(
2560 |_r: CreateChatCompletionRequest| async move {
2561 Ok::<_, BoxError>(AgentRun {
2562 messages: vec![],
2563 steps: 1,
2564 stop: AgentStopReason::DoneNoToolCalls,
2565 })
2566 },
2567 ));
2568 let picker =
2569 tower::service_fn(|_pr: PickRequest| async move { Ok::<_, BoxError>("a".to_string()) });
2570 let router = GroupBuilder::new().agent("a", a).picker(picker).build();
2571 let mut svc = router;
2572 let req = CreateChatCompletionRequestArgs::default()
2573 .model("gpt-4o")
2574 .messages(vec![])
2575 .build()
2576 .unwrap();
2577 let run = ServiceExt::ready(&mut svc)
2578 .await
2579 .unwrap()
2580 .call(req)
2581 .await
2582 .unwrap();
2583 assert_eq!(run.steps, 1);
2584 }
2585}