trae_agent_rs_core/tools/
registry.rs1use crate::tools::{Tool, ToolExecutor};
4use std::collections::HashMap;
5
6pub struct ToolRegistry {
8 factories: HashMap<String, Box<dyn ToolFactory>>,
9}
10
11pub trait ToolFactory: Send + Sync {
13 fn create(&self) -> Box<dyn Tool>;
15
16 fn tool_name(&self) -> &str;
18
19 fn tool_description(&self) -> &str;
21}
22
23impl ToolRegistry {
24 pub fn new() -> Self {
26 Self {
27 factories: HashMap::new(),
28 }
29 }
30
31 pub fn register_factory(&mut self, factory: Box<dyn ToolFactory>) {
33 self.factories.insert(factory.tool_name().to_string(), factory);
34 }
35
36 pub fn create_tool(&self, name: &str) -> Option<Box<dyn Tool>> {
38 self.factories.get(name).map(|factory| factory.create())
39 }
40
41 pub fn list_tools(&self) -> Vec<&str> {
43 self.factories.keys().map(|s| s.as_str()).collect()
44 }
45
46 pub fn get_tool_info(&self, name: &str) -> Option<(&str, &str)> {
48 self.factories.get(name).map(|factory| {
49 (factory.tool_name(), factory.tool_description())
50 })
51 }
52
53 pub fn create_executor(&self, tool_names: &[String]) -> ToolExecutor {
55 let mut executor = ToolExecutor::new();
56
57 for name in tool_names {
58 if let Some(tool) = self.create_tool(name) {
59 executor.register_tool(tool);
60 }
61 }
62
63 executor
64 }
65
66 pub fn create_executor_with_all(&self) -> ToolExecutor {
68 let mut executor = ToolExecutor::new();
69
70 for factory in self.factories.values() {
71 executor.register_tool(factory.create());
72 }
73
74 executor
75 }
76}
77
78impl Default for ToolRegistry {
79 fn default() -> Self {
80 let mut registry = Self::new();
81
82 registry.register_factory(Box::new(crate::tools::builtin::BashToolFactory));
84 registry.register_factory(Box::new(crate::tools::builtin::EditToolFactory));
85 registry.register_factory(Box::new(crate::tools::builtin::ThinkingToolFactory));
86 registry.register_factory(Box::new(crate::tools::builtin::TaskDoneToolFactory));
87 registry.register_factory(Box::new(crate::tools::builtin::JsonEditToolFactory));
88 registry.register_factory(Box::new(crate::tools::builtin::CkgToolFactory));
89 registry.register_factory(Box::new(crate::tools::builtin::McpToolFactory));
90
91 registry
92 }
93}
94
95#[macro_export]
97macro_rules! impl_tool_factory {
98 ($factory:ident, $tool:ident, $name:expr, $description:expr) => {
99 pub struct $factory;
100
101 impl $crate::tools::ToolFactory for $factory {
102 fn create(&self) -> Box<dyn $crate::tools::Tool> {
103 Box::new($tool::new())
104 }
105
106 fn tool_name(&self) -> &str {
107 $name
108 }
109
110 fn tool_description(&self) -> &str {
111 $description
112 }
113 }
114 };
115}
116
117#[cfg(test)]
118mod tests {
119 use crate::tools::registry::ToolRegistry;
120
121 #[test]
122 fn test_default_registry_has_all_tools() {
123 let registry = ToolRegistry::default();
124 let tools = registry.list_tools();
125
126 let expected_tools = vec![
128 "bash",
129 "str_replace_based_edit_tool",
130 "sequentialthinking",
131 "task_done",
132 "json_edit_tool",
133 "ckg_tool",
134 "mcp_tool",
135 ];
136
137 println!("Available tools: {:?}", tools);
138
139 for expected_tool in &expected_tools {
141 assert!(
142 tools.contains(expected_tool),
143 "Tool '{}' is not registered in the default registry",
144 expected_tool
145 );
146 }
147
148 assert_eq!(
150 tools.len(),
151 expected_tools.len(),
152 "Expected {} tools, but found {}. Tools: {:?}",
153 expected_tools.len(),
154 tools.len(),
155 tools
156 );
157 }
158
159 #[test]
160 fn test_tool_creation() {
161 let registry = ToolRegistry::default();
162
163 let tools_to_test = vec![
165 "bash",
166 "str_replace_based_edit_tool",
167 "sequentialthinking",
168 "task_done",
169 "json_edit_tool",
170 "ckg_tool",
171 "mcp_tool",
172 ];
173
174 for tool_name in tools_to_test {
175 let tool = registry.create_tool(tool_name);
176 assert!(
177 tool.is_some(),
178 "Failed to create tool '{}'",
179 tool_name
180 );
181
182 let tool = tool.unwrap();
183 assert_eq!(
184 tool.name(),
185 tool_name,
186 "Tool name mismatch for '{}'",
187 tool_name
188 );
189
190 assert!(
192 !tool.description().is_empty(),
193 "Tool '{}' has empty description",
194 tool_name
195 );
196
197 let schema = tool.parameters_schema();
199 assert!(
200 schema.is_object(),
201 "Tool '{}' parameters schema is not an object",
202 tool_name
203 );
204 }
205 }
206
207 #[test]
208 fn test_tool_info() {
209 let registry = ToolRegistry::default();
210
211 for tool_name in registry.list_tools() {
212 let info = registry.get_tool_info(tool_name);
213 assert!(
214 info.is_some(),
215 "Failed to get info for tool '{}'",
216 tool_name
217 );
218
219 let (name, description) = info.unwrap();
220 assert_eq!(name, tool_name);
221 assert!(!description.is_empty());
222 }
223 }
224
225 #[test]
226 fn test_executor_creation() {
227 let registry = ToolRegistry::default();
228
229 let tool_names = vec!["bash".to_string(), "str_replace_based_edit_tool".to_string()];
231 let _executor = registry.create_executor(&tool_names);
232
233 let _all_executor = registry.create_executor_with_all();
239 }
240
241 #[test]
242 fn test_tool_examples() {
243 let registry = ToolRegistry::default();
244
245 for tool_name in registry.list_tools() {
246 let tool = registry.create_tool(tool_name).unwrap();
247 let examples = tool.examples();
248
249 assert!(
251 !examples.is_empty(),
252 "Tool '{}' has no examples",
253 tool_name
254 );
255
256 for (i, example) in examples.iter().enumerate() {
258 assert!(
259 !example.description.is_empty(),
260 "Tool '{}' example {} has empty description",
261 tool_name,
262 i
263 );
264
265 assert!(
266 example.parameters.is_object(),
267 "Tool '{}' example {} parameters is not an object",
268 tool_name,
269 i
270 );
271
272 assert!(
273 !example.expected_result.is_empty(),
274 "Tool '{}' example {} has empty expected result",
275 tool_name,
276 i
277 );
278 }
279 }
280 }
281
282 #[test]
283 fn test_tool_parameter_schemas() {
284 let registry = ToolRegistry::default();
285
286 for tool_name in registry.list_tools() {
287 let tool = registry.create_tool(tool_name).unwrap();
288 let schema = tool.parameters_schema();
289
290 assert!(
292 schema.is_object(),
293 "Tool '{}' schema is not an object",
294 tool_name
295 );
296
297 let schema_obj = schema.as_object().unwrap();
298
299 if let Some(type_val) = schema_obj.get("type") {
301 assert_eq!(
302 type_val.as_str(),
303 Some("object"),
304 "Tool '{}' schema type is not 'object'",
305 tool_name
306 );
307 }
308
309 if let Some(properties) = schema_obj.get("properties") {
311 assert!(
312 properties.is_object(),
313 "Tool '{}' schema properties is not an object",
314 tool_name
315 );
316
317 let props = properties.as_object().unwrap();
318 assert!(
319 !props.is_empty(),
320 "Tool '{}' has no properties in schema",
321 tool_name
322 );
323 }
324 }
325 }
326}