tower_llm/groups/
mod.rs

1//! Multi-agent orchestration and handoffs
2//!
3//! This module provides two key abstractions for agent coordination:
4//!
5//! ## AgentPicker - "WHO starts the conversation?"
6//! - Routes initial messages to appropriate agents
7//! - One-time decision at conversation start
8//! - Based on message content, user context, etc.
9//!
10//! ## HandoffPolicy - "HOW do agents collaborate?"  
11//! - Defines handoff tools and triggers
12//! - Runtime decisions during conversation
13//! - Supports explicit tools, sequential workflows, conditional logic
14//!
15//! ## Key Distinction
16//! - **Picker**: Choose starting agent based on conversation context
17//! - **Policy**: Define collaboration patterns between agents
18//!
19//! These work together but serve different purposes:
20//! ```rust,ignore
21//! let group = GroupBuilder::new()
22//!     .picker(route_by_topic())           // WHO: Route by message topic
23//!     .handoff_policy(explicit_handoffs()) // HOW: Agents use handoff tools
24//!     .build();
25//! ```
26//!
27//! What this module provides (spec)
28//! - A Tower-native router between multiple agent services with explicit handoff events
29//!
30//! Exports
31//! - Models
32//!   - `AgentName` newtype
33//!   - `PickRequest { messages, last_stop: AgentStopReason }`
34//!   - `HandoffRequest` - request for agent handoff
35//!   - `HandoffResponse` - result of handoff attempt
36//! - Services
37//!   - `GroupRouter: Service<RawChatRequest, Response=AgentRun>`
38//!   - `AgentPicker: Service<PickRequest, Response=AgentName>`
39//! - Layers
40//!   - `HandoffLayer` that annotates runs with AgentStart/AgentEnd/Handoff events
41//! - Traits
42//!   - `HandoffPolicy` - defines handoff tools and runtime behavior
43//! - Utils
44//!   - `GroupBuilder` to assemble named `AgentSvc`s and a picker strategy
45//!
46//! Implementation strategy
47//! - Use `tower::steer` or a small name→index map, routing to boxed `AgentSvc`s
48//! - `AgentPicker` decides next agent based on the current transcript and stop reason
49//! - `HandoffLayer` wraps the router to emit handoff events into the run
50//!
51//! Composition
52//! - `GroupBuilder::new().agent("triage", a).agent("specialist", b).picker(p).build()`
53//! - Can be wrapped by resilience/observability layers as needed
54//!
55//! Testing strategy
56//! - Build two fake agents that return deterministic responses
57//! - A picker that selects based on a message predicate
58//! - Assert the handoff events sequence and final run aggregation
59
60use std::collections::HashMap;
61use std::future::Future;
62use std::pin::Pin;
63use std::sync::{
64    atomic::{AtomicUsize, Ordering},
65    Arc,
66};
67
68use async_openai::types::{
69    ChatCompletionRequestMessage, ChatCompletionTool, CreateChatCompletionRequest,
70};
71use serde::{Deserialize, Serialize};
72use serde_json::Value;
73use tower::{BoxError, Layer, Service, ServiceExt};
74use tracing::{debug, error, info, instrument, trace, warn};
75
76use crate::core::{
77    AgentRun, AgentStopReason, AgentSvc, LoopState, StepOutcome, ToolInvocation, ToolOutput,
78};
79
80pub type AgentName = String;
81
82#[derive(Debug, Clone)]
83pub struct PickRequest {
84    pub messages: Vec<async_openai::types::ChatCompletionRequestMessage>,
85    pub last_stop: AgentStopReason,
86}
87
88pub trait AgentPicker: Service<PickRequest, Response = AgentName, Error = BoxError> {}
89impl<T> AgentPicker for T where T: Service<PickRequest, Response = AgentName, Error = BoxError> {}
90
91// ================================================================================================
92// Handoff Types
93// ================================================================================================
94
95/// Request to handoff conversation to another agent.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct HandoffRequest {
98    /// Target agent to handoff to
99    pub target_agent: String,
100    /// Optional context data to pass to target agent
101    pub context: Option<Value>,
102    /// Optional reason for the handoff
103    pub reason: Option<String>,
104}
105
106/// Response from a handoff attempt.
107#[derive(Debug, Clone)]
108pub struct HandoffResponse {
109    /// Whether the handoff was successful
110    pub success: bool,
111    /// The target agent (for confirmation)
112    pub target_agent: String,
113    /// Any context returned from the handoff
114    pub context: Option<Value>,
115}
116
117/// Outcome of group coordination - either continue, handoff, or finish.
118#[derive(Debug, Clone)]
119pub enum GroupOutcome {
120    /// Continue with current agent
121    Continue(AgentRun),
122    /// Handoff to another agent
123    Handoff(HandoffRequest),
124    /// Conversation is complete
125    Done(AgentRun),
126}
127
128/// Trait defining handoff policies - how agents collaborate during execution.
129///
130/// This trait separates handoff behavior from initial agent routing:
131/// - `AgentPicker`: WHO starts the conversation (initial routing)
132/// - `HandoffPolicy`: HOW agents collaborate during conversation (runtime handoffs)
133pub trait HandoffPolicy: Send + Sync + 'static {
134    /// Generate handoff tools that the LLM can call.
135    /// These tools will be injected into the agent's available tools.
136    fn handoff_tools(&self) -> Vec<ChatCompletionTool>;
137
138    /// Handle a handoff tool call by converting it to a HandoffRequest.
139    /// This is called when the LLM invokes one of the handoff tools.
140    fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError>;
141
142    /// Make runtime handoff decisions based on agent state and step outcome.
143    /// This allows for automatic handoffs based on conditions (e.g., no tools called).
144    fn should_handoff(&self, state: &LoopState, outcome: &StepOutcome) -> Option<HandoffRequest>;
145
146    /// Check if a tool call is a handoff tool managed by this policy.
147    fn is_handoff_tool(&self, tool_name: &str) -> bool;
148
149    /// Transform messages during handoff.
150    ///
151    /// This method is called after a handoff is confirmed but before the messages
152    /// are passed to the next agent. It allows for message transformation such as:
153    /// - Compaction to manage context length
154    /// - Privacy redaction when crossing trust boundaries
155    /// - Context enrichment with relevant information
156    /// - Translation between agents that operate in different languages
157    ///
158    /// The default implementation returns the messages unchanged.
159    fn transform_on_handoff(
160        &self,
161        messages: Vec<ChatCompletionRequestMessage>,
162        _from_agent: &str,
163        _to_agent: &str,
164        _handoff: &HandoffRequest,
165    ) -> Pin<Box<dyn Future<Output = Result<Vec<ChatCompletionRequestMessage>, BoxError>> + Send>>
166    {
167        Box::pin(async move { Ok(messages) })
168    }
169}
170
171pub struct GroupBuilder<P = (), H = ()> {
172    agents: HashMap<AgentName, AgentSvc>,
173    picker: Option<P>,
174    handoff_policy: Option<H>,
175}
176
177impl GroupBuilder<(), ()> {
178    pub fn new() -> Self {
179        Self {
180            agents: HashMap::new(),
181            picker: None,
182            handoff_policy: None,
183        }
184    }
185}
186
187impl Default for GroupBuilder<(), ()> {
188    fn default() -> Self {
189        Self::new()
190    }
191}
192
193impl<P, H> GroupBuilder<P, H> {
194    pub fn agent(mut self, name: impl Into<String>, svc: AgentSvc) -> Self {
195        self.agents.insert(name.into(), svc);
196        self
197    }
198
199    pub fn picker<NewP>(self, p: NewP) -> GroupBuilder<NewP, H> {
200        GroupBuilder {
201            agents: self.agents,
202            picker: Some(p),
203            handoff_policy: self.handoff_policy,
204        }
205    }
206
207    /// Add handoff policy to enable agent coordination
208    pub fn handoff_policy<NewH>(self, policy: NewH) -> GroupBuilder<P, NewH>
209    where
210        NewH: HandoffPolicy + Clone,
211    {
212        GroupBuilder {
213            agents: self.agents,
214            picker: self.picker,
215            handoff_policy: Some(policy),
216        }
217    }
218}
219
220impl<P> GroupBuilder<P, ()> {
221    /// Build a basic group without handoff coordination
222    pub fn build(self) -> GroupRouter<P>
223    where
224        P: AgentPicker + Clone + Send + 'static,
225        P::Future: Send + 'static,
226    {
227        GroupRouter {
228            agents: std::sync::Arc::new(tokio::sync::Mutex::new(self.agents)),
229            picker: self.picker.expect("picker"),
230        }
231    }
232}
233
234impl<P, H> GroupBuilder<P, H>
235where
236    H: HandoffPolicy + Clone + Send + 'static,
237{
238    /// Build a handoff-enabled group coordinator
239    pub fn build(self) -> HandoffCoordinator<P, H>
240    where
241        P: AgentPicker + Clone + Send + 'static,
242        P::Future: Send + 'static,
243    {
244        HandoffCoordinator::new(
245            self.agents,
246            self.picker.expect("picker"),
247            self.handoff_policy.expect("handoff_policy"),
248        )
249    }
250}
251
252pub struct GroupRouter<P> {
253    agents: std::sync::Arc<tokio::sync::Mutex<HashMap<AgentName, AgentSvc>>>,
254    picker: P,
255}
256
257impl<P> Service<CreateChatCompletionRequest> for GroupRouter<P>
258where
259    P: AgentPicker + Clone + Send + 'static,
260    P::Future: Send + 'static,
261{
262    type Response = AgentRun;
263    type Error = BoxError;
264    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
265
266    fn poll_ready(
267        &mut self,
268        _cx: &mut std::task::Context<'_>,
269    ) -> std::task::Poll<Result<(), Self::Error>> {
270        std::task::Poll::Ready(Ok(()))
271    }
272
273    fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
274        let mut picker = self.picker.clone();
275        let agents = self.agents.clone();
276        Box::pin(async move {
277            let pick = ServiceExt::ready(&mut picker)
278                .await?
279                .call(PickRequest {
280                    messages: req.messages.clone(),
281                    last_stop: AgentStopReason::DoneNoToolCalls,
282                })
283                .await?;
284            let mut guard = agents.lock().await;
285            let agent = guard
286                .get_mut(&pick)
287                .ok_or_else(|| format!("unknown agent: {}", pick))?;
288            let run = ServiceExt::ready(agent).await?.call(req).await?;
289            Ok(run)
290        })
291    }
292}
293
294// ================================================================================================
295// Handoff Coordinator - Enhanced GroupRouter
296// ================================================================================================
297
298/// Enhanced group coordinator that manages handoffs between agents.
299///
300/// This coordinator:
301/// 1. Uses AgentPicker for initial agent selection
302/// 2. Integrates HandoffLayer with agent tools to detect handoffs
303/// 3. Orchestrates seamless agent transitions
304/// 4. Maintains conversation context across handoffs
305/// 5. Supports both explicit and automatic handoff triggers
306pub struct HandoffCoordinator<P, H> {
307    agents: Arc<tokio::sync::Mutex<HashMap<AgentName, AgentSvc>>>,
308    picker: P,
309    handoff_policy: H,
310    current_agent: Arc<tokio::sync::Mutex<Option<AgentName>>>,
311    conversation_context:
312        Arc<tokio::sync::Mutex<Vec<async_openai::types::ChatCompletionRequestMessage>>>,
313}
314
315impl<P, H> HandoffCoordinator<P, H>
316where
317    P: AgentPicker,
318    H: HandoffPolicy + Clone,
319{
320    /// Create new handoff coordinator.
321    pub fn new(agents: HashMap<AgentName, AgentSvc>, picker: P, handoff_policy: H) -> Self {
322        Self {
323            agents: Arc::new(tokio::sync::Mutex::new(agents)),
324            picker,
325            handoff_policy,
326            current_agent: Arc::new(tokio::sync::Mutex::new(None)),
327            conversation_context: Arc::new(tokio::sync::Mutex::new(Vec::new())),
328        }
329    }
330
331    /// Get the handoff tools that will be available to agents.
332    pub fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
333        self.handoff_policy.handoff_tools()
334    }
335}
336
337impl<P, H> Service<CreateChatCompletionRequest> for HandoffCoordinator<P, H>
338where
339    P: AgentPicker + Clone + Send + 'static,
340    P::Future: Send + 'static,
341    H: HandoffPolicy + Clone + Send + 'static,
342{
343    type Response = AgentRun;
344    type Error = BoxError;
345    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
346
347    fn poll_ready(
348        &mut self,
349        _cx: &mut std::task::Context<'_>,
350    ) -> std::task::Poll<Result<(), Self::Error>> {
351        std::task::Poll::Ready(Ok(()))
352    }
353
354    #[instrument(skip_all, fields(request_id = %uuid::Uuid::new_v4()))]
355    fn call(&mut self, mut request: CreateChatCompletionRequest) -> Self::Future {
356        let agents = self.agents.clone();
357        let mut picker = self.picker.clone();
358        let handoff_policy = self.handoff_policy.clone();
359        let current_agent = self.current_agent.clone();
360        let conversation_context = self.conversation_context.clone();
361
362        Box::pin(async move {
363            const MAX_HANDOFFS: usize = 10;
364            let mut handoff_count = 0;
365            let mut all_messages = Vec::new();
366            let mut total_steps = 0;
367
368            info!("🚀 Starting handoff coordinator");
369            debug!("Initial request has {} messages", request.messages.len());
370
371            // Store the original request messages
372            let original_messages = request.messages.clone();
373
374            // Pick the initial agent
375            debug!("Invoking picker to determine initial agent");
376            let initial_pick = ServiceExt::ready(&mut picker)
377                .await?
378                .call(PickRequest {
379                    messages: request.messages.clone(),
380                    last_stop: AgentStopReason::DoneNoToolCalls,
381                })
382                .await?;
383
384            info!("📍 Initial agent selected: {}", initial_pick);
385            let mut current_agent_name = initial_pick;
386
387            // Update current agent tracking
388            {
389                let mut current = current_agent.lock().await;
390                *current = Some(current_agent_name.clone());
391            }
392
393            // Main handoff loop
394            loop {
395                // Check handoff limit
396                if handoff_count >= MAX_HANDOFFS {
397                    error!("❌ Maximum handoffs exceeded ({})", MAX_HANDOFFS);
398                    return Err("Maximum handoffs exceeded".into());
399                }
400
401                info!(
402                    "🤖 Executing agent: {} (handoff #{}/{})",
403                    current_agent_name,
404                    handoff_count + 1,
405                    MAX_HANDOFFS
406                );
407
408                // Get the current agent
409                let mut agents_guard = agents.lock().await;
410                let agent = agents_guard.get_mut(&current_agent_name).ok_or_else(|| {
411                    error!("Agent not found: {}", current_agent_name);
412                    format!("Unknown agent: {}", current_agent_name)
413                })?;
414
415                // Inject handoff tools into the request
416                let handoff_tools = handoff_policy.handoff_tools();
417                if !handoff_tools.is_empty() {
418                    debug!(
419                        "Injecting {} handoff tools into request",
420                        handoff_tools.len()
421                    );
422                    trace!(
423                        "Handoff tools: {:?}",
424                        handoff_tools
425                            .iter()
426                            .map(|t| &t.function.name)
427                            .collect::<Vec<_>>()
428                    );
429
430                    // Add handoff tools to the request if not already present
431                    if request.tools.is_none() {
432                        request.tools = Some(handoff_tools);
433                    } else {
434                        // Append handoff tools to existing tools
435                        request.tools.as_mut().unwrap().extend(handoff_tools);
436                    }
437                }
438
439                // Execute the current agent
440                debug!("Calling agent with {} messages", request.messages.len());
441                // Capture prefix length to examine only new messages later
442                let delta_start = request.messages.len();
443                let agent_run = ServiceExt::ready(agent)
444                    .await?
445                    .call(request.clone())
446                    .await?;
447
448                info!(
449                    "✅ Agent {} completed: {} messages, {} steps, stop reason: {:?}",
450                    current_agent_name,
451                    agent_run.messages.len(),
452                    agent_run.steps,
453                    agent_run.stop
454                );
455
456                // Add messages from this run
457                all_messages.extend(agent_run.messages.clone());
458                total_steps += agent_run.steps;
459
460                // Check for handoff in the agent's response
461                let mut handoff_requested = None;
462
463                // Look for handoff tool calls in the response messages
464                debug!("Checking for handoff tool calls in agent response");
465                for message in agent_run.messages.iter().skip(delta_start) {
466                    if let async_openai::types::ChatCompletionRequestMessage::Assistant(msg) =
467                        message
468                    {
469                        if let Some(tool_calls) = &msg.tool_calls {
470                            trace!("Found {} tool calls in message", tool_calls.len());
471                            for tool_call in tool_calls {
472                                if handoff_policy.is_handoff_tool(&tool_call.function.name) {
473                                    info!("🔄 Handoff tool detected: {}", tool_call.function.name);
474
475                                    // Parse the handoff request from the tool call
476                                    let invocation = ToolInvocation {
477                                        id: tool_call.id.clone(),
478                                        name: tool_call.function.name.clone(),
479                                        arguments: serde_json::from_str(
480                                            &tool_call.function.arguments,
481                                        )
482                                        .unwrap_or_else(|e| {
483                                            warn!("Failed to parse handoff tool arguments: {}", e);
484                                            serde_json::json!({})
485                                        }),
486                                    };
487
488                                    match handoff_policy.handle_handoff_tool(&invocation) {
489                                        Ok(handoff_req) => {
490                                            info!(
491                                                "📋 Handoff request: {} → {} (reason: {:?})",
492                                                current_agent_name,
493                                                handoff_req.target_agent,
494                                                handoff_req.reason
495                                            );
496                                            handoff_requested = Some(handoff_req);
497                                            break;
498                                        }
499                                        Err(e) => {
500                                            warn!("Failed to handle handoff tool: {}", e);
501                                        }
502                                    }
503                                }
504                            }
505                        }
506                    }
507                    if handoff_requested.is_some() {
508                        break;
509                    }
510                }
511
512                // If no explicit handoff via tool, check policy for automatic handoff
513                if handoff_requested.is_none() {
514                    debug!(
515                        "No explicit handoff tool called, checking policy for automatic handoff"
516                    );
517
518                    // Create a LoopState for the policy check
519                    let loop_state = LoopState { steps: total_steps };
520
521                    // Convert AgentRun to StepOutcome for policy check
522                    let step_outcome = if matches!(agent_run.stop, AgentStopReason::DoneNoToolCalls)
523                    {
524                        StepOutcome::Done {
525                            messages: agent_run.messages.clone(),
526                            aux: crate::core::StepAux::default(),
527                        }
528                    } else {
529                        StepOutcome::Next {
530                            messages: agent_run.messages.clone(),
531                            aux: crate::core::StepAux::default(),
532                            invoked_tools: vec![],
533                        }
534                    };
535
536                    if let Some(handoff) = handoff_policy.should_handoff(&loop_state, &step_outcome)
537                    {
538                        info!(
539                            "🔀 Automatic handoff triggered by policy: {} → {} (reason: {:?})",
540                            current_agent_name, handoff.target_agent, handoff.reason
541                        );
542                        handoff_requested = Some(handoff);
543                    } else {
544                        debug!("No automatic handoff triggered");
545                    }
546                }
547
548                // Handle handoff if requested
549                if let Some(handoff) = handoff_requested {
550                    info!(
551                        "🚦 Processing handoff: {} → {}",
552                        current_agent_name, handoff.target_agent
553                    );
554
555                    // Update the current agent
556                    let previous_agent = current_agent_name.clone();
557                    current_agent_name = handoff.target_agent.clone();
558                    handoff_count += 1;
559
560                    // Update current agent tracking
561                    {
562                        let mut current = current_agent.lock().await;
563                        *current = Some(current_agent_name.clone());
564                    }
565
566                    // Update conversation context
567                    {
568                        let mut context = conversation_context.lock().await;
569                        context.extend(agent_run.messages.clone());
570                        debug!(
571                            "Updated conversation context with {} messages",
572                            agent_run.messages.len()
573                        );
574                    }
575
576                    // Prepare messages for next agent with accumulated context
577                    let mut messages_for_next = original_messages.clone();
578                    messages_for_next.extend(all_messages.clone());
579
580                    // Apply transformation if the policy provides one
581                    debug!(
582                        "Applying handoff transformation from {} to {}",
583                        previous_agent, current_agent_name
584                    );
585                    match handoff_policy
586                        .transform_on_handoff(
587                            messages_for_next,
588                            &previous_agent,
589                            &current_agent_name,
590                            &handoff,
591                        )
592                        .await
593                    {
594                        Ok(transformed) => {
595                            debug!(
596                                "Handoff transformation complete: {} -> {} messages",
597                                original_messages.len() + all_messages.len(),
598                                transformed.len()
599                            );
600                            request.messages = transformed;
601                        }
602                        Err(e) => {
603                            warn!(
604                                "Handoff transformation failed, using original messages: {}",
605                                e
606                            );
607                            request.messages = original_messages.clone();
608                            request.messages.extend(all_messages.clone());
609                        }
610                    }
611
612                    info!(
613                        "🔗 Handoff complete: {} → {} (total handoffs: {})",
614                        previous_agent, current_agent_name, handoff_count
615                    );
616
617                    // Continue to next iteration with new agent
618                    continue;
619                }
620
621                // No handoff, we're done
622                info!(
623                    "🎯 Workflow complete: {} total messages, {} steps, final agent: {}",
624                    all_messages.len(),
625                    total_steps,
626                    current_agent_name
627                );
628
629                return Ok(AgentRun {
630                    messages: all_messages,
631                    steps: total_steps,
632                    stop: agent_run.stop,
633                });
634            }
635        })
636    }
637}
638
639// ================================================================================================
640// Handoff Policy Implementations
641// ================================================================================================
642
643/// Explicit handoff policy - generates a handoff tool for specific target agent.
644/// The LLM can call this tool to explicitly handoff to the target agent.
645#[derive(Debug, Clone)]
646pub struct ExplicitHandoffPolicy {
647    target_agent: String,
648    tool_name: Option<String>,
649    description: Option<String>,
650}
651
652impl ExplicitHandoffPolicy {
653    /// Create new explicit handoff policy for target agent.
654    pub fn new(target_agent: impl Into<String>) -> Self {
655        Self {
656            target_agent: target_agent.into(),
657            tool_name: None,
658            description: None,
659        }
660    }
661
662    /// Set custom tool name (default: "handoff_to_{target}")
663    pub fn with_tool_name(mut self, name: impl Into<String>) -> Self {
664        self.tool_name = Some(name.into());
665        self
666    }
667
668    /// Set custom tool description
669    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
670        self.description = Some(desc.into());
671        self
672    }
673
674    fn tool_name(&self) -> String {
675        self.tool_name
676            .clone()
677            .unwrap_or_else(|| format!("handoff_to_{}", self.target_agent))
678    }
679}
680
681impl HandoffPolicy for ExplicitHandoffPolicy {
682    #[instrument(skip(self))]
683    fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
684        let tool_name = self.tool_name();
685        let description = self
686            .description
687            .clone()
688            .unwrap_or_else(|| format!("Hand off the conversation to {}", self.target_agent));
689
690        debug!(
691            "ExplicitHandoffPolicy generating tool: {} → {}",
692            tool_name, self.target_agent
693        );
694
695        vec![ChatCompletionTool {
696            r#type: async_openai::types::ChatCompletionToolType::Function,
697            function: async_openai::types::FunctionObject {
698                name: tool_name,
699                description: Some(description),
700                parameters: Some(serde_json::json!({
701                    "type": "object",
702                    "properties": {
703                        "reason": {
704                            "type": "string",
705                            "description": "Reason for the handoff"
706                        },
707                        "context": {
708                            "type": "object",
709                            "description": "Optional context to pass to the target agent"
710                        }
711                    }
712                })),
713                ..Default::default()
714            },
715        }]
716    }
717
718    fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
719        if !self.is_handoff_tool(&invocation.name) {
720            return Err(format!("Not a handoff tool: {}", invocation.name).into());
721        }
722
723        let reason = invocation
724            .arguments
725            .get("reason")
726            .and_then(|v| v.as_str())
727            .map(|s| s.to_string());
728
729        let context = invocation.arguments.get("context").cloned();
730
731        Ok(HandoffRequest {
732            target_agent: self.target_agent.clone(),
733            context,
734            reason,
735        })
736    }
737
738    fn should_handoff(&self, _state: &LoopState, _outcome: &StepOutcome) -> Option<HandoffRequest> {
739        // Explicit handoffs only trigger via tool calls, not automatically
740        None
741    }
742
743    fn is_handoff_tool(&self, tool_name: &str) -> bool {
744        tool_name == self.tool_name()
745    }
746}
747
748/// Sequential handoff policy - automatically hands off to the next agent in sequence
749/// when the current agent completes without tool calls.
750#[derive(Debug, Clone)]
751pub struct SequentialHandoffPolicy {
752    agents: Vec<String>,
753    current_index: Arc<AtomicUsize>,
754}
755
756impl SequentialHandoffPolicy {
757    /// Create new sequential handoff policy with agent sequence.
758    pub fn new(agents: Vec<String>) -> Self {
759        Self {
760            agents,
761            current_index: Arc::new(AtomicUsize::new(0)),
762        }
763    }
764
765    fn next_agent(&self) -> Option<String> {
766        let current = self.current_index.fetch_add(1, Ordering::SeqCst);
767        if current + 1 < self.agents.len() {
768            Some(self.agents[current + 1].clone())
769        } else {
770            None
771        }
772    }
773}
774
775impl HandoffPolicy for SequentialHandoffPolicy {
776    #[instrument(skip(self))]
777    fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
778        // Sequential handoffs are automatic, no tools needed
779        debug!("SequentialHandoffPolicy: no tools (automatic handoffs only)");
780        vec![]
781    }
782
783    #[instrument(skip(self, invocation))]
784    fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
785        warn!(
786            "Sequential policy received unexpected handoff tool call: {}",
787            invocation.name
788        );
789        Err(format!(
790            "Sequential policy has no handoff tools: {}",
791            invocation.name
792        )
793        .into())
794    }
795
796    #[instrument(skip(self, _state, outcome))]
797    fn should_handoff(&self, _state: &LoopState, outcome: &StepOutcome) -> Option<HandoffRequest> {
798        match outcome {
799            StepOutcome::Done { .. } => {
800                // When agent completes without tool calls, move to next agent
801                if let Some(target) = self.next_agent() {
802                    let current_idx = self.current_index.load(Ordering::SeqCst);
803                    info!(
804                        "📈 Sequential handoff: step {}/{} → {}",
805                        current_idx,
806                        self.agents.len(),
807                        target
808                    );
809                    Some(HandoffRequest {
810                        target_agent: target,
811                        context: None,
812                        reason: Some("Sequential workflow step complete".to_string()),
813                    })
814                } else {
815                    debug!("Sequential workflow complete (all steps finished)");
816                    None
817                }
818            }
819            _ => {
820                trace!("Sequential policy: no handoff (agent not done)");
821                None
822            }
823        }
824    }
825
826    fn is_handoff_tool(&self, _tool_name: &str) -> bool {
827        false // No handoff tools for sequential policy
828    }
829}
830
831/// Multi-target explicit handoff policy - supports multiple handoff targets.
832/// Each tool name maps to a specific target agent.
833#[derive(Debug, Clone)]
834pub struct MultiExplicitHandoffPolicy {
835    handoffs: HashMap<String, String>,
836}
837
838impl MultiExplicitHandoffPolicy {
839    /// Create new multi-target handoff policy with tool->agent mappings.
840    pub fn new(handoffs: HashMap<String, String>) -> Self {
841        Self { handoffs }
842    }
843
844    /// Add a handoff mapping.
845    pub fn add_handoff(mut self, tool_name: impl Into<String>, target: impl Into<String>) -> Self {
846        self.handoffs.insert(tool_name.into(), target.into());
847        self
848    }
849}
850
851impl HandoffPolicy for MultiExplicitHandoffPolicy {
852    #[instrument(skip(self))]
853    fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
854        debug!(
855            "MultiExplicitHandoffPolicy generating {} handoff tools",
856            self.handoffs.len()
857        );
858        self.handoffs
859            .iter()
860            .map(|(tool_name, target_agent)| {
861                trace!("  Tool: {} → {}", tool_name, target_agent);
862                ChatCompletionTool {
863                    r#type: async_openai::types::ChatCompletionToolType::Function,
864                    function: async_openai::types::FunctionObject {
865                        name: tool_name.clone(),
866                        description: Some(format!("Hand off the conversation to {}", target_agent)),
867                        parameters: Some(serde_json::json!({
868                            "type": "object",
869                            "properties": {
870                                "reason": {
871                                    "type": "string",
872                                    "description": "Reason for the handoff"
873                                },
874                                "context": {
875                                    "type": "object",
876                                    "description": "Optional context to pass to the target agent"
877                                }
878                            }
879                        })),
880                        ..Default::default()
881                    },
882                }
883            })
884            .collect()
885    }
886
887    fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
888        let target_agent = self
889            .handoffs
890            .get(&invocation.name)
891            .ok_or_else(|| format!("Not a handoff tool: {}", invocation.name))?;
892
893        let reason = invocation
894            .arguments
895            .get("reason")
896            .and_then(|v| v.as_str())
897            .map(|s| s.to_string());
898
899        let context = invocation.arguments.get("context").cloned();
900
901        Ok(HandoffRequest {
902            target_agent: target_agent.clone(),
903            context,
904            reason,
905        })
906    }
907
908    fn should_handoff(&self, _state: &LoopState, _outcome: &StepOutcome) -> Option<HandoffRequest> {
909        // Multi-explicit handoffs only trigger via tool calls, not automatically
910        None
911    }
912
913    fn is_handoff_tool(&self, tool_name: &str) -> bool {
914        self.handoffs.contains_key(tool_name)
915    }
916}
917
918/// Enum for composing different handoff policies.
919#[derive(Debug, Clone)]
920pub enum AnyHandoffPolicy {
921    Explicit(ExplicitHandoffPolicy),
922    MultiExplicit(MultiExplicitHandoffPolicy),
923    Sequential(SequentialHandoffPolicy),
924    Composite(CompositeHandoffPolicy),
925}
926
927impl From<ExplicitHandoffPolicy> for AnyHandoffPolicy {
928    fn from(policy: ExplicitHandoffPolicy) -> Self {
929        AnyHandoffPolicy::Explicit(policy)
930    }
931}
932
933impl From<SequentialHandoffPolicy> for AnyHandoffPolicy {
934    fn from(policy: SequentialHandoffPolicy) -> Self {
935        AnyHandoffPolicy::Sequential(policy)
936    }
937}
938
939impl From<MultiExplicitHandoffPolicy> for AnyHandoffPolicy {
940    fn from(policy: MultiExplicitHandoffPolicy) -> Self {
941        AnyHandoffPolicy::MultiExplicit(policy)
942    }
943}
944
945impl From<CompositeHandoffPolicy> for AnyHandoffPolicy {
946    fn from(policy: CompositeHandoffPolicy) -> Self {
947        AnyHandoffPolicy::Composite(policy)
948    }
949}
950
951impl HandoffPolicy for AnyHandoffPolicy {
952    fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
953        match self {
954            AnyHandoffPolicy::Explicit(p) => p.handoff_tools(),
955            AnyHandoffPolicy::MultiExplicit(p) => p.handoff_tools(),
956            AnyHandoffPolicy::Sequential(p) => p.handoff_tools(),
957            AnyHandoffPolicy::Composite(p) => p.handoff_tools(),
958        }
959    }
960
961    fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
962        match self {
963            AnyHandoffPolicy::Explicit(p) => p.handle_handoff_tool(invocation),
964            AnyHandoffPolicy::MultiExplicit(p) => p.handle_handoff_tool(invocation),
965            AnyHandoffPolicy::Sequential(p) => p.handle_handoff_tool(invocation),
966            AnyHandoffPolicy::Composite(p) => p.handle_handoff_tool(invocation),
967        }
968    }
969
970    fn should_handoff(&self, state: &LoopState, outcome: &StepOutcome) -> Option<HandoffRequest> {
971        match self {
972            AnyHandoffPolicy::Explicit(p) => p.should_handoff(state, outcome),
973            AnyHandoffPolicy::MultiExplicit(p) => p.should_handoff(state, outcome),
974            AnyHandoffPolicy::Sequential(p) => p.should_handoff(state, outcome),
975            AnyHandoffPolicy::Composite(p) => p.should_handoff(state, outcome),
976        }
977    }
978
979    fn is_handoff_tool(&self, tool_name: &str) -> bool {
980        match self {
981            AnyHandoffPolicy::Explicit(p) => p.is_handoff_tool(tool_name),
982            AnyHandoffPolicy::MultiExplicit(p) => p.is_handoff_tool(tool_name),
983            AnyHandoffPolicy::Sequential(p) => p.is_handoff_tool(tool_name),
984            AnyHandoffPolicy::Composite(p) => p.is_handoff_tool(tool_name),
985        }
986    }
987}
988
989/// Composite handoff policy - combines multiple handoff policies.
990#[derive(Debug, Clone)]
991pub struct CompositeHandoffPolicy {
992    policies: Vec<AnyHandoffPolicy>,
993}
994
995impl CompositeHandoffPolicy {
996    /// Create new composite policy from a list of policies.
997    pub fn new(policies: Vec<AnyHandoffPolicy>) -> Self {
998        Self { policies }
999    }
1000}
1001
1002impl HandoffPolicy for CompositeHandoffPolicy {
1003    fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
1004        self.policies
1005            .iter()
1006            .flat_map(|p| p.handoff_tools())
1007            .collect()
1008    }
1009
1010    fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
1011        for policy in &self.policies {
1012            if policy.is_handoff_tool(&invocation.name) {
1013                return policy.handle_handoff_tool(invocation);
1014            }
1015        }
1016        Err(format!("No policy handles handoff tool: {}", invocation.name).into())
1017    }
1018
1019    fn should_handoff(&self, state: &LoopState, outcome: &StepOutcome) -> Option<HandoffRequest> {
1020        // Return first handoff decision from any policy
1021        for policy in &self.policies {
1022            if let Some(handoff) = policy.should_handoff(state, outcome) {
1023                return Some(handoff);
1024            }
1025        }
1026        None
1027    }
1028
1029    fn is_handoff_tool(&self, tool_name: &str) -> bool {
1030        self.policies.iter().any(|p| p.is_handoff_tool(tool_name))
1031    }
1032}
1033
1034// ================================================================================================
1035// CompactingHandoffPolicy - Wrapper that adds compaction on handoff
1036// ================================================================================================
1037
1038/// A handoff policy wrapper that applies conversation compaction during handoffs.
1039///
1040/// This policy wraps another handoff policy and adds automatic conversation compaction
1041/// when messages are transferred between agents. This helps manage context length
1042/// and can improve performance with long conversations.
1043#[derive(Clone)]
1044pub struct CompactingHandoffPolicy<P> {
1045    inner: P,
1046    compaction_policy: crate::auto_compaction::CompactionPolicy,
1047    provider: Arc<tokio::sync::Mutex<crate::provider::OpenAIProvider>>,
1048    token_counter: Arc<crate::auto_compaction::SimpleTokenCounter>,
1049}
1050
1051impl<P> CompactingHandoffPolicy<P> {
1052    /// Create a new compacting handoff policy wrapping the given policy.
1053    pub fn new(
1054        inner: P,
1055        compaction_policy: crate::auto_compaction::CompactionPolicy,
1056        provider: Arc<tokio::sync::Mutex<crate::provider::OpenAIProvider>>,
1057    ) -> Self {
1058        Self {
1059            inner,
1060            compaction_policy,
1061            provider,
1062            token_counter: Arc::new(crate::auto_compaction::SimpleTokenCounter::new()),
1063        }
1064    }
1065}
1066
1067// Dummy service for the compactor (we only use the compact_messages method)
1068struct DummyStepService;
1069
1070impl Service<CreateChatCompletionRequest> for DummyStepService {
1071    type Response = crate::core::StepOutcome;
1072    type Error = BoxError;
1073    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1074
1075    fn poll_ready(
1076        &mut self,
1077        _: &mut std::task::Context<'_>,
1078    ) -> std::task::Poll<Result<(), Self::Error>> {
1079        std::task::Poll::Ready(Ok(()))
1080    }
1081
1082    fn call(&mut self, _: CreateChatCompletionRequest) -> Self::Future {
1083        Box::pin(async {
1084            Ok(crate::core::StepOutcome::Done {
1085                messages: vec![],
1086                aux: Default::default(),
1087            })
1088        })
1089    }
1090}
1091
1092impl<P> HandoffPolicy for CompactingHandoffPolicy<P>
1093where
1094    P: HandoffPolicy + Clone,
1095{
1096    fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
1097        self.inner.handoff_tools()
1098    }
1099
1100    fn handle_handoff_tool(&self, invocation: &ToolInvocation) -> Result<HandoffRequest, BoxError> {
1101        self.inner.handle_handoff_tool(invocation)
1102    }
1103
1104    fn should_handoff(&self, state: &LoopState, outcome: &StepOutcome) -> Option<HandoffRequest> {
1105        self.inner.should_handoff(state, outcome)
1106    }
1107
1108    fn is_handoff_tool(&self, tool_name: &str) -> bool {
1109        self.inner.is_handoff_tool(tool_name)
1110    }
1111
1112    fn transform_on_handoff(
1113        &self,
1114        messages: Vec<ChatCompletionRequestMessage>,
1115        from_agent: &str,
1116        to_agent: &str,
1117        handoff: &HandoffRequest,
1118    ) -> Pin<Box<dyn Future<Output = Result<Vec<ChatCompletionRequestMessage>, BoxError>> + Send>>
1119    {
1120        let compaction_policy = self.compaction_policy.clone();
1121        let provider = self.provider.clone();
1122        let token_counter = self.token_counter.clone();
1123        let inner_policy = self.inner.clone();
1124        let from_agent = from_agent.to_string();
1125        let to_agent = to_agent.to_string();
1126        let handoff = handoff.clone();
1127
1128        Box::pin(async move {
1129            // First apply the inner policy's transformation (if any)
1130            let messages = inner_policy
1131                .transform_on_handoff(messages, &from_agent, &to_agent, &handoff)
1132                .await?;
1133
1134            // Then apply compaction
1135            tracing::debug!(
1136                "Applying compaction during handoff from {} to {}",
1137                from_agent,
1138                to_agent
1139            );
1140
1141            // Create a temporary compactor
1142            let compactor = crate::auto_compaction::AutoCompaction::<
1143                DummyStepService,
1144                crate::provider::OpenAIProvider,
1145                crate::auto_compaction::SimpleTokenCounter,
1146            > {
1147                inner: Arc::new(tokio::sync::Mutex::new(DummyStepService)),
1148                policy: compaction_policy,
1149                provider,
1150                token_counter,
1151            };
1152
1153            match compactor.compact_messages(messages.clone()).await {
1154                Ok(compacted) => {
1155                    tracing::info!(
1156                        "Successfully compacted messages during handoff: {} -> {} messages",
1157                        messages.len(),
1158                        compacted.len()
1159                    );
1160                    Ok(compacted)
1161                }
1162                Err(e) => {
1163                    tracing::warn!(
1164                        "Compaction failed during handoff, using original messages: {}",
1165                        e
1166                    );
1167                    Ok(messages)
1168                }
1169            }
1170        })
1171    }
1172}
1173
1174// ================================================================================================
1175// Convenience Constructors
1176// ================================================================================================
1177
1178/// Create an explicit handoff policy for a target agent.
1179///
1180/// Example:
1181/// ```rust,ignore
1182/// let policy = explicit_handoff_to("specialist")
1183///     .with_description("Escalate complex issues to specialist");
1184/// ```
1185pub fn explicit_handoff_to(target: impl Into<String>) -> ExplicitHandoffPolicy {
1186    ExplicitHandoffPolicy::new(target)
1187}
1188
1189/// Create a sequential handoff policy with agent sequence.
1190///
1191/// Example:
1192/// ```rust,ignore  
1193/// let policy = sequential_handoff(vec!["researcher", "writer", "reviewer"]);
1194/// ```
1195pub fn sequential_handoff(agents: Vec<String>) -> SequentialHandoffPolicy {
1196    SequentialHandoffPolicy::new(agents)
1197}
1198
1199/// Create a composite handoff policy combining multiple policies.
1200///
1201/// Example:
1202/// ```rust,ignore
1203/// let policy = composite_handoff(vec![
1204///     AnyHandoffPolicy::Explicit(explicit_handoff_to("specialist")),
1205///     AnyHandoffPolicy::Sequential(sequential_handoff(vec!["a".to_string(), "b".to_string()])),
1206/// ]);
1207/// ```
1208pub fn composite_handoff(policies: Vec<AnyHandoffPolicy>) -> CompositeHandoffPolicy {
1209    CompositeHandoffPolicy::new(policies)
1210}
1211
1212// ================================================================================================
1213// Handoff Layer - Tower Integration
1214// ================================================================================================
1215
1216/// Enhanced ToolOutput that can signal handoff requests.
1217#[derive(Debug, Clone)]
1218pub enum ToolOutputResult {
1219    /// Regular tool output
1220    Tool(ToolOutput),
1221    /// Handoff request from a handoff tool
1222    Handoff(HandoffRequest),
1223}
1224
1225impl From<ToolOutput> for ToolOutputResult {
1226    fn from(output: ToolOutput) -> Self {
1227        ToolOutputResult::Tool(output)
1228    }
1229}
1230
1231/// Layer that adds handoff capabilities to tool services.
1232///
1233/// This layer:
1234/// 1. Wraps an existing tool service
1235/// 2. Adds handoff tools from the policy to the available tools
1236/// 3. Intercepts handoff tool calls and converts them to HandoffRequest
1237/// 4. Passes through regular tool calls unchanged
1238#[derive(Debug, Clone)]
1239pub struct HandoffLayer<P> {
1240    handoff_policy: P,
1241}
1242
1243impl<P> HandoffLayer<P>
1244where
1245    P: HandoffPolicy,
1246{
1247    /// Create a new HandoffLayer with the given handoff policy.
1248    pub fn new(policy: P) -> Self {
1249        Self {
1250            handoff_policy: policy,
1251        }
1252    }
1253
1254    /// Get the handoff tools that this layer will inject.
1255    pub fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
1256        self.handoff_policy.handoff_tools()
1257    }
1258}
1259
1260impl<S, P> Layer<S> for HandoffLayer<P>
1261where
1262    P: HandoffPolicy + Clone,
1263{
1264    type Service = HandoffService<S, P>;
1265
1266    fn layer(&self, inner: S) -> Self::Service {
1267        HandoffService::new(inner, self.handoff_policy.clone())
1268    }
1269}
1270
1271/// Service that wraps a tool service and adds handoff capabilities.
1272///
1273/// This service handles both regular tool calls and handoff tool calls:
1274/// - Regular tools: Pass through to inner service, return ToolOutputResult::Tool
1275/// - Handoff tools: Process with policy, return ToolOutputResult::Handoff
1276#[derive(Debug, Clone)]
1277pub struct HandoffService<S, P> {
1278    inner: S,
1279    handoff_policy: P,
1280}
1281
1282/// Helper to wrap a `ToolRouter` with handoff interception and return a boxed tool service.
1283pub fn layer_tool_router_with_handoff<P>(
1284    router: crate::core::ToolRouter,
1285    policy: P,
1286) -> crate::core::ToolSvc
1287where
1288    P: HandoffPolicy + Clone,
1289{
1290    let layer = HandoffLayer::new(policy);
1291    let svc = layer.layer(router);
1292    tower::util::BoxCloneService::new(svc)
1293}
1294
1295impl<S, P> HandoffService<S, P>
1296where
1297    P: HandoffPolicy,
1298{
1299    /// Create a new HandoffService wrapping the inner service.
1300    pub fn new(inner: S, policy: P) -> Self {
1301        Self {
1302            inner,
1303            handoff_policy: policy,
1304        }
1305    }
1306}
1307
1308impl<S, P> Service<ToolInvocation> for HandoffService<S, P>
1309where
1310    S: Service<ToolInvocation, Response = ToolOutput, Error = BoxError> + Send + 'static,
1311    S::Future: Send + 'static,
1312    P: HandoffPolicy + Clone,
1313{
1314    type Response = ToolOutput;
1315    type Error = BoxError;
1316    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1317
1318    fn poll_ready(
1319        &mut self,
1320        cx: &mut std::task::Context<'_>,
1321    ) -> std::task::Poll<Result<(), Self::Error>> {
1322        self.inner.poll_ready(cx)
1323    }
1324
1325    fn call(&mut self, req: ToolInvocation) -> Self::Future {
1326        // Check if this is a handoff tool call
1327        if self.handoff_policy.is_handoff_tool(&req.name) {
1328            // Handle handoff tool call
1329            let policy = self.handoff_policy.clone();
1330            Box::pin(async move {
1331                let handoff_request = policy.handle_handoff_tool(&req)?;
1332                let out = ToolOutput {
1333                    id: req.id,
1334                    result: serde_json::json!({
1335                        "handoff": handoff_request.target_agent,
1336                        "ack": true,
1337                        "reason": handoff_request.reason,
1338                        "context": handoff_request.context,
1339                    }),
1340                };
1341                Ok(out)
1342            })
1343        } else {
1344            // Regular tool call - pass through to inner service
1345            let future = self.inner.call(req);
1346            Box::pin(future)
1347        }
1348    }
1349}
1350
1351#[cfg(test)]
1352mod tests {
1353    use super::*;
1354    use async_openai::types::{CreateChatCompletionRequest, CreateChatCompletionRequestArgs};
1355    use tower::util::BoxService;
1356
1357    // ================================================================================================
1358    // Handoff Policy Tests
1359    // ================================================================================================
1360
1361    mod handoff_policy_tests {
1362        use super::*;
1363        use async_openai::types::ChatCompletionRequestUserMessageArgs;
1364
1365        #[test]
1366        fn explicit_handoff_policy_generates_correct_tools() {
1367            let policy =
1368                explicit_handoff_to("specialist").with_description("Escalate to specialist");
1369
1370            let tools = policy.handoff_tools();
1371            assert_eq!(tools.len(), 1);
1372
1373            let tool = &tools[0];
1374            assert_eq!(tool.function.name, "handoff_to_specialist");
1375            assert_eq!(
1376                tool.function.description,
1377                Some("Escalate to specialist".to_string())
1378            );
1379
1380            // Verify tool parameters schema
1381            let params = tool.function.parameters.as_ref().unwrap();
1382            assert!(params.get("properties").is_some());
1383            assert!(params["properties"]["reason"].is_object());
1384        }
1385
1386        #[test]
1387        fn explicit_handoff_policy_handles_tool_calls() {
1388            let policy = explicit_handoff_to("specialist");
1389
1390            let invocation = ToolInvocation {
1391                id: "test_id".to_string(),
1392                name: "handoff_to_specialist".to_string(),
1393                arguments: serde_json::json!({
1394                    "reason": "Complex technical issue",
1395                    "context": {"priority": "high"}
1396                }),
1397            };
1398
1399            let result = policy.handle_handoff_tool(&invocation).unwrap();
1400            assert_eq!(result.target_agent, "specialist");
1401            assert_eq!(result.reason, Some("Complex technical issue".to_string()));
1402            assert!(result.context.is_some());
1403        }
1404
1405        #[test]
1406        fn explicit_handoff_policy_rejects_wrong_tools() {
1407            let policy = explicit_handoff_to("specialist");
1408
1409            let invocation = ToolInvocation {
1410                id: "test_id".to_string(),
1411                name: "some_other_tool".to_string(),
1412                arguments: serde_json::json!({}),
1413            };
1414
1415            let result = policy.handle_handoff_tool(&invocation);
1416            assert!(result.is_err());
1417            assert!(result
1418                .unwrap_err()
1419                .to_string()
1420                .contains("Not a handoff tool"));
1421        }
1422
1423        #[test]
1424        fn explicit_handoff_policy_no_automatic_handoffs() {
1425            let policy = explicit_handoff_to("specialist");
1426            let state = LoopState { steps: 1 };
1427            let outcome = StepOutcome::Done {
1428                messages: vec![],
1429                aux: crate::core::StepAux::default(),
1430            };
1431
1432            // Explicit policies don't trigger automatic handoffs
1433            assert!(policy.should_handoff(&state, &outcome).is_none());
1434        }
1435
1436        #[tokio::test]
1437        async fn handoff_policy_default_transformation() {
1438            // Test that the default transform_on_handoff returns messages unchanged
1439            let policy = explicit_handoff_to("specialist");
1440
1441            let messages = vec![ChatCompletionRequestUserMessageArgs::default()
1442                .content("Test message")
1443                .build()
1444                .unwrap()
1445                .into()];
1446
1447            let handoff = HandoffRequest {
1448                target_agent: "specialist".to_string(),
1449                context: None,
1450                reason: Some("Test handoff".to_string()),
1451            };
1452
1453            let result = policy
1454                .transform_on_handoff(messages.clone(), "agent1", "specialist", &handoff)
1455                .await
1456                .unwrap();
1457
1458            assert_eq!(result.len(), messages.len());
1459            assert_eq!(result, messages);
1460        }
1461
1462        #[test]
1463        fn sequential_handoff_policy_advances_correctly() {
1464            let agents = vec!["a".to_string(), "b".to_string(), "c".to_string()];
1465            let policy = sequential_handoff(agents.clone());
1466
1467            let state = LoopState { steps: 1 };
1468            let outcome = StepOutcome::Done {
1469                messages: vec![],
1470                aux: crate::core::StepAux::default(),
1471            };
1472
1473            // First handoff: a -> b
1474            let handoff1 = policy.should_handoff(&state, &outcome).unwrap();
1475            assert_eq!(handoff1.target_agent, "b");
1476            assert!(handoff1.reason.is_some());
1477
1478            // Second handoff: b -> c
1479            let handoff2 = policy.should_handoff(&state, &outcome).unwrap();
1480            assert_eq!(handoff2.target_agent, "c");
1481
1482            // Third call: no more agents
1483            let handoff3 = policy.should_handoff(&state, &outcome);
1484            assert!(handoff3.is_none());
1485        }
1486
1487        #[test]
1488        fn sequential_handoff_policy_no_tools() {
1489            let policy = sequential_handoff(vec!["a".to_string(), "b".to_string()]);
1490
1491            // Sequential policies don't generate tools
1492            assert!(policy.handoff_tools().is_empty());
1493            assert!(!policy.is_handoff_tool("any_tool"));
1494        }
1495
1496        #[test]
1497        fn composite_handoff_policy_combines_tools() {
1498            let explicit1 = explicit_handoff_to("specialist");
1499            let explicit2 = explicit_handoff_to("supervisor");
1500            let sequential = sequential_handoff(vec!["a".to_string(), "b".to_string()]);
1501
1502            let composite = composite_handoff(vec![
1503                AnyHandoffPolicy::Explicit(explicit1),
1504                AnyHandoffPolicy::Explicit(explicit2),
1505                AnyHandoffPolicy::Sequential(sequential),
1506            ]);
1507
1508            let tools = composite.handoff_tools();
1509            // Should have 2 tools (from explicit policies), sequential has none
1510            assert_eq!(tools.len(), 2);
1511
1512            let tool_names: Vec<&str> = tools.iter().map(|t| t.function.name.as_str()).collect();
1513            assert!(tool_names.contains(&"handoff_to_specialist"));
1514            assert!(tool_names.contains(&"handoff_to_supervisor"));
1515        }
1516
1517        #[test]
1518        fn composite_handoff_policy_routes_to_correct_handler() {
1519            let explicit = explicit_handoff_to("specialist");
1520            let sequential = sequential_handoff(vec!["a".to_string()]);
1521
1522            let composite = composite_handoff(vec![
1523                AnyHandoffPolicy::Explicit(explicit),
1524                AnyHandoffPolicy::Sequential(sequential),
1525            ]);
1526
1527            let invocation = ToolInvocation {
1528                id: "test_id".to_string(),
1529                name: "handoff_to_specialist".to_string(),
1530                arguments: serde_json::json!({"reason": "test"}),
1531            };
1532
1533            let result = composite.handle_handoff_tool(&invocation).unwrap();
1534            assert_eq!(result.target_agent, "specialist");
1535        }
1536
1537        #[test]
1538        fn composite_handoff_policy_first_match_wins() {
1539            let explicit = explicit_handoff_to("specialist");
1540            let sequential = sequential_handoff(vec!["a".to_string(), "b".to_string()]);
1541
1542            let composite = composite_handoff(vec![
1543                AnyHandoffPolicy::Explicit(explicit),
1544                AnyHandoffPolicy::Sequential(sequential),
1545            ]);
1546
1547            let state = LoopState { steps: 1 };
1548            let outcome = StepOutcome::Done {
1549                messages: vec![],
1550                aux: crate::core::StepAux::default(),
1551            };
1552
1553            // Explicit policy returns None, sequential returns Some
1554            // Should get sequential result since explicit is first but returns None
1555            let result = composite.should_handoff(&state, &outcome).unwrap();
1556            assert_eq!(result.target_agent, "b"); // First handoff in sequential
1557        }
1558
1559        #[test]
1560        fn any_handoff_policy_conversions_work() {
1561            let explicit = explicit_handoff_to("specialist");
1562            let sequential = sequential_handoff(vec!["a".to_string()]);
1563
1564            // Test From impls
1565            let _any1: AnyHandoffPolicy = explicit.into();
1566            let _any2: AnyHandoffPolicy = sequential.into();
1567
1568            // Should compile and work: conversions successful
1569        }
1570
1571        #[test]
1572        fn handoff_request_serialization() {
1573            let request = HandoffRequest {
1574                target_agent: "specialist".to_string(),
1575                context: Some(serde_json::json!({"priority": "high"})),
1576                reason: Some("Complex issue".to_string()),
1577            };
1578
1579            // Should serialize/deserialize correctly
1580            let json = serde_json::to_string(&request).unwrap();
1581            let deserialized: HandoffRequest = serde_json::from_str(&json).unwrap();
1582
1583            assert_eq!(deserialized.target_agent, request.target_agent);
1584            assert_eq!(deserialized.reason, request.reason);
1585            assert_eq!(deserialized.context, request.context);
1586        }
1587    }
1588
1589    #[tokio::test]
1590    async fn per_agent_instructions_applied_to_each_agent_request() {
1591        use crate::core::{Agent, CompositePolicy};
1592        use crate::provider::ProviderResponse;
1593        use async_openai::config::OpenAIConfig;
1594        use async_openai::types::{ChatCompletionResponseMessage, Role as RespRole};
1595        use std::sync::Arc as StdArc;
1596
1597        // Capturing provider that records the last request it received
1598        #[derive(Clone)]
1599        struct CapturingProvider {
1600            captured: StdArc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>>,
1601        }
1602        impl tower::Service<CreateChatCompletionRequest> for CapturingProvider {
1603            type Response = ProviderResponse;
1604            type Error = BoxError;
1605            type Future = std::pin::Pin<
1606                Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
1607            >;
1608            fn poll_ready(
1609                &mut self,
1610                _cx: &mut std::task::Context<'_>,
1611            ) -> std::task::Poll<Result<(), Self::Error>> {
1612                std::task::Poll::Ready(Ok(()))
1613            }
1614            fn call(&mut self, req: CreateChatCompletionRequest) -> Self::Future {
1615                let captured = self.captured.clone();
1616                Box::pin(async move {
1617                    *captured.lock().await = Some(req);
1618                    #[allow(deprecated)]
1619                    let assistant = ChatCompletionResponseMessage {
1620                        content: Some("ok".into()),
1621                        role: RespRole::Assistant,
1622                        tool_calls: None,
1623                        function_call: None,
1624                        refusal: None,
1625                        audio: None,
1626                    };
1627                    Ok(ProviderResponse {
1628                        assistant,
1629                        prompt_tokens: 1,
1630                        completion_tokens: 1,
1631                    })
1632                })
1633            }
1634        }
1635
1636        let client = StdArc::new(async_openai::Client::<OpenAIConfig>::new());
1637
1638        // Agent A with instructions "A"
1639        let cap_a: StdArc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>> =
1640            StdArc::new(tokio::sync::Mutex::new(None));
1641        let provider_a = CapturingProvider {
1642            captured: cap_a.clone(),
1643        };
1644        let agent_a = Agent::builder(client.clone())
1645            .with_provider(provider_a)
1646            .model("gpt-4o")
1647            .instructions("A")
1648            .policy(CompositePolicy::new(vec![
1649                crate::core::policies::until_no_tool_calls(),
1650                crate::core::policies::max_steps(1),
1651            ]))
1652            .build();
1653
1654        // Agent B with instructions "B"
1655        let cap_b: StdArc<tokio::sync::Mutex<Option<CreateChatCompletionRequest>>> =
1656            StdArc::new(tokio::sync::Mutex::new(None));
1657        let provider_b = CapturingProvider {
1658            captured: cap_b.clone(),
1659        };
1660        let agent_b = Agent::builder(client.clone())
1661            .with_provider(provider_b)
1662            .model("gpt-4o")
1663            .instructions("B")
1664            .policy(CompositePolicy::new(vec![
1665                crate::core::policies::until_no_tool_calls(),
1666                crate::core::policies::max_steps(1),
1667            ]))
1668            .build();
1669
1670        // Coordinator with sequential handoff A -> B
1671        let mut agents = std::collections::HashMap::new();
1672        agents.insert("a".to_string(), agent_a);
1673        agents.insert("b".to_string(), agent_b);
1674        let picker =
1675            tower::service_fn(|_pr: PickRequest| async move { Ok::<_, BoxError>("a".to_string()) });
1676        let policy = SequentialHandoffPolicy::new(vec!["a".into(), "b".into()]);
1677        let mut coord = HandoffCoordinator::new(agents, picker, policy);
1678
1679        // User-only request
1680        let req = CreateChatCompletionRequestArgs::default()
1681            .model("gpt-4o")
1682            .messages(vec![
1683                async_openai::types::ChatCompletionRequestUserMessageArgs::default()
1684                    .content("hi")
1685                    .build()
1686                    .unwrap()
1687                    .into(),
1688            ])
1689            .build()
1690            .unwrap();
1691
1692        let _run = tower::ServiceExt::ready(&mut coord)
1693            .await
1694            .unwrap()
1695            .call(req)
1696            .await
1697            .unwrap();
1698
1699        // Assert captured requests for each agent have their own instructions as the first system
1700        let got_a = cap_a.lock().await.clone().expect("captured request a");
1701        match &got_a.messages[0] {
1702            ChatCompletionRequestMessage::System(s) => match &s.content {
1703                async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1704                    assert_eq!(t, "A");
1705                }
1706                _ => panic!("expected text content"),
1707            },
1708            _ => panic!("expected first message to be system"),
1709        }
1710        let got_b = cap_b.lock().await.clone().expect("captured request b");
1711        match &got_b.messages[0] {
1712            ChatCompletionRequestMessage::System(s) => match &s.content {
1713                async_openai::types::ChatCompletionRequestSystemMessageContent::Text(t) => {
1714                    assert_eq!(t, "B");
1715                }
1716                _ => panic!("expected text content"),
1717            },
1718            _ => panic!("expected first message to be system"),
1719        }
1720    }
1721
1722    // ================================================================================================
1723    // Handoff Layer Integration Tests
1724    // ================================================================================================
1725
1726    mod handoff_layer_tests {
1727        use super::*;
1728        use tower::{service_fn, ServiceExt};
1729
1730        // Mock tool service for testing
1731        fn mock_tool_service() -> tower::util::BoxService<ToolInvocation, ToolOutput, BoxError> {
1732            tower::util::BoxService::new(service_fn(|req: ToolInvocation| async move {
1733                Ok(ToolOutput {
1734                    id: req.id,
1735                    result: serde_json::json!({"status": "success", "tool": req.name}),
1736                })
1737            }))
1738        }
1739
1740        #[tokio::test]
1741        async fn handoff_layer_passes_through_regular_tools() {
1742            let policy = explicit_handoff_to("specialist");
1743            let layer = HandoffLayer::new(policy);
1744            let mut service = layer.layer(mock_tool_service());
1745
1746            let invocation = ToolInvocation {
1747                id: "test_id".to_string(),
1748                name: "regular_tool".to_string(),
1749                arguments: serde_json::json!({"param": "value"}),
1750            };
1751
1752            let result = ServiceExt::ready(&mut service)
1753                .await
1754                .unwrap()
1755                .call(invocation)
1756                .await
1757                .unwrap();
1758            assert_eq!(result.id, "test_id");
1759            assert_eq!(result.result["tool"], "regular_tool");
1760        }
1761
1762        #[tokio::test]
1763        async fn handoff_layer_intercepts_handoff_tools() {
1764            let policy = explicit_handoff_to("specialist");
1765            let layer = HandoffLayer::new(policy);
1766            let mut service = layer.layer(mock_tool_service());
1767
1768            let invocation = ToolInvocation {
1769                id: "test_id".to_string(),
1770                name: "handoff_to_specialist".to_string(),
1771                arguments: serde_json::json!({"reason": "Complex issue"}),
1772            };
1773
1774            let result = ServiceExt::ready(&mut service)
1775                .await
1776                .unwrap()
1777                .call(invocation)
1778                .await
1779                .unwrap();
1780            assert_eq!(result.id, "test_id");
1781            assert_eq!(result.result["handoff"], "specialist");
1782        }
1783
1784        #[tokio::test]
1785        async fn handoff_layer_exposes_policy_tools() {
1786            let policy =
1787                explicit_handoff_to("specialist").with_description("Escalate to specialist");
1788            let layer = HandoffLayer::new(policy);
1789
1790            let tools = layer.handoff_tools();
1791            assert_eq!(tools.len(), 1);
1792            assert_eq!(tools[0].function.name, "handoff_to_specialist");
1793            assert_eq!(
1794                tools[0].function.description,
1795                Some("Escalate to specialist".to_string())
1796            );
1797        }
1798
1799        #[tokio::test]
1800        async fn handoff_layer_with_composite_policy() {
1801            let composite = composite_handoff(vec![
1802                AnyHandoffPolicy::Explicit(explicit_handoff_to("specialist")),
1803                AnyHandoffPolicy::Explicit(explicit_handoff_to("supervisor")),
1804            ]);
1805
1806            let layer = HandoffLayer::new(composite.clone());
1807            let mut service = layer.layer(mock_tool_service());
1808
1809            // Test first handoff tool
1810            let invocation1 = ToolInvocation {
1811                id: "test_id1".to_string(),
1812                name: "handoff_to_specialist".to_string(),
1813                arguments: serde_json::json!({"reason": "Technical issue"}),
1814            };
1815
1816            let result1 = ServiceExt::ready(&mut service)
1817                .await
1818                .unwrap()
1819                .call(invocation1)
1820                .await
1821                .unwrap();
1822            assert_eq!(result1.result["handoff"], "specialist");
1823
1824            // Test second handoff tool
1825            let invocation2 = ToolInvocation {
1826                id: "test_id2".to_string(),
1827                name: "handoff_to_supervisor".to_string(),
1828                arguments: serde_json::json!({"reason": "Escalation needed"}),
1829            };
1830
1831            let result2 = ServiceExt::ready(&mut service)
1832                .await
1833                .unwrap()
1834                .call(invocation2)
1835                .await
1836                .unwrap();
1837            assert_eq!(result2.result["handoff"], "supervisor");
1838
1839            // Verify layer exposes both tools
1840            let tools = layer.handoff_tools();
1841            assert_eq!(tools.len(), 2);
1842            let tool_names: Vec<&str> = tools.iter().map(|t| t.function.name.as_str()).collect();
1843            assert!(tool_names.contains(&"handoff_to_specialist"));
1844            assert!(tool_names.contains(&"handoff_to_supervisor"));
1845        }
1846
1847        #[tokio::test]
1848        async fn handoff_layer_error_handling() {
1849            let policy = explicit_handoff_to("specialist");
1850            let layer = HandoffLayer::new(policy);
1851            let mut service = layer.layer(mock_tool_service());
1852
1853            // Try to call handoff tool with invalid arguments
1854            let invocation = ToolInvocation {
1855                id: "test_id".to_string(),
1856                name: "handoff_to_specialist".to_string(),
1857                arguments: serde_json::json!({}), // Missing required fields - should still work
1858            };
1859
1860            // Should succeed even with minimal arguments
1861            let result = ServiceExt::ready(&mut service)
1862                .await
1863                .unwrap()
1864                .call(invocation)
1865                .await
1866                .unwrap();
1867            assert_eq!(result.result["handoff"], "specialist");
1868        }
1869    }
1870
1871    // ================================================================================================
1872    // End-to-End Handoff Coordinator Tests
1873    // ================================================================================================
1874
1875    mod handoff_coordinator_tests {
1876        #![allow(deprecated)]
1877        use super::*;
1878        use crate::provider::{FixedProvider, ProviderResponse};
1879        use crate::validation::{gen, validate_conversation, ValidationPolicy};
1880        use crate::Agent;
1881        use async_openai::{config::OpenAIConfig, Client};
1882        use proptest::prelude::*;
1883        use std::sync::Arc;
1884        use tower::{service_fn, ServiceExt};
1885
1886        // Mock agents for testing
1887        fn mock_agent(name: &'static str, response: &'static str) -> AgentSvc {
1888            let name = name.to_string();
1889            let response = response.to_string();
1890            tower::util::BoxService::new(service_fn(move |_req: CreateChatCompletionRequest| {
1891                let name = name.clone();
1892                let response = response.clone();
1893                async move {
1894                    use async_openai::types::{
1895                        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
1896                    };
1897
1898                    let message = ChatCompletionRequestAssistantMessageArgs::default()
1899                        .content(format!("[{}]: {}", name, response))
1900                        .build()?;
1901
1902                    Ok::<AgentRun, BoxError>(AgentRun {
1903                        messages: vec![ChatCompletionRequestMessage::Assistant(message)],
1904                        steps: 1,
1905                        stop: AgentStopReason::DoneNoToolCalls,
1906                    })
1907                }
1908            }))
1909        }
1910
1911        // Mock picker that selects based on message content
1912        #[derive(Clone)]
1913        struct MockPicker;
1914
1915        impl Service<PickRequest> for MockPicker {
1916            type Response = String;
1917            type Error = BoxError;
1918            type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1919
1920            fn poll_ready(
1921                &mut self,
1922                _: &mut std::task::Context<'_>,
1923            ) -> std::task::Poll<Result<(), Self::Error>> {
1924                std::task::Poll::Ready(Ok(()))
1925            }
1926
1927            fn call(&mut self, req: PickRequest) -> Self::Future {
1928                Box::pin(async move {
1929                    let content = req.messages.first()
1930                    .and_then(|msg| match msg {
1931                        async_openai::types::ChatCompletionRequestMessage::User(user_msg) => {
1932                            match &user_msg.content {
1933                                async_openai::types::ChatCompletionRequestUserMessageContent::Text(text) => {
1934                                    Some(text.as_str())
1935                                }
1936                                _ => None,
1937                            }
1938                        }
1939                        _ => None,
1940                    })
1941                    .unwrap_or("");
1942
1943                    let agent = if content.contains("billing") {
1944                        "billing_agent"
1945                    } else if content.contains("technical") {
1946                        "tech_agent"
1947                    } else {
1948                        "triage_agent"
1949                    };
1950
1951                    Ok::<String, BoxError>(agent.to_string())
1952                })
1953            }
1954        }
1955
1956        fn mock_picker() -> MockPicker {
1957            MockPicker
1958        }
1959
1960        #[tokio::test]
1961        async fn handoff_coordinator_basic_operation() -> Result<(), BoxError> {
1962            let coordinator = GroupBuilder::new()
1963                .agent(
1964                    "triage_agent",
1965                    mock_agent("triage", "I'll handle your request"),
1966                )
1967                .agent(
1968                    "billing_agent",
1969                    mock_agent("billing", "Billing issue resolved"),
1970                )
1971                .picker(mock_picker())
1972                .handoff_policy(explicit_handoff_to("billing_agent"))
1973                .build();
1974
1975            let mut service = coordinator;
1976
1977            use async_openai::types::{
1978                ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
1979            };
1980            let user_message = ChatCompletionRequestUserMessageArgs::default()
1981                .content("I have a general question")
1982                .build()?;
1983
1984            let request = CreateChatCompletionRequest {
1985                messages: vec![ChatCompletionRequestMessage::User(user_message)],
1986                model: "gpt-4o".to_string(),
1987                ..Default::default()
1988            };
1989
1990            let result = ServiceExt::ready(&mut service).await?.call(request).await?;
1991
1992            // Should route to triage_agent and return its response
1993            assert_eq!(result.messages.len(), 1);
1994            assert!(
1995                format!("{:?}", result.messages[0]).contains("[triage]: I'll handle your request")
1996            );
1997            assert_eq!(result.steps, 1);
1998            let policy = ValidationPolicy {
1999                allow_repeated_roles: true,
2000                require_user_first: false,
2001                require_user_present: false,
2002                ..Default::default()
2003            };
2004            assert!(validate_conversation(&result.messages, &policy).is_none());
2005
2006            Ok(())
2007        }
2008
2009        #[tokio::test]
2010        async fn handoff_coordinator_sequential_workflow() -> Result<(), BoxError> {
2011            let sequential_policy = sequential_handoff(vec![
2012                "researcher".to_string(),
2013                "writer".to_string(),
2014                "reviewer".to_string(),
2015            ]);
2016
2017            #[derive(Clone)]
2018            struct ResearchPicker;
2019            impl Service<PickRequest> for ResearchPicker {
2020                type Response = String;
2021                type Error = BoxError;
2022                type Future =
2023                    Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2024
2025                fn poll_ready(
2026                    &mut self,
2027                    _: &mut std::task::Context<'_>,
2028                ) -> std::task::Poll<Result<(), Self::Error>> {
2029                    std::task::Poll::Ready(Ok(()))
2030                }
2031
2032                fn call(&mut self, _req: PickRequest) -> Self::Future {
2033                    Box::pin(async move { Ok::<String, BoxError>("researcher".to_string()) })
2034                }
2035            }
2036
2037            let coordinator = GroupBuilder::new()
2038                .agent("researcher", mock_agent("researcher", "Research complete"))
2039                .agent("writer", mock_agent("writer", "Article written"))
2040                .agent("reviewer", mock_agent("reviewer", "Review complete"))
2041                .picker(ResearchPicker)
2042                .handoff_policy(sequential_policy)
2043                .build();
2044
2045            let mut service = coordinator;
2046
2047            use async_openai::types::{
2048                ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
2049            };
2050            let user_message = ChatCompletionRequestUserMessageArgs::default()
2051                .content("Write an article about AI")
2052                .build()?;
2053
2054            let request = CreateChatCompletionRequest {
2055                messages: vec![ChatCompletionRequestMessage::User(user_message)],
2056                model: "gpt-4o".to_string(),
2057                ..Default::default()
2058            };
2059
2060            let result = ServiceExt::ready(&mut service).await?.call(request).await?;
2061
2062            // Should start with researcher and proceed through the sequence
2063            assert!(!result.messages.is_empty());
2064            // The sequential policy doesn't have automatic handoff implemented in our mock agents
2065            // So it should just return the researcher's response
2066            assert!(format!("{:?}", result.messages[0]).contains("[researcher]: Research complete"));
2067            let policy = ValidationPolicy {
2068                allow_repeated_roles: true,
2069                require_user_first: false,
2070                require_user_present: false,
2071                ..Default::default()
2072            };
2073            assert!(validate_conversation(&result.messages, &policy).is_none());
2074
2075            Ok(())
2076        }
2077
2078        #[tokio::test]
2079        async fn handoff_coordinator_picker_routing() -> Result<(), BoxError> {
2080            let coordinator = GroupBuilder::new()
2081                .agent("triage_agent", mock_agent("triage", "General help"))
2082                .agent("billing_agent", mock_agent("billing", "Billing help"))
2083                .agent("tech_agent", mock_agent("tech", "Technical help"))
2084                .picker(mock_picker())
2085                .handoff_policy(explicit_handoff_to("tech_agent"))
2086                .build();
2087
2088            let mut service = coordinator;
2089
2090            // Test billing routing
2091            use async_openai::types::{
2092                ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
2093            };
2094            let billing_message = ChatCompletionRequestUserMessageArgs::default()
2095                .content("I have a billing question")
2096                .build()?;
2097
2098            let billing_request = CreateChatCompletionRequest {
2099                messages: vec![ChatCompletionRequestMessage::User(billing_message)],
2100                model: "gpt-4o".to_string(),
2101                ..Default::default()
2102            };
2103
2104            let result = ServiceExt::ready(&mut service)
2105                .await?
2106                .call(billing_request)
2107                .await?;
2108
2109            // Should route directly to billing_agent
2110            assert!(format!("{:?}", result.messages[0]).contains("billing"));
2111            let policy = ValidationPolicy {
2112                allow_repeated_roles: true,
2113                require_user_first: false,
2114                require_user_present: false,
2115                ..Default::default()
2116            };
2117            assert!(validate_conversation(&result.messages, &policy).is_none());
2118
2119            // Test technical routing
2120            let tech_message = ChatCompletionRequestUserMessageArgs::default()
2121                .content("I have a technical issue")
2122                .build()?;
2123
2124            let tech_request = CreateChatCompletionRequest {
2125                messages: vec![ChatCompletionRequestMessage::User(tech_message)],
2126                model: "gpt-4o".to_string(),
2127                ..Default::default()
2128            };
2129
2130            let result2 = ServiceExt::ready(&mut service)
2131                .await?
2132                .call(tech_request)
2133                .await?;
2134
2135            // Should route directly to tech_agent
2136            assert!(format!("{:?}", result2.messages[0]).contains("tech"));
2137            let policy = ValidationPolicy {
2138                allow_repeated_roles: true,
2139                require_user_first: false,
2140                require_user_present: false,
2141                ..Default::default()
2142            };
2143            assert!(validate_conversation(&result2.messages, &policy).is_none());
2144
2145            Ok(())
2146        }
2147
2148        #[tokio::test]
2149        async fn handoff_coordinator_composite_policy() -> Result<(), BoxError> {
2150            let composite_policy = composite_handoff(vec![
2151                AnyHandoffPolicy::Explicit(explicit_handoff_to("specialist")),
2152                AnyHandoffPolicy::Sequential(sequential_handoff(vec![
2153                    "a".to_string(),
2154                    "b".to_string(),
2155                ])),
2156            ]);
2157
2158            #[derive(Clone)]
2159            struct TriagePicker;
2160            impl Service<PickRequest> for TriagePicker {
2161                type Response = String;
2162                type Error = BoxError;
2163                type Future =
2164                    Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2165
2166                fn poll_ready(
2167                    &mut self,
2168                    _: &mut std::task::Context<'_>,
2169                ) -> std::task::Poll<Result<(), Self::Error>> {
2170                    std::task::Poll::Ready(Ok(()))
2171                }
2172
2173                fn call(&mut self, _req: PickRequest) -> Self::Future {
2174                    Box::pin(async move { Ok::<String, BoxError>("triage".to_string()) })
2175                }
2176            }
2177
2178            let coordinator = GroupBuilder::new()
2179                .agent("triage", mock_agent("triage", "Triage response"))
2180                .agent(
2181                    "specialist",
2182                    mock_agent("specialist", "Specialist response"),
2183                )
2184                .agent("a", mock_agent("a", "Agent A response"))
2185                .agent("b", mock_agent("b", "Agent B response"))
2186                .picker(TriagePicker)
2187                .handoff_policy(composite_policy)
2188                .build();
2189
2190            // Verify handoff tools are exposed
2191            let tools = coordinator.handoff_tools();
2192            assert_eq!(tools.len(), 1); // Only explicit policy generates tools
2193            assert_eq!(tools[0].function.name, "handoff_to_specialist");
2194
2195            let mut service = coordinator;
2196
2197            use async_openai::types::{
2198                ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
2199            };
2200            let user_message = ChatCompletionRequestUserMessageArgs::default()
2201                .content("Help me")
2202                .build()?;
2203
2204            let request = CreateChatCompletionRequest {
2205                messages: vec![ChatCompletionRequestMessage::User(user_message)],
2206                model: "gpt-4o".to_string(),
2207                ..Default::default()
2208            };
2209
2210            let result = ServiceExt::ready(&mut service).await?.call(request).await?;
2211
2212            // Should execute triage agent
2213            assert!(!result.messages.is_empty());
2214            assert!(format!("{:?}", result.messages[0]).contains("[triage]: Triage response"));
2215            let policy = ValidationPolicy {
2216                allow_repeated_roles: true,
2217                require_user_first: false,
2218                require_user_present: false,
2219                ..Default::default()
2220            };
2221            assert!(validate_conversation(&result.messages, &policy).is_none());
2222
2223            Ok(())
2224        }
2225
2226        #[tokio::test]
2227        async fn handoff_coordinator_error_handling() -> Result<(), BoxError> {
2228            #[derive(Clone)]
2229            struct AgentAPicker;
2230            impl Service<PickRequest> for AgentAPicker {
2231                type Response = String;
2232                type Error = BoxError;
2233                type Future =
2234                    Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2235
2236                fn poll_ready(
2237                    &mut self,
2238                    _: &mut std::task::Context<'_>,
2239                ) -> std::task::Poll<Result<(), Self::Error>> {
2240                    std::task::Poll::Ready(Ok(()))
2241                }
2242
2243                fn call(&mut self, _req: PickRequest) -> Self::Future {
2244                    Box::pin(async move { Ok::<String, BoxError>("agent_a".to_string()) })
2245                }
2246            }
2247
2248            let coordinator = GroupBuilder::new()
2249                .agent("agent_a", mock_agent("a", "Response A"))
2250                .picker(AgentAPicker)
2251                .handoff_policy(explicit_handoff_to("nonexistent_agent"))
2252                .build();
2253
2254            let mut service = coordinator;
2255
2256            use async_openai::types::{
2257                ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
2258            };
2259            let user_message = ChatCompletionRequestUserMessageArgs::default()
2260                .content("Test message")
2261                .build()?;
2262
2263            let request = CreateChatCompletionRequest {
2264                messages: vec![ChatCompletionRequestMessage::User(user_message)],
2265                model: "gpt-4o".to_string(),
2266                ..Default::default()
2267            };
2268
2269            // Should complete without handoff since explicit policy doesn't auto-handoff
2270            let result = ServiceExt::ready(&mut service).await?.call(request).await?;
2271
2272            assert_eq!(result.messages.len(), 1);
2273            assert!(format!("{:?}", result.messages[0]).contains("[a]: Response A"));
2274
2275            Ok(())
2276        }
2277
2278        #[tokio::test]
2279        async fn handoff_coordinator_max_handoffs_protection() -> Result<(), BoxError> {
2280            // Create a policy that always hands off to create infinite loop
2281            struct InfiniteHandoffPolicy;
2282
2283            impl HandoffPolicy for InfiniteHandoffPolicy {
2284                fn handoff_tools(&self) -> Vec<ChatCompletionTool> {
2285                    vec![]
2286                }
2287                fn handle_handoff_tool(
2288                    &self,
2289                    _: &ToolInvocation,
2290                ) -> Result<HandoffRequest, BoxError> {
2291                    Err("No tools".into())
2292                }
2293                fn should_handoff(&self, _: &LoopState, _: &StepOutcome) -> Option<HandoffRequest> {
2294                    Some(HandoffRequest {
2295                        target_agent: "agent_b".to_string(),
2296                        context: None,
2297                        reason: Some("Infinite handoff test".to_string()),
2298                    })
2299                }
2300                fn is_handoff_tool(&self, _: &str) -> bool {
2301                    false
2302                }
2303            }
2304
2305            impl Clone for InfiniteHandoffPolicy {
2306                fn clone(&self) -> Self {
2307                    InfiniteHandoffPolicy
2308                }
2309            }
2310
2311            #[derive(Clone)]
2312            struct StartPicker;
2313            impl Service<PickRequest> for StartPicker {
2314                type Response = String;
2315                type Error = BoxError;
2316                type Future =
2317                    Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2318
2319                fn poll_ready(
2320                    &mut self,
2321                    _: &mut std::task::Context<'_>,
2322                ) -> std::task::Poll<Result<(), Self::Error>> {
2323                    std::task::Poll::Ready(Ok(()))
2324                }
2325
2326                fn call(&mut self, _req: PickRequest) -> Self::Future {
2327                    Box::pin(async move { Ok::<String, BoxError>("agent_a".to_string()) })
2328                }
2329            }
2330
2331            let coordinator = GroupBuilder::new()
2332                .agent("agent_a", mock_agent("a", "Response A"))
2333                .agent("agent_b", mock_agent("b", "Response B"))
2334                .picker(StartPicker)
2335                .handoff_policy(InfiniteHandoffPolicy)
2336                .build();
2337
2338            let mut service = coordinator;
2339
2340            use async_openai::types::{
2341                ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
2342            };
2343            let user_message = ChatCompletionRequestUserMessageArgs::default()
2344                .content("Test infinite handoff protection")
2345                .build()?;
2346
2347            let request = CreateChatCompletionRequest {
2348                messages: vec![ChatCompletionRequestMessage::User(user_message)],
2349                model: "gpt-4o".to_string(),
2350                ..Default::default()
2351            };
2352
2353            let result = ServiceExt::ready(&mut service).await?.call(request).await;
2354
2355            // Should error due to max handoffs exceeded
2356            assert!(result.is_err());
2357            assert!(result
2358                .unwrap_err()
2359                .to_string()
2360                .contains("Maximum handoffs exceeded"));
2361
2362            Ok(())
2363        }
2364
2365        #[tokio::test]
2366        async fn handoff_coordinator_explicit_tool_transition() -> Result<(), BoxError> {
2367            // Build two agents with custom providers (no network)
2368            // Agent "triage" emits a handoff tool call to "specialist"
2369            let tool_name = "handoff_to_specialist".to_string();
2370            let tc = async_openai::types::ChatCompletionMessageToolCall {
2371                id: "call_1".to_string(),
2372                r#type: async_openai::types::ChatCompletionToolType::Function,
2373                function: async_openai::types::FunctionCall {
2374                    name: tool_name.clone(),
2375                    arguments: "{\"reason\":\"escalate\"}".to_string(),
2376                },
2377            };
2378            let assistant_triage = async_openai::types::ChatCompletionResponseMessage {
2379                content: None,
2380                role: async_openai::types::Role::Assistant,
2381                tool_calls: Some(vec![tc]),
2382                function_call: None,
2383                refusal: None,
2384                audio: None,
2385            };
2386            let triage_provider = FixedProvider::new(ProviderResponse {
2387                assistant: assistant_triage,
2388                prompt_tokens: 1,
2389                completion_tokens: 1,
2390            });
2391
2392            // Agent "specialist" returns a plain assistant message
2393            let assistant_specialist = async_openai::types::ChatCompletionResponseMessage {
2394                content: Some("[specialist]: done".to_string()),
2395                role: async_openai::types::Role::Assistant,
2396                tool_calls: None,
2397                function_call: None,
2398                refusal: None,
2399                audio: None,
2400            };
2401            let specialist_provider = FixedProvider::new(ProviderResponse {
2402                assistant: assistant_specialist,
2403                prompt_tokens: 1,
2404                completion_tokens: 1,
2405            });
2406
2407            let client = Arc::new(Client::<OpenAIConfig>::new());
2408            let triage_agent = Agent::builder(client.clone())
2409                .model("gpt-4o")
2410                .handoff_policy(explicit_handoff_to("specialist").into())
2411                .with_provider(triage_provider)
2412                .policy(crate::CompositePolicy::new(vec![
2413                    crate::core::policies::max_steps(2),
2414                ]))
2415                .build();
2416            let specialist_agent = Agent::builder(client.clone())
2417                .model("gpt-4o")
2418                .with_provider(specialist_provider)
2419                .policy(crate::CompositePolicy::new(vec![
2420                    crate::core::policies::max_steps(1),
2421                ]))
2422                .build();
2423
2424            #[derive(Clone)]
2425            struct Picker;
2426            impl Service<PickRequest> for Picker {
2427                type Response = String;
2428                type Error = BoxError;
2429                type Future =
2430                    Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2431                fn poll_ready(
2432                    &mut self,
2433                    _cx: &mut std::task::Context<'_>,
2434                ) -> std::task::Poll<Result<(), Self::Error>> {
2435                    std::task::Poll::Ready(Ok(()))
2436                }
2437                fn call(&mut self, _req: PickRequest) -> Self::Future {
2438                    Box::pin(async move { Ok::<_, BoxError>("triage".to_string()) })
2439                }
2440            }
2441
2442            let coordinator = GroupBuilder::new()
2443                .agent("triage", triage_agent)
2444                .agent("specialist", specialist_agent)
2445                .picker(Picker)
2446                .handoff_policy(explicit_handoff_to("specialist"))
2447                .build();
2448
2449            let mut svc = coordinator;
2450            let user = async_openai::types::ChatCompletionRequestUserMessageArgs::default()
2451                .content("start")
2452                .build()?;
2453            let req = CreateChatCompletionRequest {
2454                messages: vec![async_openai::types::ChatCompletionRequestMessage::User(
2455                    user,
2456                )],
2457                model: "gpt-4o".to_string(),
2458                ..Default::default()
2459            };
2460
2461            let out = ServiceExt::ready(&mut svc).await?.call(req).await?;
2462            // Should include specialist message at the end after handoff
2463            assert!(format!("{:?}", out.messages).contains("[specialist]: done"));
2464            let policy = ValidationPolicy {
2465                allow_repeated_roles: true,
2466                ..Default::default()
2467            };
2468            assert!(validate_conversation(&out.messages, &policy).is_none());
2469            Ok(())
2470        }
2471
2472        proptest! {
2473            #[test]
2474            fn handoff_preserves_validity_for_valid_inputs(msgs in gen::valid_conversation(gen::GeneratorConfig::default())) {
2475                use async_openai::types::CreateChatCompletionRequest;
2476                // Build two agents with deterministic providers
2477                let tool_name = "handoff_to_specialist".to_string();
2478                let tc = async_openai::types::ChatCompletionMessageToolCall {
2479                    id: "call_prop".to_string(),
2480                    r#type: async_openai::types::ChatCompletionToolType::Function,
2481                    function: async_openai::types::FunctionCall {
2482                        name: tool_name.clone(),
2483                        arguments: "{\"reason\":\"pbt\"}".to_string(),
2484                    },
2485                };
2486                let assistant_triage = async_openai::types::ChatCompletionResponseMessage {
2487                    content: None,
2488                    role: async_openai::types::Role::Assistant,
2489                    tool_calls: Some(vec![tc]),
2490                    function_call: None,
2491                    refusal: None,
2492                    audio: None,
2493                };
2494                let triage_provider = FixedProvider::new(ProviderResponse {
2495                    assistant: assistant_triage,
2496                    prompt_tokens: 1,
2497                    completion_tokens: 1,
2498                });
2499
2500                let assistant_specialist = async_openai::types::ChatCompletionResponseMessage {
2501                    content: Some("[specialist]: ok".to_string()),
2502                    role: async_openai::types::Role::Assistant,
2503                    tool_calls: None,
2504                    function_call: None,
2505                    refusal: None,
2506                    audio: None,
2507                };
2508                let specialist_provider = FixedProvider::new(ProviderResponse {
2509                    assistant: assistant_specialist,
2510                    prompt_tokens: 1,
2511                    completion_tokens: 1,
2512                });
2513
2514                let client = Arc::new(Client::<OpenAIConfig>::new());
2515                let triage_agent = Agent::builder(client.clone())
2516                    .model("gpt-4o")
2517                    .handoff_policy(explicit_handoff_to("specialist").into())
2518                    .with_provider(triage_provider)
2519                    .policy(crate::CompositePolicy::new(vec![crate::core::policies::max_steps(2)]))
2520                    .build();
2521                let specialist_agent = Agent::builder(client.clone())
2522                    .model("gpt-4o")
2523                    .with_provider(specialist_provider)
2524                    .policy(crate::CompositePolicy::new(vec![crate::core::policies::max_steps(1)]))
2525                    .build();
2526
2527                #[derive(Clone)]
2528                struct Picker;
2529                impl Service<PickRequest> for Picker {
2530                    type Response = String;
2531                    type Error = BoxError;
2532                    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
2533                    fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> { std::task::Poll::Ready(Ok(())) }
2534                    fn call(&mut self, _req: PickRequest) -> Self::Future { Box::pin(async move { Ok::<_, BoxError>("triage".to_string()) }) }
2535                }
2536
2537                let coordinator = GroupBuilder::new()
2538                    .agent("triage", triage_agent)
2539                    .agent("specialist", specialist_agent)
2540                    .picker(Picker)
2541                    .handoff_policy(explicit_handoff_to("specialist"))
2542                    .build();
2543
2544                let mut svc = coordinator;
2545                let req = CreateChatCompletionRequest { messages: msgs, model: "gpt-4o".to_string(), ..Default::default() };
2546                let result = futures::executor::block_on(async move { ServiceExt::ready(&mut svc).await.unwrap().call(req).await }).unwrap();
2547                let policy = ValidationPolicy { allow_repeated_roles: true, allow_system_anywhere: true, require_user_first: false, require_user_present: false, ..Default::default() };
2548                prop_assert!(validate_conversation(&result.messages, &policy).is_none());
2549            }
2550        }
2551    }
2552
2553    // ================================================================================================
2554    // Original Group Router Tests
2555    // ================================================================================================
2556
2557    #[tokio::test]
2558    async fn routes_to_named_agent() {
2559        let a: AgentSvc = BoxService::new(tower::service_fn(
2560            |_r: CreateChatCompletionRequest| async move {
2561                Ok::<_, BoxError>(AgentRun {
2562                    messages: vec![],
2563                    steps: 1,
2564                    stop: AgentStopReason::DoneNoToolCalls,
2565                })
2566            },
2567        ));
2568        let picker =
2569            tower::service_fn(|_pr: PickRequest| async move { Ok::<_, BoxError>("a".to_string()) });
2570        let router = GroupBuilder::new().agent("a", a).picker(picker).build();
2571        let mut svc = router;
2572        let req = CreateChatCompletionRequestArgs::default()
2573            .model("gpt-4o")
2574            .messages(vec![])
2575            .build()
2576            .unwrap();
2577        let run = ServiceExt::ready(&mut svc)
2578            .await
2579            .unwrap()
2580            .call(req)
2581            .await
2582            .unwrap();
2583        assert_eq!(run.steps, 1);
2584    }
2585}