1use std::collections::HashMap;
30use std::sync::Arc;
31use thulp_core::{Error, Result, ToolDefinition};
32use tokio::sync::RwLock;
33
34pub struct ToolRegistry {
42 tools: Arc<RwLock<HashMap<String, ToolDefinition>>>,
44
45 tags: Arc<RwLock<HashMap<String, Vec<String>>>>,
47}
48
49impl ToolRegistry {
50 pub fn new() -> Self {
52 Self {
53 tools: Arc::new(RwLock::new(HashMap::new())),
54 tags: Arc::new(RwLock::new(HashMap::new())),
55 }
56 }
57
58 pub async fn register(&self, tool: ToolDefinition) -> Result<()> {
60 let mut tools = self.tools.write().await;
61 tools.insert(tool.name.clone(), tool);
62 Ok(())
63 }
64
65 pub async fn register_many(&self, tools: Vec<ToolDefinition>) -> Result<()> {
67 let mut registry = self.tools.write().await;
68 for tool in tools {
69 registry.insert(tool.name.clone(), tool);
70 }
71 Ok(())
72 }
73
74 pub async fn unregister(&self, name: &str) -> Result<Option<ToolDefinition>> {
76 let mut tools = self.tools.write().await;
77 Ok(tools.remove(name))
78 }
79
80 pub async fn get(&self, name: &str) -> Result<Option<ToolDefinition>> {
82 let tools = self.tools.read().await;
83 Ok(tools.get(name).cloned())
84 }
85
86 pub async fn list(&self) -> Result<Vec<ToolDefinition>> {
88 let tools = self.tools.read().await;
89 Ok(tools.values().cloned().collect())
90 }
91
92 pub async fn count(&self) -> usize {
94 let tools = self.tools.read().await;
95 tools.len()
96 }
97
98 pub async fn clear(&self) {
100 let mut tools = self.tools.write().await;
101 let mut tags = self.tags.write().await;
102 tools.clear();
103 tags.clear();
104 }
105
106 pub async fn contains(&self, name: &str) -> bool {
108 let tools = self.tools.read().await;
109 tools.contains_key(name)
110 }
111
112 pub async fn tag(&self, tool_name: &str, tag: &str) -> Result<()> {
114 let tools = self.tools.read().await;
115 if !tools.contains_key(tool_name) {
116 return Err(Error::InvalidConfig(format!(
117 "Tool '{}' not found in registry",
118 tool_name
119 )));
120 }
121 drop(tools);
122
123 let mut tags = self.tags.write().await;
124 tags.entry(tag.to_string())
125 .or_insert_with(Vec::new)
126 .push(tool_name.to_string());
127 Ok(())
128 }
129
130 pub async fn find_by_tag(&self, tag: &str) -> Result<Vec<ToolDefinition>> {
132 let tags = self.tags.read().await;
133 let tool_names = match tags.get(tag) {
134 Some(names) => names.clone(),
135 None => return Ok(Vec::new()),
136 };
137 drop(tags);
138
139 let tools = self.tools.read().await;
140 let mut results = Vec::new();
141 for name in tool_names {
142 if let Some(tool) = tools.get(&name) {
143 results.push(tool.clone());
144 }
145 }
146 Ok(results)
147 }
148}
149
150impl Default for ToolRegistry {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use thulp_core::Parameter;
160
161 fn create_test_tool(name: &str) -> ToolDefinition {
162 ToolDefinition::builder(name)
163 .description(format!("Test tool: {}", name))
164 .parameter(Parameter::required_string("test_param"))
165 .build()
166 }
167
168 #[tokio::test]
169 async fn registry_creation() {
170 let registry = ToolRegistry::new();
171 assert_eq!(registry.count().await, 0);
172 }
173
174 #[tokio::test]
175 async fn register_and_get_tool() {
176 let registry = ToolRegistry::new();
177 let tool = create_test_tool("test_tool");
178
179 registry.register(tool.clone()).await.unwrap();
180
181 let retrieved = registry.get("test_tool").await.unwrap();
182 assert!(retrieved.is_some());
183 assert_eq!(retrieved.unwrap().name, "test_tool");
184 }
185
186 #[tokio::test]
187 async fn register_many_tools() {
188 let registry = ToolRegistry::new();
189 let tools = vec![
190 create_test_tool("tool1"),
191 create_test_tool("tool2"),
192 create_test_tool("tool3"),
193 ];
194
195 registry.register_many(tools).await.unwrap();
196
197 assert_eq!(registry.count().await, 3);
198 assert!(registry.contains("tool1").await);
199 assert!(registry.contains("tool2").await);
200 assert!(registry.contains("tool3").await);
201 }
202
203 #[tokio::test]
204 async fn unregister_tool() {
205 let registry = ToolRegistry::new();
206 let tool = create_test_tool("test_tool");
207
208 registry.register(tool).await.unwrap();
209 assert_eq!(registry.count().await, 1);
210
211 let removed = registry.unregister("test_tool").await.unwrap();
212 assert!(removed.is_some());
213 assert_eq!(registry.count().await, 0);
214 }
215
216 #[tokio::test]
217 async fn list_tools() {
218 let registry = ToolRegistry::new();
219 let tools = vec![create_test_tool("tool1"), create_test_tool("tool2")];
220
221 registry.register_many(tools).await.unwrap();
222
223 let listed = registry.list().await.unwrap();
224 assert_eq!(listed.len(), 2);
225 }
226
227 #[tokio::test]
228 async fn clear_registry() {
229 let registry = ToolRegistry::new();
230 let tools = vec![create_test_tool("tool1"), create_test_tool("tool2")];
231
232 registry.register_many(tools).await.unwrap();
233 assert_eq!(registry.count().await, 2);
234
235 registry.clear().await;
236 assert_eq!(registry.count().await, 0);
237 }
238
239 #[tokio::test]
240 async fn tag_and_find_tools() {
241 let registry = ToolRegistry::new();
242 let tool1 = create_test_tool("tool1");
243 let tool2 = create_test_tool("tool2");
244 let tool3 = create_test_tool("tool3");
245
246 registry.register(tool1).await.unwrap();
247 registry.register(tool2).await.unwrap();
248 registry.register(tool3).await.unwrap();
249
250 registry.tag("tool1", "filesystem").await.unwrap();
251 registry.tag("tool2", "filesystem").await.unwrap();
252 registry.tag("tool3", "network").await.unwrap();
253
254 let filesystem_tools = registry.find_by_tag("filesystem").await.unwrap();
255 assert_eq!(filesystem_tools.len(), 2);
256
257 let network_tools = registry.find_by_tag("network").await.unwrap();
258 assert_eq!(network_tools.len(), 1);
259 }
260
261 #[tokio::test]
262 async fn tag_nonexistent_tool() {
263 let registry = ToolRegistry::new();
264 let result = registry.tag("nonexistent", "test").await;
265 assert!(result.is_err());
266 }
267}