Skip to main content

synaps_cli/tools/
registry.rs

1//! Tool registry — maintains name→tool map and cached JSON schema for the API.
2use std::sync::Arc;
3use serde_json::Value;
4use std::collections::{HashMap, HashSet};
5use crate::tools::Tool;
6
7/// Registry of available tools. Maintains a name→tool map and a cached JSON schema
8/// array that gets sent to the API. Thread-safe via `Arc<RwLock<ToolRegistry>>`.
9#[derive(Clone)]
10pub struct ToolRegistry {
11    tools: HashMap<String, Arc<dyn Tool>>,
12    /// Cached Anthropic-compatible schema — rebuilt on register(), shared via Arc for zero-copy reads.
13    cached_schema: Arc<Vec<Value>>,
14    /// Mapping from API-safe tool names back to their runtime names.
15    api_to_runtime_names: HashMap<String, String>,
16    /// Per-tool mapping from API-safe input property names back to runtime names.
17    input_name_maps: HashMap<String, SchemaNameMap>,
18}
19
20#[derive(Clone, Debug, Default)]
21struct SchemaNameMap {
22    api_to_runtime: HashMap<String, String>,
23    children: HashMap<String, SchemaNameMap>,
24    /// Name map for array `items` schemas (objects inside arrays).
25    items: Option<Box<SchemaNameMap>>,
26}
27
28impl Default for ToolRegistry {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl ToolRegistry {
35    pub fn new() -> Self {
36        let tools: Vec<Arc<dyn Tool>> = vec![
37            Arc::new(crate::tools::bash::BashTool),
38            Arc::new(crate::tools::read::ReadTool),
39            Arc::new(crate::tools::write::WriteTool),
40            Arc::new(crate::tools::edit::EditTool),
41            Arc::new(crate::tools::grep::GrepTool),
42            Arc::new(crate::tools::find::FindTool),
43            Arc::new(crate::tools::ls::LsTool),
44            Arc::new(crate::tools::subagent::SubagentTool),
45            Arc::new(crate::tools::subagent::start::SubagentStartTool),
46            Arc::new(crate::tools::subagent::status::SubagentStatusTool),
47            Arc::new(crate::tools::subagent::steer::SubagentSteerTool),
48            Arc::new(crate::tools::subagent::collect::SubagentCollectTool),
49            Arc::new(crate::tools::subagent::resume::SubagentResumeTool),
50            Arc::new(crate::tools::shell::ShellStartTool),
51            Arc::new(crate::tools::shell::ShellSendTool),
52            Arc::new(crate::tools::shell::ShellEndTool),
53        ];
54        Self::from_tools(tools)
55    }
56
57    /// Empty registry for tests and narrow embedded runtimes that want to opt in
58    /// to specific tools explicitly.
59    pub fn empty() -> Self {
60        Self::from_tools(Vec::new())
61    }
62
63    /// Registry without subagent tool — used for subagent runtimes to prevent recursion.
64    pub fn without_subagent() -> Self {
65        let tools: Vec<Arc<dyn Tool>> = vec![
66            Arc::new(crate::tools::bash::BashTool),
67            Arc::new(crate::tools::read::ReadTool),
68            Arc::new(crate::tools::write::WriteTool),
69            Arc::new(crate::tools::edit::EditTool),
70            Arc::new(crate::tools::grep::GrepTool),
71            Arc::new(crate::tools::find::FindTool),
72            Arc::new(crate::tools::ls::LsTool),
73            Arc::new(crate::tools::shell::ShellStartTool),
74            Arc::new(crate::tools::shell::ShellSendTool),
75            Arc::new(crate::tools::shell::ShellEndTool),
76        ];
77        Self::from_tools(tools)
78    }
79
80    /// Built-in subagent registry plus extension tools merged in from a shared
81    /// registry. Used by subagents that need to invoke extension-provided
82    /// tools while still excluding the recursive `subagent_*` tools.
83    ///
84    /// Only tools whose `Tool::extension_id()` returns `Some(_)` are merged;
85    /// built-ins inside `extension_tools` are ignored to avoid duplicating or
86    /// shadowing the canonical built-in instances.
87    pub fn without_subagent_with_extensions(extension_tools: &ToolRegistry) -> Self {
88        let mut combined = Self::without_subagent();
89        for tool in extension_tools.tools.values() {
90            if tool.extension_id().is_some() {
91                combined.tools.insert(tool.name().to_string(), tool.clone());
92            }
93        }
94        combined.rebuild_schema();
95        combined
96    }
97
98    fn from_tools(tool_list: Vec<Arc<dyn Tool>>) -> Self {
99        let mut registry = ToolRegistry {
100            tools: HashMap::new(),
101            cached_schema: Arc::new(Vec::new()),
102            api_to_runtime_names: HashMap::new(),
103            input_name_maps: HashMap::new(),
104        };
105        // Insert all tools first, then rebuild schema once.
106        // Calling register() in a loop would rebuild_schema() on every
107        // iteration, making initialization O(n²) with MCP tool counts.
108        for tool in tool_list {
109            let name = tool.name().to_string();
110            registry.tools.insert(name, tool);
111        }
112        registry.rebuild_schema();
113        registry
114    }
115
116    fn api_safe_name(name: &str, used: &HashSet<String>) -> String {
117        Self::api_safe_identifier(name, used, 128, false)
118    }
119
120    fn api_safe_property_name(name: &str, used: &HashSet<String>) -> String {
121        Self::api_safe_identifier(name, used, 64, true)
122    }
123
124    fn api_safe_identifier(name: &str, used: &HashSet<String>, max_len: usize, allow_dot: bool) -> String {
125        let mut sanitized = String::with_capacity(name.len());
126        for ch in name.chars() {
127            if ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' || (allow_dot && ch == '.') {
128                sanitized.push(ch);
129            } else {
130                sanitized.push('_');
131            }
132        }
133        if sanitized.is_empty() {
134            sanitized.push_str("field");
135        }
136        if sanitized.len() > max_len {
137            sanitized.truncate(max_len);
138        }
139
140        let base = sanitized.clone();
141        let mut suffix = 2;
142        while used.contains(&sanitized) {
143            let suffix_str = format!("_{suffix}");
144            let keep = max_len.saturating_sub(suffix_str.len());
145            sanitized = format!("{}{}", &base[..base.len().min(keep)], suffix_str);
146            suffix += 1;
147        }
148        sanitized
149    }
150
151    fn sanitize_schema(mut schema: Value) -> (Value, SchemaNameMap) {
152        let mut map = SchemaNameMap::default();
153        let Some(obj) = schema.as_object_mut() else {
154            return (schema, map);
155        };
156
157        let mut required_name_map = HashMap::new();
158        if let Some(props_value) = obj.get_mut("properties") {
159            if let Some(props) = props_value.as_object_mut() {
160                let original = std::mem::take(props);
161                let mut used = HashSet::new();
162                for (runtime_name, child_schema) in original {
163                    let api_name = Self::api_safe_property_name(&runtime_name, &used);
164                    used.insert(api_name.clone());
165                    required_name_map.insert(runtime_name.clone(), api_name.clone());
166                    map.api_to_runtime.insert(api_name.clone(), runtime_name);
167
168                    let (sanitized_child, child_map) = Self::sanitize_schema(child_schema);
169                    if !child_map.api_to_runtime.is_empty() || !child_map.children.is_empty() {
170                        map.children.insert(api_name.clone(), child_map);
171                    }
172                    props.insert(api_name, sanitized_child);
173                }
174            }
175        }
176
177        if let Some(required) = obj.get_mut("required").and_then(Value::as_array_mut) {
178            for item in required.iter_mut() {
179                if let Some(name) = item.as_str() {
180                    if let Some(api_name) = required_name_map.get(name) {
181                        *item = Value::String(api_name.clone());
182                    }
183                }
184            }
185        }
186
187        // Recurse into array item schemas; store the child map so
188        // translate_input_names can reverse-map property names inside array elements.
189        if let Some(items) = obj.get_mut("items") {
190            let (sanitized_items, items_map) = Self::sanitize_schema(std::mem::take(items));
191            if !items_map.api_to_runtime.is_empty() || !items_map.children.is_empty() || items_map.items.is_some() {
192                map.items = Some(Box::new(items_map));
193            }
194            *items = sanitized_items;
195        }
196
197        (schema, map)
198    }
199
200    fn translate_input_names(input: Value, map: &SchemaNameMap) -> Value {
201        match input {
202            Value::Object(obj) => {
203                let mut out = serde_json::Map::new();
204                for (api_name, value) in obj {
205                    let runtime_name = map.api_to_runtime.get(&api_name).cloned().unwrap_or_else(|| api_name.clone());
206                    let value = if let Some(child) = map.children.get(&api_name) {
207                        Self::translate_input_names(value, child)
208                    } else {
209                        value
210                    };
211                    out.insert(runtime_name, value);
212                }
213                Value::Object(out)
214            }
215            Value::Array(arr) => {
216                // If the schema had an items map, apply it to each array element.
217                if let Some(items_map) = &map.items {
218                    Value::Array(arr.into_iter().map(|v| Self::translate_input_names(v, items_map)).collect())
219                } else {
220                    Value::Array(arr)
221                }
222            }
223            other => other,
224        }
225    }
226
227    fn rebuild_schema(&mut self) {
228        let mut used = HashSet::new();
229        let mut api_to_runtime_names = HashMap::new();
230        let mut input_name_maps = HashMap::new();
231        let mut schema = Vec::with_capacity(self.tools.len());
232
233        // Sort by runtime name for deterministic API name assignment.
234        // HashMap iteration is random, so without sorting, collision suffixes
235        // (_2, _3) could change between rebuilds, breaking in-flight conversations.
236        let mut sorted_tools: Vec<_> = self.tools.values().collect();
237        sorted_tools.sort_by_key(|t| t.name().to_string());
238
239        for tool in sorted_tools {
240            let runtime_name = tool.name();
241            let api_name = Self::api_safe_name(runtime_name, &used);
242            used.insert(api_name.clone());
243            api_to_runtime_names.insert(api_name.clone(), runtime_name.to_string());
244            let (input_schema, input_map) = Self::sanitize_schema(tool.parameters());
245            input_name_maps.insert(api_name.clone(), input_map);
246            schema.push(serde_json::json!({
247                "name": api_name,
248                "description": tool.description(),
249                "input_schema": input_schema
250            }));
251        }
252
253        self.api_to_runtime_names = api_to_runtime_names;
254        self.input_name_maps = input_name_maps;
255        self.cached_schema = Arc::new(schema);
256    }
257
258    /// Register an additional tool at runtime (e.g. MCP tools, custom tools).
259    /// If a tool with the same name exists, it is replaced.
260    pub fn register(&mut self, tool: Arc<dyn Tool>) {
261        let name = tool.name().to_string();
262        self.tools.insert(name, tool);
263        self.rebuild_schema();
264    }
265
266    pub fn get(&self, name: &str) -> Option<&Arc<dyn Tool>> {
267        let runtime_name = self.api_to_runtime_names.get(name).map(String::as_str).unwrap_or(name);
268        self.tools.get(runtime_name)
269    }
270
271    pub fn runtime_name_for_api<'a>(&'a self, name: &'a str) -> &'a str {
272        self.api_to_runtime_names.get(name).map(String::as_str).unwrap_or(name)
273    }
274
275    pub fn translate_input_for_api_tool(&self, tool_name: &str, input: Value) -> Value {
276        if let Some(map) = self.input_name_maps.get(tool_name) {
277            Self::translate_input_names(input, map)
278        } else {
279            input
280        }
281    }
282
283    pub fn tools_schema(&self) -> Arc<Vec<Value>> {
284        Arc::clone(&self.cached_schema)
285    }
286
287    /// Return runtime names of tools owned by the given extension id, sorted ascending.
288    /// Built-in tools (which return `None` from `Tool::extension_id`) are excluded.
289    pub fn tool_names_for_extension(&self, extension_id: &str) -> Vec<String> {
290        let mut names: Vec<String> = self
291            .tools
292            .values()
293            .filter(|t| t.extension_id() == Some(extension_id))
294            .map(|t| t.name().to_string())
295            .collect();
296        names.sort();
297        names
298    }
299}
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::{Result, ToolContext};
304    use serde_json::json;
305
306    struct NamedTool(&'static str);
307
308    #[async_trait::async_trait]
309    impl Tool for NamedTool {
310        fn name(&self) -> &str { self.0 }
311        fn description(&self) -> &str { "test tool" }
312        fn parameters(&self) -> Value { json!({"type": "object"}) }
313        async fn execute(&self, _params: Value, _ctx: ToolContext) -> Result<String> {
314            Ok("ok".to_string())
315        }
316    }
317
318
319
320    struct SchemaTool;
321
322    #[async_trait::async_trait]
323    impl Tool for SchemaTool {
324        fn name(&self) -> &str { "schema_tool" }
325        fn description(&self) -> &str { "schema tool" }
326        fn parameters(&self) -> Value {
327            json!({
328                "type": "object",
329                "properties": {
330                    "bad:key/that/is/far/too/long/for/anthropic/property/names/and/keeps/going": {"type": "string"},
331                    "nested:obj": {
332                        "type": "object",
333                        "properties": {"inner/key": {"type": "string"}},
334                        "required": ["inner/key"]
335                    }
336                },
337                "required": [
338                    "bad:key/that/is/far/too/long/for/anthropic/property/names/and/keeps/going",
339                    "nested:obj"
340                ]
341            })
342        }
343        async fn execute(&self, _params: Value, _ctx: ToolContext) -> Result<String> {
344            Ok("ok".to_string())
345        }
346    }
347
348    #[test]
349    fn tool_schema_uses_api_safe_names_and_maps_back() {
350        let registry = ToolRegistry::from_tools(vec![Arc::new(NamedTool("plugin:skill.tool"))]);
351
352        assert_eq!(registry.tools_schema()[0]["name"], "plugin_skill_tool");
353        assert!(registry.get("plugin:skill.tool").is_some());
354        assert!(registry.get("plugin_skill_tool").is_some());
355        assert_eq!(registry.runtime_name_for_api("plugin_skill_tool"), "plugin:skill.tool");
356    }
357
358    #[test]
359    fn tool_schema_disambiguates_sanitized_name_collisions() {
360        let registry = ToolRegistry::from_tools(vec![
361            Arc::new(NamedTool("a:b")),
362            Arc::new(NamedTool("a.b")),
363        ]);
364        let names: HashSet<String> = registry.tools_schema().iter()
365            .filter_map(|s| s["name"].as_str().map(str::to_string))
366            .collect();
367
368        assert_eq!(names.len(), 2);
369        assert!(names.contains("a_b"));
370        assert!(names.contains("a_b_2"));
371        assert!(registry.get("a_b").is_some());
372        assert!(registry.get("a_b_2").is_some());
373    }
374
375    #[test]
376    fn tool_schema_truncates_long_names_to_anthropic_limit() {
377        let long = "x".repeat(140);
378        let leaked: &'static str = Box::leak(long.into_boxed_str());
379        let registry = ToolRegistry::from_tools(vec![Arc::new(NamedTool(leaked))]);
380        let schema = registry.tools_schema();
381        let name = schema[0]["name"].as_str().unwrap();
382
383        assert_eq!(name.len(), 128);
384        assert!(name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-'));
385        assert!(registry.get(name).is_some());
386    }
387
388    #[test]
389    fn tool_schema_sanitizes_input_property_names_and_translates_inputs_back() {
390        let registry = ToolRegistry::from_tools(vec![Arc::new(SchemaTool)]);
391        let schema = registry.tools_schema();
392        let input_schema = &schema[0]["input_schema"];
393        let props = input_schema["properties"].as_object().unwrap();
394
395        assert!(props.keys().all(|k| k.len() <= 64 && k.chars().all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.')));
396        assert_eq!(input_schema["required"].as_array().unwrap()[0].as_str().unwrap().len(), 64);
397        assert_eq!(input_schema["required"][1], "nested_obj");
398        assert!(props["nested_obj"]["properties"].as_object().unwrap().contains_key("inner_key"));
399        assert_eq!(props["nested_obj"]["required"][0], "inner_key");
400
401        let first_required = input_schema["required"][0].as_str().unwrap();
402        let translated = registry.translate_input_for_api_tool("schema_tool", json!({
403            first_required: "value",
404            "nested_obj": {"inner_key": "nested"}
405        }));
406
407        assert_eq!(translated["bad:key/that/is/far/too/long/for/anthropic/property/names/and/keeps/going"], "value");
408        assert_eq!(translated["nested:obj"]["inner/key"], "nested");
409    }
410
411    #[test]
412    fn test_tool_registry_new() {
413        let registry = ToolRegistry::new();
414
415        // Should have 11 tools including subagent + 3 shell tools
416        assert_eq!(registry.tools_schema().len(), 16);
417
418        // Should find bash tool
419        assert!(registry.get("bash").is_some());
420
421        // Should not find nonexistent tool
422        assert!(registry.get("nonexistent").is_none());
423
424        // Verify all expected tools are present
425        assert!(registry.get("bash").is_some());
426        assert!(registry.get("read").is_some());
427        assert!(registry.get("write").is_some());
428        assert!(registry.get("edit").is_some());
429        assert!(registry.get("grep").is_some());
430        assert!(registry.get("find").is_some());
431        assert!(registry.get("ls").is_some());
432        assert!(registry.get("subagent").is_some());
433    }
434
435    #[test]
436    fn test_tool_registry_without_subagent() {
437        let registry = ToolRegistry::without_subagent();
438
439        // Should have 10 tools without subagent (7 base + 3 shell)
440        assert_eq!(registry.tools_schema().len(), 10);
441
442        // Should not have subagent tool
443        assert!(registry.get("subagent").is_none());
444
445        // Should still have bash tool
446        assert!(registry.get("bash").is_some());
447
448        // Verify all expected tools are present except subagent
449        assert!(registry.get("bash").is_some());
450        assert!(registry.get("read").is_some());
451        assert!(registry.get("write").is_some());
452        assert!(registry.get("edit").is_some());
453        assert!(registry.get("grep").is_some());
454        assert!(registry.get("find").is_some());
455        assert!(registry.get("ls").is_some());
456    }
457
458    #[test]
459    fn test_tool_registry_register() {
460        let mut registry = ToolRegistry::without_subagent();
461        let initial_count = registry.tools_schema().len();
462
463        // Register a new tool (using BashTool with different name for simplicity)
464        struct TestTool;
465        #[async_trait::async_trait]
466        impl Tool for TestTool {
467            fn name(&self) -> &str { "test_tool" }
468            fn description(&self) -> &str { "A test tool" }
469            fn parameters(&self) -> Value { json!({"type": "object"}) }
470            async fn execute(&self, _params: Value, _ctx: ToolContext) -> Result<String> {
471                Ok("test result".to_string())
472            }
473        }
474
475        registry.register(Arc::new(TestTool));
476
477        // Should have one more tool now
478        assert_eq!(registry.tools_schema().len(), initial_count + 1);
479
480        // Should find the new tool
481        assert!(registry.get("test_tool").is_some());
482    }
483
484    #[test]
485    fn tool_names_for_extension_filters_by_owner_and_sorts() {
486        struct OwnedTool(&'static str, Option<&'static str>);
487        #[async_trait::async_trait]
488        impl Tool for OwnedTool {
489            fn name(&self) -> &str { self.0 }
490            fn description(&self) -> &str { "owned" }
491            fn parameters(&self) -> Value { json!({"type": "object"}) }
492            async fn execute(&self, _params: Value, _ctx: ToolContext) -> Result<String> {
493                Ok("ok".to_string())
494            }
495            fn extension_id(&self) -> Option<&str> { self.1 }
496        }
497
498        let mut registry = ToolRegistry::without_subagent();
499        registry.register(Arc::new(OwnedTool("alpha:zed", Some("alpha"))));
500        registry.register(Arc::new(OwnedTool("alpha:bar", Some("alpha"))));
501        registry.register(Arc::new(OwnedTool("beta:thing", Some("beta"))));
502
503        assert_eq!(
504            registry.tool_names_for_extension("alpha"),
505            vec!["alpha:bar".to_string(), "alpha:zed".to_string()]
506        );
507        assert_eq!(
508            registry.tool_names_for_extension("beta"),
509            vec!["beta:thing".to_string()]
510        );
511        assert!(registry.tool_names_for_extension("ghost").is_empty());
512        // Built-in tools (no owner) must not leak.
513        assert!(registry.tool_names_for_extension("bash").is_empty());
514    }
515
516    struct OwnedTool(&'static str, Option<&'static str>);
517    #[async_trait::async_trait]
518    impl Tool for OwnedTool {
519        fn name(&self) -> &str { self.0 }
520        fn description(&self) -> &str { "owned" }
521        fn parameters(&self) -> Value { json!({"type": "object"}) }
522        async fn execute(&self, _params: Value, _ctx: ToolContext) -> Result<String> {
523            Ok("ok".to_string())
524        }
525        fn extension_id(&self) -> Option<&str> { self.1 }
526    }
527
528    #[test]
529    fn without_subagent_excludes_subagent_tools() {
530        let registry = ToolRegistry::without_subagent();
531        assert!(registry.get("subagent").is_none());
532        assert!(registry.get("subagent_start").is_none());
533        assert!(registry.get("subagent_status").is_none());
534        assert!(registry.get("subagent_steer").is_none());
535        assert!(registry.get("subagent_collect").is_none());
536        assert!(registry.get("subagent_resume").is_none());
537        // Built-ins remain.
538        assert!(registry.get("bash").is_some());
539        assert!(registry.get("read").is_some());
540    }
541
542    #[test]
543    fn without_subagent_with_extensions_includes_extension_tools() {
544        let mut other = ToolRegistry::empty();
545        other.register(Arc::new(OwnedTool("alpha:do_thing", Some("alpha"))));
546
547        let merged = ToolRegistry::without_subagent_with_extensions(&other);
548
549        // Extension tool present.
550        assert!(merged.get("alpha:do_thing").is_some());
551        // Built-ins still present.
552        assert!(merged.get("bash").is_some());
553        assert!(merged.get("read").is_some());
554        // Subagent tools still absent.
555        assert!(merged.get("subagent_start").is_none());
556    }
557
558    #[test]
559    fn without_subagent_with_extensions_excludes_built_ins_from_other_registry() {
560        // `other` simulates a shared registry that already holds built-ins
561        // (e.g. the extension manager's tools registry). Only tools with an
562        // extension owner must be merged — built-ins must NOT be re-added or
563        // overwritten with a foreign instance.
564        let other = ToolRegistry::new();
565
566        let merged = ToolRegistry::without_subagent_with_extensions(&other);
567
568        // Only one instance of `bash`, and it's the built-in (no extension_id).
569        let bash = merged.get("bash").expect("bash present");
570        assert!(bash.extension_id().is_none());
571        // No subagent tools leaked from `other`.
572        assert!(merged.get("subagent_start").is_none());
573        assert!(merged.get("subagent").is_none());
574    }
575
576    #[test]
577    fn without_subagent_with_extensions_does_not_overwrite_existing_builtin() {
578        // If `other` somehow contained a tool named like a built-in but with
579        // an extension_id, our merge currently allows it to overwrite. We
580        // skip non-extension tools, but we DO allow extension-owned tools to
581        // shadow names — document this by asserting that built-ins without
582        // matching names in `other` are preserved unchanged.
583        let mut other = ToolRegistry::empty();
584        other.register(Arc::new(OwnedTool("ext:custom", Some("ext"))));
585
586        let merged = ToolRegistry::without_subagent_with_extensions(&other);
587        assert!(merged.get("ext:custom").is_some());
588        assert!(merged.get("bash").is_some());
589        assert!(merged.get("bash").unwrap().extension_id().is_none());
590    }
591}