vtcode_core/tools/handlers/
router.rs1use crate::config::constants::tools;
13use crate::types::CompactStr;
14use hashbrown::HashMap;
15use std::sync::Arc;
16
17use async_trait::async_trait;
18
19use crate::tools::tool_intent;
20
21use super::tool_handler::{
22 ConfiguredToolSpec, ToolCallError, ToolHandler, ToolInvocation, ToolKind, ToolOutput,
23 ToolPayload, ToolSession, ToolSpec, TurnContext,
24};
25
26#[derive(Clone, Debug)]
28pub struct ToolCall {
29 pub tool_name: String,
31 pub call_id: String,
33 pub payload: ToolPayload,
35}
36
37struct DispatchEntry {
38 canonical_name: String,
39 handler: Arc<dyn ToolHandler>,
40}
41
42pub struct DispatchRegistry {
44 handlers: HashMap<CompactStr, DispatchEntry>,
45}
46
47fn normalize_router_tool_name(tool_name: &str) -> Option<String> {
48 let lowered = tool_name.trim().to_ascii_lowercase();
49 if lowered.is_empty() {
50 return None;
51 }
52
53 let normalized = lowered
54 .replace([' ', '-'], "_")
55 .replace(['(', ')', '\'', '"'], "");
56
57 let mapped = match normalized.as_str() {
58 alias if tool_intent::canonical_unified_exec_tool_name(alias).is_some() => {
59 tools::UNIFIED_EXEC
60 }
61 "exec_code" | "run_code" | "run_command" | "run_command_pty" => tools::UNIFIED_EXEC,
62 "search_text" => tools::GREP_FILE,
63 tools::READ_FILE => tools::READ_FILE,
64 tools::WRITE_FILE => tools::WRITE_FILE,
65 tools::EDIT_FILE => tools::EDIT_FILE,
66 tools::LIST_FILES => tools::LIST_FILES,
67 _ => normalized.as_str(),
68 };
69
70 if mapped == lowered {
71 None
72 } else {
73 Some(mapped.to_string())
74 }
75}
76
77fn suggest_similar_tool_names(
78 requested_tool_name: &str,
79 handlers: &HashMap<CompactStr, DispatchEntry>,
80) -> Vec<String> {
81 let requested_lower = requested_tool_name.to_ascii_lowercase();
82 let normalized = normalize_router_tool_name(requested_tool_name).unwrap_or_default();
83
84 let mut available: Vec<CompactStr> = handlers.keys().cloned().collect();
85 available.sort_unstable();
86
87 available
88 .into_iter()
89 .filter(|candidate| {
90 let c: &str = candidate;
91 c.contains(&requested_lower)
92 || requested_lower.contains(c)
93 || (!normalized.is_empty() && (c.contains(&*normalized) || normalized.contains(c)))
94 })
95 .take(3)
96 .map(|c| c.to_string())
97 .collect()
98}
99
100impl DispatchRegistry {
101 pub fn new(handlers: HashMap<String, Arc<dyn ToolHandler>>) -> Self {
102 let handlers: HashMap<CompactStr, DispatchEntry> = handlers
103 .into_iter()
104 .map(|(name, handler)| {
105 (
106 CompactStr::from(name.clone()),
107 DispatchEntry {
108 canonical_name: name,
109 handler,
110 },
111 )
112 })
113 .collect();
114 Self { handlers }
115 }
116
117 pub fn handler(&self, name: &str) -> Option<Arc<dyn ToolHandler>> {
118 self.handlers.get(name).map(|entry| entry.handler.clone())
119 }
120
121 pub fn resolve_tool_name(&self, requested_name: &str) -> Result<&str, ToolCallError> {
122 self.resolve_entry(requested_name)
123 .map(|entry| entry.canonical_name.as_str())
124 }
125
126 pub async fn dispatch(&self, invocation: ToolInvocation) -> Result<ToolOutput, ToolCallError> {
128 let entry = self.resolve_entry(&invocation.tool_name)?;
129 let handler = &entry.handler;
130
131 if !handler.matches_kind(&invocation.payload) {
132 return Err(ToolCallError::respond(format!(
133 "Tool {} invoked with incompatible payload type",
134 invocation.tool_name
135 )));
136 }
137
138 handler.handle(invocation).await
139 }
140
141 fn resolve_entry(&self, requested_name: &str) -> Result<&DispatchEntry, ToolCallError> {
142 let normalized_name = normalize_router_tool_name(requested_name);
143 self.handlers
144 .get(requested_name)
145 .or_else(|| {
146 normalized_name
147 .as_deref()
148 .and_then(|candidate| self.handlers.get(candidate))
149 })
150 .ok_or_else(|| {
151 let suggested = suggest_similar_tool_names(requested_name, &self.handlers);
152 let normalized_hint = normalized_name
153 .as_deref()
154 .filter(|candidate| *candidate != requested_name)
155 .map(|candidate| format!(" Normalized as '{candidate}'."))
156 .unwrap_or_default();
157 let suggestion_hint = if suggested.is_empty() {
158 String::new()
159 } else {
160 format!(" Did you mean: {}?", suggested.join(", "))
161 };
162 ToolCallError::respond(format!(
163 "Unknown tool: {requested_name}.{normalized_hint}{suggestion_hint}"
164 ))
165 })
166 }
167}
168
169pub struct DispatchRegistryBuilder {
171 handlers: HashMap<CompactStr, DispatchEntry>,
172 specs: Vec<ConfiguredToolSpec>,
173}
174
175impl Default for DispatchRegistryBuilder {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181impl DispatchRegistryBuilder {
182 pub fn new() -> Self {
183 Self {
184 handlers: HashMap::new(),
185 specs: Vec::new(),
186 }
187 }
188
189 pub fn push_spec(&mut self, spec: ToolSpec) -> &mut Self {
191 self.push_spec_with_parallel_support(spec, false)
192 }
193
194 pub fn push_spec_with_parallel_support(
196 &mut self,
197 spec: ToolSpec,
198 supports_parallel_tool_calls: bool,
199 ) -> &mut Self {
200 self.specs
201 .push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
202 self
203 }
204
205 pub fn register_handler(
207 &mut self,
208 name: impl Into<String>,
209 handler: Arc<dyn ToolHandler>,
210 ) -> &mut Self {
211 let name = name.into();
212 self.register_route(name.clone(), name, handler)
213 }
214
215 pub fn register_route(
217 &mut self,
218 name: impl Into<String>,
219 canonical_name: impl Into<String>,
220 handler: Arc<dyn ToolHandler>,
221 ) -> &mut Self {
222 let name = name.into();
223 let canonical_name = canonical_name.into();
224 let previous = self.handlers.insert(
225 CompactStr::from(&*name),
226 DispatchEntry {
227 canonical_name: canonical_name.clone(),
228 handler: Arc::new(RouteAliasHandler {
229 canonical_name,
230 inner: handler,
231 }),
232 },
233 );
234 if previous.is_some() {
235 tracing::warn!("Overwriting handler for tool");
236 }
237 self
238 }
239
240 pub fn register_aliases(&mut self, names: &[&str], handler: Arc<dyn ToolHandler>) -> &mut Self {
242 for name in names {
243 self.register_handler((*name).to_string(), handler.clone());
244 }
245 self
246 }
247
248 pub fn build(self) -> (Vec<ConfiguredToolSpec>, DispatchRegistry) {
250 let registry = DispatchRegistry {
251 handlers: self.handlers,
252 };
253 (self.specs, registry)
254 }
255}
256
257pub struct ToolRouter {
264 registry: DispatchRegistry,
265 specs: Vec<ConfiguredToolSpec>,
266}
267
268impl ToolRouter {
269 pub fn from_builder(builder: DispatchRegistryBuilder) -> Self {
271 let (specs, registry) = builder.build();
272 Self { registry, specs }
273 }
274
275 pub fn specs(&self) -> Vec<ToolSpec> {
277 self.specs.iter().map(|c| c.spec.clone()).collect()
278 }
279
280 pub fn configured_specs(&self) -> &[ConfiguredToolSpec] {
282 &self.specs
283 }
284
285 pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
287 self.specs
288 .iter()
289 .filter(|c| c.supports_parallel_tool_calls)
290 .any(|c| c.spec.name() == tool_name)
291 }
292
293 pub fn resolve_tool_name(&self, tool_name: &str) -> Result<&str, ToolCallError> {
295 self.registry.resolve_tool_name(tool_name)
296 }
297
298 pub fn build_tool_call(
302 name: String,
303 call_id: String,
304 arguments: String,
305 mcp_prefix: Option<&str>,
306 ) -> Result<ToolCall, ToolCallError> {
307 if let Some(prefix) = mcp_prefix
309 && name.starts_with(prefix)
310 {
311 let parts: Vec<&str> = name.splitn(2, '/').collect();
312 if parts.len() == 2 {
313 return Ok(ToolCall {
314 tool_name: name.clone(),
315 call_id,
316 payload: ToolPayload::Mcp {
317 arguments: Some(serde_json::from_str(&arguments).unwrap_or_default()),
318 },
319 });
320 }
321 }
322
323 Ok(ToolCall {
325 tool_name: name,
326 call_id,
327 payload: ToolPayload::Function { arguments },
328 })
329 }
330
331 pub async fn dispatch_tool_call(
333 &self,
334 session: Arc<dyn ToolSession>,
335 turn: Arc<TurnContext>,
336 call: ToolCall,
337 ) -> Result<ToolOutput, ToolCallError> {
338 let invocation = ToolInvocation {
339 session,
340 turn,
341 tracker: None,
342 call_id: call.call_id,
343 tool_name: call.tool_name,
344 payload: call.payload,
345 };
346
347 self.registry.dispatch(invocation).await
348 }
349
350 #[cold]
352 pub fn failure_response(_call_id: String, error: ToolCallError) -> ToolOutput {
353 ToolOutput::error(error.to_string())
354 }
355}
356
357struct RouteAliasHandler {
358 canonical_name: String,
359 inner: Arc<dyn ToolHandler>,
360}
361
362#[async_trait]
363impl ToolHandler for RouteAliasHandler {
364 fn kind(&self) -> ToolKind {
365 self.inner.kind()
366 }
367
368 fn matches_kind(&self, payload: &ToolPayload) -> bool {
369 self.inner.matches_kind(payload)
370 }
371
372 async fn is_mutating(&self, invocation: &ToolInvocation) -> bool {
373 self.inner.is_mutating(invocation).await
374 }
375
376 async fn handle(&self, mut invocation: ToolInvocation) -> Result<ToolOutput, ToolCallError> {
377 invocation.tool_name = self.canonical_name.clone();
378 self.inner.handle(invocation).await
379 }
380}
381
382#[async_trait]
384pub trait ToolRouterProvider: Send + Sync {
385 async fn get_tool_router(&self) -> Arc<ToolRouter>;
387}
388
389#[cfg(test)]
390mod tests {
391 use super::super::tool_handler::{ResponsesApiTool, ToolKind};
392 use super::*;
393 use serde_json::json;
394
395 struct MockHandler;
396
397 #[async_trait]
398 impl ToolHandler for MockHandler {
399 fn kind(&self) -> ToolKind {
400 ToolKind::Function
401 }
402
403 async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, ToolCallError> {
404 Ok(ToolOutput::simple(format!(
405 "Handled: {}",
406 invocation.tool_name
407 )))
408 }
409 }
410
411 #[test]
412 fn test_build_tool_call_function() {
413 let call = ToolRouter::build_tool_call(
414 "test_tool".to_string(),
415 "call-1".to_string(),
416 r#"{"arg": "value"}"#.to_string(),
417 None,
418 )
419 .unwrap();
420
421 assert_eq!(call.tool_name, "test_tool");
422 assert_eq!(call.call_id, "call-1");
423 assert!(matches!(call.payload, ToolPayload::Function { .. }));
424 }
425
426 #[test]
427 fn test_build_tool_call_mcp() {
428 let call = ToolRouter::build_tool_call(
429 "mcp_server/do_thing".to_string(),
430 "call-2".to_string(),
431 r#"{"arg": "value"}"#.to_string(),
432 Some("mcp_server"),
433 )
434 .unwrap();
435
436 assert_eq!(call.tool_name, "mcp_server/do_thing");
437 assert!(matches!(
438 call.payload,
439 ToolPayload::Mcp { arguments: Some(_) }
440 ));
441 }
442
443 #[test]
444 fn test_registry_builder() {
445 let handler = Arc::new(MockHandler);
446 let spec = ToolSpec::Function(ResponsesApiTool {
447 name: "test_tool".to_string(),
448 description: "A test tool".to_string(),
449 parameters: json!({"type": "object"}),
450 strict: false,
451 });
452
453 let mut builder = DispatchRegistryBuilder::new();
454 builder
455 .push_spec_with_parallel_support(spec, true)
456 .register_handler("test_tool", handler);
457
458 let (specs, registry) = builder.build();
459
460 assert_eq!(specs.len(), 1);
461 assert!(specs[0].supports_parallel_tool_calls);
462 assert!(registry.handler("test_tool").is_some());
463 }
464
465 #[test]
466 fn test_router_parallel_support() {
467 let handler = Arc::new(MockHandler);
468 let spec = ToolSpec::Function(ResponsesApiTool {
469 name: "parallel_tool".to_string(),
470 description: "Supports parallel".to_string(),
471 parameters: json!({"type": "object"}),
472 strict: false,
473 });
474
475 let mut builder = DispatchRegistryBuilder::new();
476 builder
477 .push_spec_with_parallel_support(spec, true)
478 .register_handler("parallel_tool", handler);
479
480 let router = ToolRouter::from_builder(builder);
481
482 assert!(router.tool_supports_parallel("parallel_tool"));
483 assert!(!router.tool_supports_parallel("nonexistent"));
484 }
485
486 #[test]
487 fn test_normalize_router_tool_name_exec_code_label() {
488 assert_eq!(
489 normalize_router_tool_name("Exec code").as_deref(),
490 Some("unified_exec")
491 );
492 assert_eq!(
493 normalize_router_tool_name("run command (PTY)").as_deref(),
494 Some("unified_exec")
495 );
496 assert_eq!(
497 normalize_router_tool_name("bash").as_deref(),
498 Some("unified_exec")
499 );
500 assert_eq!(
501 normalize_router_tool_name("container.exec").as_deref(),
502 Some("unified_exec")
503 );
504 }
505
506 #[test]
507 fn test_suggest_similar_tool_names_uses_normalized_form() {
508 let mut handlers = HashMap::new();
509 handlers.insert(
510 CompactStr::from("unified_exec"),
511 DispatchEntry {
512 canonical_name: "unified_exec".to_string(),
513 handler: Arc::new(MockHandler) as Arc<dyn ToolHandler>,
514 },
515 );
516
517 let suggestions = suggest_similar_tool_names("Exec code", &handlers);
518 assert_eq!(suggestions, vec!["unified_exec"]);
519 }
520}