1use std::sync::Arc;
3use serde_json::Value;
4use std::collections::{HashMap, HashSet};
5use crate::tools::Tool;
6
7#[derive(Clone)]
10pub struct ToolRegistry {
11 tools: HashMap<String, Arc<dyn Tool>>,
12 cached_schema: Arc<Vec<Value>>,
14 api_to_runtime_names: HashMap<String, String>,
16 input_name_maps: HashMap<String, SchemaNameMap>,
18}
19
20#[derive(Clone, Debug, Default)]
21struct SchemaNameMap {
22 api_to_runtime: HashMap<String, String>,
23 children: HashMap<String, SchemaNameMap>,
24 items: Option<Box<SchemaNameMap>>,
26}
27
28impl Default for ToolRegistry {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl ToolRegistry {
35 pub fn new() -> Self {
36 let tools: Vec<Arc<dyn Tool>> = vec![
37 Arc::new(crate::tools::bash::BashTool),
38 Arc::new(crate::tools::read::ReadTool),
39 Arc::new(crate::tools::write::WriteTool),
40 Arc::new(crate::tools::edit::EditTool),
41 Arc::new(crate::tools::grep::GrepTool),
42 Arc::new(crate::tools::find::FindTool),
43 Arc::new(crate::tools::ls::LsTool),
44 Arc::new(crate::tools::subagent::SubagentTool),
45 Arc::new(crate::tools::subagent::start::SubagentStartTool),
46 Arc::new(crate::tools::subagent::status::SubagentStatusTool),
47 Arc::new(crate::tools::subagent::steer::SubagentSteerTool),
48 Arc::new(crate::tools::subagent::collect::SubagentCollectTool),
49 Arc::new(crate::tools::subagent::resume::SubagentResumeTool),
50 Arc::new(crate::tools::shell::ShellStartTool),
51 Arc::new(crate::tools::shell::ShellSendTool),
52 Arc::new(crate::tools::shell::ShellEndTool),
53 ];
54 Self::from_tools(tools)
55 }
56
57 pub fn empty() -> Self {
60 Self::from_tools(Vec::new())
61 }
62
63 pub fn without_subagent() -> Self {
65 let tools: Vec<Arc<dyn Tool>> = vec![
66 Arc::new(crate::tools::bash::BashTool),
67 Arc::new(crate::tools::read::ReadTool),
68 Arc::new(crate::tools::write::WriteTool),
69 Arc::new(crate::tools::edit::EditTool),
70 Arc::new(crate::tools::grep::GrepTool),
71 Arc::new(crate::tools::find::FindTool),
72 Arc::new(crate::tools::ls::LsTool),
73 Arc::new(crate::tools::shell::ShellStartTool),
74 Arc::new(crate::tools::shell::ShellSendTool),
75 Arc::new(crate::tools::shell::ShellEndTool),
76 ];
77 Self::from_tools(tools)
78 }
79
80 pub fn without_subagent_with_extensions(extension_tools: &ToolRegistry) -> Self {
88 let mut combined = Self::without_subagent();
89 for tool in extension_tools.tools.values() {
90 if tool.extension_id().is_some() {
91 combined.tools.insert(tool.name().to_string(), tool.clone());
92 }
93 }
94 combined.rebuild_schema();
95 combined
96 }
97
98 fn from_tools(tool_list: Vec<Arc<dyn Tool>>) -> Self {
99 let mut registry = ToolRegistry {
100 tools: HashMap::new(),
101 cached_schema: Arc::new(Vec::new()),
102 api_to_runtime_names: HashMap::new(),
103 input_name_maps: HashMap::new(),
104 };
105 for tool in tool_list {
109 let name = tool.name().to_string();
110 registry.tools.insert(name, tool);
111 }
112 registry.rebuild_schema();
113 registry
114 }
115
116 fn api_safe_name(name: &str, used: &HashSet<String>) -> String {
117 Self::api_safe_identifier(name, used, 128, false)
118 }
119
120 fn api_safe_property_name(name: &str, used: &HashSet<String>) -> String {
121 Self::api_safe_identifier(name, used, 64, true)
122 }
123
124 fn api_safe_identifier(name: &str, used: &HashSet<String>, max_len: usize, allow_dot: bool) -> String {
125 let mut sanitized = String::with_capacity(name.len());
126 for ch in name.chars() {
127 if ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' || (allow_dot && ch == '.') {
128 sanitized.push(ch);
129 } else {
130 sanitized.push('_');
131 }
132 }
133 if sanitized.is_empty() {
134 sanitized.push_str("field");
135 }
136 if sanitized.len() > max_len {
137 sanitized.truncate(max_len);
138 }
139
140 let base = sanitized.clone();
141 let mut suffix = 2;
142 while used.contains(&sanitized) {
143 let suffix_str = format!("_{suffix}");
144 let keep = max_len.saturating_sub(suffix_str.len());
145 sanitized = format!("{}{}", &base[..base.len().min(keep)], suffix_str);
146 suffix += 1;
147 }
148 sanitized
149 }
150
151 fn sanitize_schema(mut schema: Value) -> (Value, SchemaNameMap) {
152 let mut map = SchemaNameMap::default();
153 let Some(obj) = schema.as_object_mut() else {
154 return (schema, map);
155 };
156
157 let mut required_name_map = HashMap::new();
158 if let Some(props_value) = obj.get_mut("properties") {
159 if let Some(props) = props_value.as_object_mut() {
160 let original = std::mem::take(props);
161 let mut used = HashSet::new();
162 for (runtime_name, child_schema) in original {
163 let api_name = Self::api_safe_property_name(&runtime_name, &used);
164 used.insert(api_name.clone());
165 required_name_map.insert(runtime_name.clone(), api_name.clone());
166 map.api_to_runtime.insert(api_name.clone(), runtime_name);
167
168 let (sanitized_child, child_map) = Self::sanitize_schema(child_schema);
169 if !child_map.api_to_runtime.is_empty() || !child_map.children.is_empty() {
170 map.children.insert(api_name.clone(), child_map);
171 }
172 props.insert(api_name, sanitized_child);
173 }
174 }
175 }
176
177 if let Some(required) = obj.get_mut("required").and_then(Value::as_array_mut) {
178 for item in required.iter_mut() {
179 if let Some(name) = item.as_str() {
180 if let Some(api_name) = required_name_map.get(name) {
181 *item = Value::String(api_name.clone());
182 }
183 }
184 }
185 }
186
187 if let Some(items) = obj.get_mut("items") {
190 let (sanitized_items, items_map) = Self::sanitize_schema(std::mem::take(items));
191 if !items_map.api_to_runtime.is_empty() || !items_map.children.is_empty() || items_map.items.is_some() {
192 map.items = Some(Box::new(items_map));
193 }
194 *items = sanitized_items;
195 }
196
197 (schema, map)
198 }
199
200 fn translate_input_names(input: Value, map: &SchemaNameMap) -> Value {
201 match input {
202 Value::Object(obj) => {
203 let mut out = serde_json::Map::new();
204 for (api_name, value) in obj {
205 let runtime_name = map.api_to_runtime.get(&api_name).cloned().unwrap_or_else(|| api_name.clone());
206 let value = if let Some(child) = map.children.get(&api_name) {
207 Self::translate_input_names(value, child)
208 } else {
209 value
210 };
211 out.insert(runtime_name, value);
212 }
213 Value::Object(out)
214 }
215 Value::Array(arr) => {
216 if let Some(items_map) = &map.items {
218 Value::Array(arr.into_iter().map(|v| Self::translate_input_names(v, items_map)).collect())
219 } else {
220 Value::Array(arr)
221 }
222 }
223 other => other,
224 }
225 }
226
227 fn rebuild_schema(&mut self) {
228 let mut used = HashSet::new();
229 let mut api_to_runtime_names = HashMap::new();
230 let mut input_name_maps = HashMap::new();
231 let mut schema = Vec::with_capacity(self.tools.len());
232
233 let mut sorted_tools: Vec<_> = self.tools.values().collect();
237 sorted_tools.sort_by_key(|t| t.name().to_string());
238
239 for tool in sorted_tools {
240 let runtime_name = tool.name();
241 let api_name = Self::api_safe_name(runtime_name, &used);
242 used.insert(api_name.clone());
243 api_to_runtime_names.insert(api_name.clone(), runtime_name.to_string());
244 let (input_schema, input_map) = Self::sanitize_schema(tool.parameters());
245 input_name_maps.insert(api_name.clone(), input_map);
246 schema.push(serde_json::json!({
247 "name": api_name,
248 "description": tool.description(),
249 "input_schema": input_schema
250 }));
251 }
252
253 self.api_to_runtime_names = api_to_runtime_names;
254 self.input_name_maps = input_name_maps;
255 self.cached_schema = Arc::new(schema);
256 }
257
258 pub fn register(&mut self, tool: Arc<dyn Tool>) {
261 let name = tool.name().to_string();
262 self.tools.insert(name, tool);
263 self.rebuild_schema();
264 }
265
266 pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
267 let runtime_name = self.api_to_runtime_names.get(name).map(String::as_str).unwrap_or(name);
268 self.tools.get(runtime_name)
269 }
270
271 pub fn runtime_name_for_api<'a>(&'a self, name: &'a str) -> &'a str {
272 self.api_to_runtime_names.get(name).map(String::as_str).unwrap_or(name)
273 }
274
275 pub fn translate_input_for_api_tool(&self, tool_name: &str, input: Value) -> Value {
276 if let Some(map) = self.input_name_maps.get(tool_name) {
277 Self::translate_input_names(input, map)
278 } else {
279 input
280 }
281 }
282
283 pub fn tools_schema(&self) -> Arc<Vec<Value>> {
284 Arc::clone(&self.cached_schema)
285 }
286
287 pub fn tool_names_for_extension(&self, extension_id: &str) -> Vec<String> {
290 let mut names: Vec<String> = self
291 .tools
292 .values()
293 .filter(|t| t.extension_id() == Some(extension_id))
294 .map(|t| t.name().to_string())
295 .collect();
296 names.sort();
297 names
298 }
299}
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::{Result, ToolContext};
304 use serde_json::json;
305
306 struct NamedTool(&'static str);
307
308 #[async_trait::async_trait]
309 impl Tool for NamedTool {
310 fn name(&self) -> &str { self.0 }
311 fn description(&self) -> &str { "test tool" }
312 fn parameters(&self) -> Value { json!({"type": "object"}) }
313 async fn execute(&self, _params: Value, _ctx: ToolContext) -> Result<String> {
314 Ok("ok".to_string())
315 }
316 }
317
318
319
320 struct SchemaTool;
321
322 #[async_trait::async_trait]
323 impl Tool for SchemaTool {
324 fn name(&self) -> &str { "schema_tool" }
325 fn description(&self) -> &str { "schema tool" }
326 fn parameters(&self) -> Value {
327 json!({
328 "type": "object",
329 "properties": {
330 "bad:key/that/is/far/too/long/for/anthropic/property/names/and/keeps/going": {"type": "string"},
331 "nested:obj": {
332 "type": "object",
333 "properties": {"inner/key": {"type": "string"}},
334 "required": ["inner/key"]
335 }
336 },
337 "required": [
338 "bad:key/that/is/far/too/long/for/anthropic/property/names/and/keeps/going",
339 "nested:obj"
340 ]
341 })
342 }
343 async fn execute(&self, _params: Value, _ctx: ToolContext) -> Result<String> {
344 Ok("ok".to_string())
345 }
346 }
347
348 #[test]
349 fn tool_schema_uses_api_safe_names_and_maps_back() {
350 let registry = ToolRegistry::from_tools(vec![Arc::new(NamedTool("plugin:skill.tool"))]);
351
352 assert_eq!(registry.tools_schema()[0]["name"], "plugin_skill_tool");
353 assert!(registry.get("plugin:skill.tool").is_some());
354 assert!(registry.get("plugin_skill_tool").is_some());
355 assert_eq!(registry.runtime_name_for_api("plugin_skill_tool"), "plugin:skill.tool");
356 }
357
358 #[test]
359 fn tool_schema_disambiguates_sanitized_name_collisions() {
360 let registry = ToolRegistry::from_tools(vec![
361 Arc::new(NamedTool("a:b")),
362 Arc::new(NamedTool("a.b")),
363 ]);
364 let names: HashSet<String> = registry.tools_schema().iter()
365 .filter_map(|s| s["name"].as_str().map(str::to_string))
366 .collect();
367
368 assert_eq!(names.len(), 2);
369 assert!(names.contains("a_b"));
370 assert!(names.contains("a_b_2"));
371 assert!(registry.get("a_b").is_some());
372 assert!(registry.get("a_b_2").is_some());
373 }
374
375 #[test]
376 fn tool_schema_truncates_long_names_to_anthropic_limit() {
377 let long = "x".repeat(140);
378 let leaked: &'static str = Box::leak(long.into_boxed_str());
379 let registry = ToolRegistry::from_tools(vec![Arc::new(NamedTool(leaked))]);
380 let schema = registry.tools_schema();
381 let name = schema[0]["name"].as_str().unwrap();
382
383 assert_eq!(name.len(), 128);
384 assert!(name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-'));
385 assert!(registry.get(name).is_some());
386 }
387
388 #[test]
389 fn tool_schema_sanitizes_input_property_names_and_translates_inputs_back() {
390 let registry = ToolRegistry::from_tools(vec![Arc::new(SchemaTool)]);
391 let schema = registry.tools_schema();
392 let input_schema = &schema[0]["input_schema"];
393 let props = input_schema["properties"].as_object().unwrap();
394
395 assert!(props.keys().all(|k| k.len() <= 64 && k.chars().all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.')));
396 assert_eq!(input_schema["required"].as_array().unwrap()[0].as_str().unwrap().len(), 64);
397 assert_eq!(input_schema["required"][1], "nested_obj");
398 assert!(props["nested_obj"]["properties"].as_object().unwrap().contains_key("inner_key"));
399 assert_eq!(props["nested_obj"]["required"][0], "inner_key");
400
401 let first_required = input_schema["required"][0].as_str().unwrap();
402 let translated = registry.translate_input_for_api_tool("schema_tool", json!({
403 first_required: "value",
404 "nested_obj": {"inner_key": "nested"}
405 }));
406
407 assert_eq!(translated["bad:key/that/is/far/too/long/for/anthropic/property/names/and/keeps/going"], "value");
408 assert_eq!(translated["nested:obj"]["inner/key"], "nested");
409 }
410
411 #[test]
412 fn test_tool_registry_new() {
413 let registry = ToolRegistry::new();
414
415 assert_eq!(registry.tools_schema().len(), 16);
417
418 assert!(registry.get("bash").is_some());
420
421 assert!(registry.get("nonexistent").is_none());
423
424 assert!(registry.get("bash").is_some());
426 assert!(registry.get("read").is_some());
427 assert!(registry.get("write").is_some());
428 assert!(registry.get("edit").is_some());
429 assert!(registry.get("grep").is_some());
430 assert!(registry.get("find").is_some());
431 assert!(registry.get("ls").is_some());
432 assert!(registry.get("subagent").is_some());
433 }
434
435 #[test]
436 fn test_tool_registry_without_subagent() {
437 let registry = ToolRegistry::without_subagent();
438
439 assert_eq!(registry.tools_schema().len(), 10);
441
442 assert!(registry.get("subagent").is_none());
444
445 assert!(registry.get("bash").is_some());
447
448 assert!(registry.get("bash").is_some());
450 assert!(registry.get("read").is_some());
451 assert!(registry.get("write").is_some());
452 assert!(registry.get("edit").is_some());
453 assert!(registry.get("grep").is_some());
454 assert!(registry.get("find").is_some());
455 assert!(registry.get("ls").is_some());
456 }
457
458 #[test]
459 fn test_tool_registry_register() {
460 let mut registry = ToolRegistry::without_subagent();
461 let initial_count = registry.tools_schema().len();
462
463 struct TestTool;
465 #[async_trait::async_trait]
466 impl Tool for TestTool {
467 fn name(&self) -> &str { "test_tool" }
468 fn description(&self) -> &str { "A test tool" }
469 fn parameters(&self) -> Value { json!({"type": "object"}) }
470 async fn execute(&self, _params: Value, _ctx: ToolContext) -> Result<String> {
471 Ok("test result".to_string())
472 }
473 }
474
475 registry.register(Arc::new(TestTool));
476
477 assert_eq!(registry.tools_schema().len(), initial_count + 1);
479
480 assert!(registry.get("test_tool").is_some());
482 }
483
484 #[test]
485 fn tool_names_for_extension_filters_by_owner_and_sorts() {
486 struct OwnedTool(&'static str, Option<&'static str>);
487 #[async_trait::async_trait]
488 impl Tool for OwnedTool {
489 fn name(&self) -> &str { self.0 }
490 fn description(&self) -> &str { "owned" }
491 fn parameters(&self) -> Value { json!({"type": "object"}) }
492 async fn execute(&self, _params: Value, _ctx: ToolContext) -> Result<String> {
493 Ok("ok".to_string())
494 }
495 fn extension_id(&self) -> Option<&str> { self.1 }
496 }
497
498 let mut registry = ToolRegistry::without_subagent();
499 registry.register(Arc::new(OwnedTool("alpha:zed", Some("alpha"))));
500 registry.register(Arc::new(OwnedTool("alpha:bar", Some("alpha"))));
501 registry.register(Arc::new(OwnedTool("beta:thing", Some("beta"))));
502
503 assert_eq!(
504 registry.tool_names_for_extension("alpha"),
505 vec!["alpha:bar".to_string(), "alpha:zed".to_string()]
506 );
507 assert_eq!(
508 registry.tool_names_for_extension("beta"),
509 vec!["beta:thing".to_string()]
510 );
511 assert!(registry.tool_names_for_extension("ghost").is_empty());
512 assert!(registry.tool_names_for_extension("bash").is_empty());
514 }
515
516 struct OwnedTool(&'static str, Option<&'static str>);
517 #[async_trait::async_trait]
518 impl Tool for OwnedTool {
519 fn name(&self) -> &str { self.0 }
520 fn description(&self) -> &str { "owned" }
521 fn parameters(&self) -> Value { json!({"type": "object"}) }
522 async fn execute(&self, _params: Value, _ctx: ToolContext) -> Result<String> {
523 Ok("ok".to_string())
524 }
525 fn extension_id(&self) -> Option<&str> { self.1 }
526 }
527
528 #[test]
529 fn without_subagent_excludes_subagent_tools() {
530 let registry = ToolRegistry::without_subagent();
531 assert!(registry.get("subagent").is_none());
532 assert!(registry.get("subagent_start").is_none());
533 assert!(registry.get("subagent_status").is_none());
534 assert!(registry.get("subagent_steer").is_none());
535 assert!(registry.get("subagent_collect").is_none());
536 assert!(registry.get("subagent_resume").is_none());
537 assert!(registry.get("bash").is_some());
539 assert!(registry.get("read").is_some());
540 }
541
542 #[test]
543 fn without_subagent_with_extensions_includes_extension_tools() {
544 let mut other = ToolRegistry::empty();
545 other.register(Arc::new(OwnedTool("alpha:do_thing", Some("alpha"))));
546
547 let merged = ToolRegistry::without_subagent_with_extensions(&other);
548
549 assert!(merged.get("alpha:do_thing").is_some());
551 assert!(merged.get("bash").is_some());
553 assert!(merged.get("read").is_some());
554 assert!(merged.get("subagent_start").is_none());
556 }
557
558 #[test]
559 fn without_subagent_with_extensions_excludes_built_ins_from_other_registry() {
560 let other = ToolRegistry::new();
565
566 let merged = ToolRegistry::without_subagent_with_extensions(&other);
567
568 let bash = merged.get("bash").expect("bash present");
570 assert!(bash.extension_id().is_none());
571 assert!(merged.get("subagent_start").is_none());
573 assert!(merged.get("subagent").is_none());
574 }
575
576 #[test]
577 fn without_subagent_with_extensions_does_not_overwrite_existing_builtin() {
578 let mut other = ToolRegistry::empty();
584 other.register(Arc::new(OwnedTool("ext:custom", Some("ext"))));
585
586 let merged = ToolRegistry::without_subagent_with_extensions(&other);
587 assert!(merged.get("ext:custom").is_some());
588 assert!(merged.get("bash").is_some());
589 assert!(merged.get("bash").unwrap().extension_id().is_none());
590 }
591}