1use std::collections::HashSet;
8use std::sync::Arc;
9
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use tokio_util::sync::CancellationToken;
14
15use crate::registry::AgentRegistry;
16use crate::tool::{AgentTool, AgentToolResult, ToolFuture, validated_schema_for};
17use crate::types::LlmMessage;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct TransferSignal {
29 target_agent: String,
30 reason: String,
31 #[serde(default, skip_serializing_if = "Option::is_none")]
32 context_summary: Option<String>,
33 #[serde(default)]
34 conversation_history: Vec<LlmMessage>,
35 #[serde(default, skip_serializing_if = "Option::is_none")]
36 transfer_chain: Option<TransferChain>,
37}
38
39impl TransferSignal {
40 pub fn new(target_agent: impl Into<String>, reason: impl Into<String>) -> Self {
42 Self {
43 target_agent: target_agent.into(),
44 reason: reason.into(),
45 context_summary: None,
46 conversation_history: Vec::new(),
47 transfer_chain: None,
48 }
49 }
50
51 #[must_use]
53 pub fn with_context_summary(mut self, summary: impl Into<String>) -> Self {
54 self.context_summary = Some(summary.into());
55 self
56 }
57
58 #[must_use]
63 pub fn with_conversation_history(mut self, history: Vec<LlmMessage>) -> Self {
64 self.conversation_history = history;
65 self
66 }
67
68 #[must_use]
73 pub fn with_transfer_chain(mut self, chain: TransferChain) -> Self {
74 self.transfer_chain = Some(chain);
75 self
76 }
77
78 pub fn target_agent(&self) -> &str {
80 &self.target_agent
81 }
82
83 pub fn reason(&self) -> &str {
85 &self.reason
86 }
87
88 pub fn context_summary(&self) -> Option<&str> {
90 self.context_summary.as_deref()
91 }
92
93 pub fn conversation_history(&self) -> &[LlmMessage] {
95 &self.conversation_history
96 }
97
98 pub const fn transfer_chain(&self) -> Option<&TransferChain> {
100 self.transfer_chain.as_ref()
101 }
102}
103
104#[derive(Debug, Clone, thiserror::Error)]
108pub enum TransferError {
109 #[error("circular transfer detected: agent '{agent_name}' already in chain {chain:?}")]
111 CircularTransfer {
112 agent_name: String,
113 chain: Vec<String>,
114 },
115 #[error("max transfer depth exceeded: depth {depth} >= max {max}")]
117 MaxDepthExceeded { depth: usize, max: usize },
118}
119
120#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct TransferChain {
128 chain: Vec<String>,
129 max_depth: usize,
130}
131
132impl TransferChain {
133 pub const fn new(max_depth: usize) -> Self {
135 Self {
136 chain: Vec::new(),
137 max_depth,
138 }
139 }
140
141 pub fn push(&mut self, agent_name: impl Into<String>) -> Result<(), TransferError> {
146 let name = agent_name.into();
147 if self.chain.contains(&name) {
148 return Err(TransferError::CircularTransfer {
149 agent_name: name,
150 chain: self.chain.clone(),
151 });
152 }
153 if self.chain.len() >= self.max_depth {
154 return Err(TransferError::MaxDepthExceeded {
155 depth: self.chain.len(),
156 max: self.max_depth,
157 });
158 }
159 self.chain.push(name);
160 Ok(())
161 }
162
163 pub const fn depth(&self) -> usize {
165 self.chain.len()
166 }
167
168 pub fn contains(&self, agent_name: &str) -> bool {
170 self.chain.iter().any(|n| n == agent_name)
171 }
172
173 pub fn chain(&self) -> &[String] {
175 &self.chain
176 }
177}
178
179impl Default for TransferChain {
180 fn default() -> Self {
181 Self::new(5)
182 }
183}
184
185#[derive(Deserialize, JsonSchema)]
189#[schemars(deny_unknown_fields)]
190struct TransferParams {
191 agent_name: String,
193 reason: String,
195 context_summary: Option<String>,
197}
198
199pub struct TransferToAgentTool {
207 registry: Arc<AgentRegistry>,
208 allowed_targets: Option<HashSet<String>>,
209 schema: Value,
210}
211
212#[allow(dead_code)]
215impl TransferToAgentTool {
216 pub fn new(registry: Arc<AgentRegistry>) -> Self {
218 Self {
219 registry,
220 allowed_targets: None,
221 schema: validated_schema_for::<TransferParams>(),
222 }
223 }
224
225 pub fn with_allowed_targets(
227 registry: Arc<AgentRegistry>,
228 targets: impl IntoIterator<Item = impl Into<String>>,
229 ) -> Self {
230 Self {
231 registry,
232 allowed_targets: Some(targets.into_iter().map(Into::into).collect()),
233 schema: validated_schema_for::<TransferParams>(),
234 }
235 }
236}
237
238impl AgentTool for TransferToAgentTool {
239 #[allow(clippy::unnecessary_literal_bound)]
240 fn name(&self) -> &str {
241 "transfer_to_agent"
242 }
243
244 #[allow(clippy::unnecessary_literal_bound)]
245 fn label(&self) -> &str {
246 "Transfer to Agent"
247 }
248
249 #[allow(clippy::unnecessary_literal_bound)]
250 fn description(&self) -> &str {
251 "Transfer the conversation to another agent. Use when the user's request \
252 is better handled by a different specialist agent."
253 }
254
255 fn parameters_schema(&self) -> &Value {
256 &self.schema
257 }
258
259 fn execute(
260 &self,
261 _tool_call_id: &str,
262 params: Value,
263 cancellation_token: CancellationToken,
264 _on_update: Option<Box<dyn Fn(AgentToolResult) + Send + Sync>>,
265 _state: std::sync::Arc<std::sync::RwLock<crate::SessionState>>,
266 _credential: Option<crate::credential::ResolvedCredential>,
267 ) -> ToolFuture<'_> {
268 Box::pin(async move {
269 let parsed: TransferParams = match serde_json::from_value(params) {
270 Ok(p) => p,
271 Err(e) => return AgentToolResult::error(format!("invalid parameters: {e}")),
272 };
273
274 if cancellation_token.is_cancelled() {
275 return AgentToolResult::error("cancelled");
276 }
277
278 if let Some(ref allowed) = self.allowed_targets
280 && !allowed.contains(&parsed.agent_name)
281 {
282 let mut sorted: Vec<&String> = allowed.iter().collect();
283 sorted.sort();
284 return AgentToolResult::error(format!(
285 "transfer to '{}' not allowed. Allowed targets: {sorted:?}",
286 parsed.agent_name,
287 ));
288 }
289
290 if self.registry.get(&parsed.agent_name).is_none() {
292 return AgentToolResult::error(format!(
293 "agent '{}' not found in registry",
294 parsed.agent_name
295 ));
296 }
297
298 let mut signal = TransferSignal::new(&parsed.agent_name, &parsed.reason);
300 if let Some(summary) = parsed.context_summary {
301 signal = signal.with_context_summary(summary);
302 }
303
304 AgentToolResult::transfer(signal)
305 })
306 }
307}
308
309const _: () = {
312 const fn assert_send_sync<T: Send + Sync>() {}
313 assert_send_sync::<TransferSignal>();
314 assert_send_sync::<TransferChain>();
315 assert_send_sync::<TransferError>();
316 assert_send_sync::<TransferToAgentTool>();
317};
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
326 fn transfer_signal_new_sets_target_and_reason() {
327 let signal = TransferSignal::new("billing", "billing issue");
328 assert_eq!(signal.target_agent(), "billing");
329 assert_eq!(signal.reason(), "billing issue");
330 assert_eq!(signal.context_summary(), None);
331 assert!(signal.conversation_history().is_empty());
332 assert!(signal.transfer_chain().is_none());
333 }
334
335 #[test]
336 fn transfer_signal_with_context_summary() {
337 let signal = TransferSignal::new("billing", "billing issue")
338 .with_context_summary("User has a $50 charge they dispute");
339 assert_eq!(
340 signal.context_summary(),
341 Some("User has a $50 charge they dispute")
342 );
343 }
344
345 #[test]
346 fn transfer_signal_with_conversation_history() {
347 use crate::types::{ContentBlock, UserMessage};
348
349 let msg = LlmMessage::User(UserMessage {
350 content: vec![ContentBlock::Text {
351 text: "hello".into(),
352 }],
353 timestamp: 0,
354 cache_hint: None,
355 });
356 let signal = TransferSignal::new("tech", "tech issue").with_conversation_history(vec![msg]);
357 assert_eq!(signal.conversation_history().len(), 1);
358 }
359
360 #[test]
361 fn transfer_signal_serde_roundtrip() {
362 let mut chain = TransferChain::new(3);
363 chain.push("support").unwrap();
364 chain.push("billing").unwrap();
365 let signal = TransferSignal::new("billing", "billing issue")
366 .with_context_summary("User disputes charge")
367 .with_transfer_chain(chain);
368 let json = serde_json::to_string(&signal).unwrap();
369 let parsed: TransferSignal = serde_json::from_str(&json).unwrap();
370 assert_eq!(parsed.target_agent(), "billing");
371 assert_eq!(parsed.reason(), "billing issue");
372 assert_eq!(parsed.context_summary(), Some("User disputes charge"));
373 assert!(parsed.conversation_history().is_empty());
374 let chain = parsed.transfer_chain().expect("expected transfer chain");
375 assert_eq!(chain.chain(), &["support", "billing"]);
376 }
377
378 #[test]
379 fn transfer_signal_deserialize_without_optional_fields() {
380 let json = r#"{"target_agent":"billing","reason":"billing issue"}"#;
381 let parsed: TransferSignal = serde_json::from_str(json).unwrap();
382 assert_eq!(parsed.target_agent(), "billing");
383 assert_eq!(parsed.reason(), "billing issue");
384 assert_eq!(parsed.context_summary(), None);
385 assert!(parsed.conversation_history().is_empty());
386 assert!(parsed.transfer_chain().is_none());
387 }
388
389 #[test]
390 fn transfer_signal_serde_skips_none_context_summary() {
391 let signal = TransferSignal::new("billing", "billing issue");
392 let json = serde_json::to_value(&signal).unwrap();
393 assert!(!json.as_object().unwrap().contains_key("context_summary"));
394 assert!(!json.as_object().unwrap().contains_key("transfer_chain"));
395 }
396
397 #[test]
398 fn transfer_signal_builder_chain() {
399 let signal = TransferSignal::new("target", "reason")
400 .with_context_summary("summary")
401 .with_conversation_history(vec![]);
402 assert_eq!(signal.target_agent(), "target");
403 assert_eq!(signal.reason(), "reason");
404 assert_eq!(signal.context_summary(), Some("summary"));
405 assert!(signal.conversation_history().is_empty());
406 }
407
408 #[test]
410 fn transfer_chain_rejects_circular() {
411 let mut chain = TransferChain::default();
412 chain.push("agent-a").unwrap();
413 chain.push("agent-b").unwrap();
414 let err = chain.push("agent-a").unwrap_err();
415 assert!(matches!(err, TransferError::CircularTransfer { .. }));
416 }
417
418 #[test]
420 fn transfer_chain_rejects_max_depth() {
421 let mut chain = TransferChain::new(2);
422 chain.push("a").unwrap();
423 chain.push("b").unwrap();
424 let err = chain.push("c").unwrap_err();
425 assert!(matches!(
426 err,
427 TransferError::MaxDepthExceeded { depth: 2, max: 2 }
428 ));
429 }
430
431 #[test]
433 fn transfer_chain_allows_new_agent() {
434 let mut chain = TransferChain::default();
435 assert!(chain.push("agent-a").is_ok());
436 assert!(chain.push("agent-b").is_ok());
437 assert!(chain.push("agent-c").is_ok());
438 }
439
440 #[test]
442 fn transfer_chain_default_max_depth() {
443 let mut chain = TransferChain::default();
444 for i in 0..5 {
446 chain.push(format!("agent-{i}")).unwrap();
447 }
448 let err = chain.push("agent-5").unwrap_err();
450 assert!(matches!(err, TransferError::MaxDepthExceeded { .. }));
451 }
452
453 #[test]
455 fn transfer_chain_contains_and_depth() {
456 let mut chain = TransferChain::default();
457 assert_eq!(chain.depth(), 0);
458 assert!(!chain.contains("a"));
459
460 chain.push("a").unwrap();
461 assert_eq!(chain.depth(), 1);
462 assert!(chain.contains("a"));
463 assert!(!chain.contains("b"));
464
465 chain.push("b").unwrap();
466 assert_eq!(chain.depth(), 2);
467 assert!(chain.contains("b"));
468 assert_eq!(chain.chain(), &["a", "b"]);
469 }
470
471 #[test]
473 fn transfer_chain_self_transfer_is_circular() {
474 let mut chain = TransferChain::default();
475 chain.push("support").unwrap();
476 let err = chain.push("support").unwrap_err();
478 assert!(
479 matches!(err, TransferError::CircularTransfer { agent_name, .. } if agent_name == "support")
480 );
481 }
482
483 #[test]
485 fn transfer_signal_carries_full_context() {
486 let signal = TransferSignal::new("billing", "billing question")
487 .with_context_summary("User asked about invoice #123");
488 assert_eq!(signal.target_agent(), "billing");
489 assert_eq!(signal.reason(), "billing question");
490 assert_eq!(
491 signal.context_summary(),
492 Some("User asked about invoice #123")
493 );
494 }
495
496 #[cfg(feature = "testkit")]
499 mod transfer_tool_tests {
500 use super::*;
501 use crate::agent::{Agent, AgentOptions};
502 use crate::registry::AgentRegistry;
503 use crate::testing::SimpleMockStreamFn;
504 use crate::tool::AgentTool;
505 use crate::types::ModelSpec;
506 use tokio_util::sync::CancellationToken;
507
508 fn dummy_agent() -> Agent {
510 Agent::new(AgentOptions::new(
511 "test",
512 ModelSpec::new("test", "test-model"),
513 std::sync::Arc::new(SimpleMockStreamFn::from_text("hi")),
514 crate::agent::default_convert,
515 ))
516 }
517
518 #[tokio::test]
520 async fn transfer_tool_validates_target_and_returns_signal() {
521 let registry = std::sync::Arc::new(AgentRegistry::new());
522 registry.register("billing", dummy_agent());
523
524 let tool = TransferToAgentTool::new(registry);
525 let params = serde_json::json!({
526 "agent_name": "billing",
527 "reason": "billing question"
528 });
529
530 let result = tool
531 .execute(
532 "tc-1",
533 params,
534 CancellationToken::new(),
535 None,
536 std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::default())),
537 None,
538 )
539 .await;
540
541 assert!(!result.is_error);
542 assert!(result.is_transfer());
543 let signal = result.transfer_signal.unwrap();
544 assert_eq!(signal.target_agent(), "billing");
545 assert_eq!(signal.reason(), "billing question");
546 assert_eq!(signal.context_summary(), None);
547 assert!(signal.conversation_history().is_empty());
549 }
550
551 #[tokio::test]
553 async fn transfer_tool_target_not_found_returns_error() {
554 let registry = std::sync::Arc::new(AgentRegistry::new());
555 let tool = TransferToAgentTool::new(registry);
558 let params = serde_json::json!({
559 "agent_name": "nonexistent",
560 "reason": "test"
561 });
562
563 let result = tool
564 .execute(
565 "tc-1",
566 params,
567 CancellationToken::new(),
568 None,
569 std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::default())),
570 None,
571 )
572 .await;
573
574 assert!(result.is_error);
575 assert!(!result.is_transfer());
576 let text = &result.content[0];
577 match text {
578 crate::types::ContentBlock::Text { text } => {
579 assert!(
580 text.contains("not found in registry"),
581 "expected 'not found in registry', got: {text}"
582 );
583 }
584 _ => panic!("expected text content block"),
585 }
586 }
587
588 #[tokio::test]
590 async fn transfer_tool_includes_context_summary() {
591 let registry = std::sync::Arc::new(AgentRegistry::new());
592 registry.register("billing", dummy_agent());
593
594 let tool = TransferToAgentTool::new(registry);
595 let params = serde_json::json!({
596 "agent_name": "billing",
597 "reason": "billing dispute",
598 "context_summary": "User has a $50 charge they want to dispute"
599 });
600
601 let result = tool
602 .execute(
603 "tc-1",
604 params,
605 CancellationToken::new(),
606 None,
607 std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::default())),
608 None,
609 )
610 .await;
611
612 assert!(!result.is_error);
613 let signal = result.transfer_signal.unwrap();
614 assert_eq!(
615 signal.context_summary(),
616 Some("User has a $50 charge they want to dispute")
617 );
618 }
619
620 #[tokio::test]
622 async fn transfer_tool_result_text_format() {
623 let registry = std::sync::Arc::new(AgentRegistry::new());
624 registry.register("billing", dummy_agent());
625
626 let tool = TransferToAgentTool::new(registry);
627 let params = serde_json::json!({
628 "agent_name": "billing",
629 "reason": "billing question"
630 });
631
632 let result = tool
633 .execute(
634 "tc-1",
635 params,
636 CancellationToken::new(),
637 None,
638 std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::default())),
639 None,
640 )
641 .await;
642
643 let text = &result.content[0];
644 match text {
645 crate::types::ContentBlock::Text { text } => {
646 assert_eq!(text, "Transfer to billing initiated.");
647 }
648 _ => panic!("expected text content block"),
649 }
650 }
651
652 #[tokio::test]
654 async fn transfer_tool_allowed_targets_restricts() {
655 let registry = std::sync::Arc::new(AgentRegistry::new());
656 registry.register("billing", dummy_agent());
657 registry.register("tech", dummy_agent());
658
659 let tool = TransferToAgentTool::with_allowed_targets(registry, vec!["billing"]);
661 let params = serde_json::json!({
662 "agent_name": "tech",
663 "reason": "tech question"
664 });
665
666 let result = tool
667 .execute(
668 "tc-1",
669 params,
670 CancellationToken::new(),
671 None,
672 std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::default())),
673 None,
674 )
675 .await;
676
677 assert!(result.is_error);
678 let text = &result.content[0];
679 match text {
680 crate::types::ContentBlock::Text { text } => {
681 assert!(
682 text.contains("not allowed"),
683 "expected 'not allowed', got: {text}"
684 );
685 }
686 _ => panic!("expected text content block"),
687 }
688 }
689
690 #[tokio::test]
692 async fn transfer_tool_allowed_targets_permits() {
693 let registry = std::sync::Arc::new(AgentRegistry::new());
694 registry.register("billing", dummy_agent());
695
696 let tool = TransferToAgentTool::with_allowed_targets(registry, vec!["billing"]);
697 let params = serde_json::json!({
698 "agent_name": "billing",
699 "reason": "billing question"
700 });
701
702 let result = tool
703 .execute(
704 "tc-1",
705 params,
706 CancellationToken::new(),
707 None,
708 std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::default())),
709 None,
710 )
711 .await;
712
713 assert!(!result.is_error);
714 assert!(result.is_transfer());
715 }
716
717 #[tokio::test]
719 async fn transfer_tool_empty_allowed_targets_rejects_all() {
720 let registry = std::sync::Arc::new(AgentRegistry::new());
721 registry.register("billing", dummy_agent());
722
723 let tool = TransferToAgentTool::with_allowed_targets(
725 std::sync::Arc::clone(®istry),
726 std::iter::empty::<String>(),
727 );
728 let params = serde_json::json!({
729 "agent_name": "billing",
730 "reason": "test"
731 });
732
733 let result = tool
734 .execute(
735 "tc-1",
736 params,
737 CancellationToken::new(),
738 None,
739 std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::default())),
740 None,
741 )
742 .await;
743
744 assert!(result.is_error);
745 let text = &result.content[0];
746 match text {
747 crate::types::ContentBlock::Text { text } => {
748 assert!(
749 text.contains("not allowed"),
750 "expected 'not allowed', got: {text}"
751 );
752 }
753 _ => panic!("expected text content block"),
754 }
755 }
756
757 #[tokio::test]
759 async fn transfer_tool_respects_cancellation() {
760 let registry = std::sync::Arc::new(AgentRegistry::new());
761 registry.register("billing", dummy_agent());
762
763 let tool = TransferToAgentTool::new(registry);
764 let params = serde_json::json!({
765 "agent_name": "billing",
766 "reason": "test"
767 });
768
769 let token = CancellationToken::new();
770 token.cancel();
771
772 let result = tool
773 .execute(
774 "tc-1",
775 params,
776 token,
777 None,
778 std::sync::Arc::new(std::sync::RwLock::new(crate::SessionState::default())),
779 None,
780 )
781 .await;
782
783 assert!(result.is_error);
784 let text = &result.content[0];
785 match text {
786 crate::types::ContentBlock::Text { text } => {
787 assert_eq!(text, "cancelled");
788 }
789 _ => panic!("expected text content block"),
790 }
791 }
792 }
793}