rustant_tools/
registry.rs1use async_trait::async_trait;
8use rustant_core::error::ToolError;
9use rustant_core::types::{RiskLevel, ToolDefinition, ToolOutput};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use tracing::{debug, info};
14
15#[async_trait]
17pub trait Tool: Send + Sync {
18 fn name(&self) -> &str;
20
21 fn description(&self) -> &str;
23
24 fn parameters_schema(&self) -> serde_json::Value;
26
27 async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError>;
29
30 fn risk_level(&self) -> RiskLevel;
32
33 fn timeout(&self) -> Duration {
35 Duration::from_secs(30)
36 }
37}
38
39#[derive(Clone)]
41pub struct ToolRegistry {
42 tools: HashMap<String, Arc<dyn Tool>>,
43}
44
45impl ToolRegistry {
46 pub fn new() -> Self {
47 Self {
48 tools: HashMap::new(),
49 }
50 }
51
52 pub fn register(&mut self, tool: Arc<dyn Tool>) -> Result<(), ToolError> {
54 let name = tool.name().to_string();
55 if self.tools.contains_key(&name) {
56 return Err(ToolError::AlreadyRegistered { name });
57 }
58 debug!(tool = %name, "Registering tool");
59 self.tools.insert(name, tool);
60 Ok(())
61 }
62
63 pub fn unregister(&mut self, name: &str) -> Result<(), ToolError> {
65 if self.tools.remove(name).is_none() {
66 return Err(ToolError::NotFound {
67 name: name.to_string(),
68 });
69 }
70 debug!(tool = %name, "Unregistered tool");
71 Ok(())
72 }
73
74 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
76 self.tools.get(name).cloned()
77 }
78
79 pub fn list_definitions(&self) -> Vec<ToolDefinition> {
81 self.tools
82 .values()
83 .map(|tool| ToolDefinition {
84 name: tool.name().to_string(),
85 description: tool.description().to_string(),
86 parameters: tool.parameters_schema(),
87 })
88 .collect()
89 }
90
91 pub fn list_names(&self) -> Vec<String> {
93 self.tools.keys().cloned().collect()
94 }
95
96 pub fn get_risk_level(&self, name: &str) -> Option<RiskLevel> {
98 self.tools.get(name).map(|t| t.risk_level())
99 }
100
101 pub fn get_parameters_schema(&self, name: &str) -> Option<serde_json::Value> {
103 self.tools.get(name).map(|t| t.parameters_schema())
104 }
105
106 pub fn len(&self) -> usize {
108 self.tools.len()
109 }
110
111 pub fn is_empty(&self) -> bool {
113 self.tools.is_empty()
114 }
115
116 pub async fn execute(
118 &self,
119 name: &str,
120 args: serde_json::Value,
121 ) -> Result<ToolOutput, ToolError> {
122 let tool = self.tools.get(name).ok_or_else(|| ToolError::NotFound {
123 name: name.to_string(),
124 })?;
125
126 let timeout = tool.timeout();
127 info!(tool = %name, timeout_secs = timeout.as_secs(), "Executing tool");
128
129 match tokio::time::timeout(timeout, tool.execute(args)).await {
130 Ok(result) => result,
131 Err(_) => Err(ToolError::Timeout {
132 name: name.to_string(),
133 timeout_secs: timeout.as_secs(),
134 }),
135 }
136 }
137}
138
139impl Default for ToolRegistry {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 struct EchoTool;
151
152 #[async_trait]
153 impl Tool for EchoTool {
154 fn name(&self) -> &str {
155 "echo"
156 }
157
158 fn description(&self) -> &str {
159 "Echoes the input text back"
160 }
161
162 fn parameters_schema(&self) -> serde_json::Value {
163 serde_json::json!({
164 "type": "object",
165 "properties": {
166 "text": { "type": "string", "description": "Text to echo" }
167 },
168 "required": ["text"]
169 })
170 }
171
172 async fn execute(&self, args: serde_json::Value) -> Result<ToolOutput, ToolError> {
173 let text = args["text"]
174 .as_str()
175 .ok_or_else(|| ToolError::InvalidArguments {
176 name: "echo".to_string(),
177 reason: "missing 'text' parameter".to_string(),
178 })?;
179 Ok(ToolOutput::text(format!("Echo: {}", text)))
180 }
181
182 fn risk_level(&self) -> RiskLevel {
183 RiskLevel::ReadOnly
184 }
185 }
186
187 struct SlowTool;
189
190 #[async_trait]
191 impl Tool for SlowTool {
192 fn name(&self) -> &str {
193 "slow"
194 }
195
196 fn description(&self) -> &str {
197 "A tool that takes forever"
198 }
199
200 fn parameters_schema(&self) -> serde_json::Value {
201 serde_json::json!({"type": "object"})
202 }
203
204 async fn execute(&self, _args: serde_json::Value) -> Result<ToolOutput, ToolError> {
205 tokio::time::sleep(Duration::from_secs(60)).await;
206 Ok(ToolOutput::text("done"))
207 }
208
209 fn risk_level(&self) -> RiskLevel {
210 RiskLevel::ReadOnly
211 }
212
213 fn timeout(&self) -> Duration {
214 Duration::from_millis(100) }
216 }
217
218 #[test]
219 fn test_registry_new() {
220 let registry = ToolRegistry::new();
221 assert!(registry.is_empty());
222 assert_eq!(registry.len(), 0);
223 }
224
225 #[test]
226 fn test_register_tool() {
227 let mut registry = ToolRegistry::new();
228 let tool: Arc<dyn Tool> = Arc::new(EchoTool);
229 registry.register(tool).unwrap();
230
231 assert_eq!(registry.len(), 1);
232 assert!(!registry.is_empty());
233 assert!(registry.get("echo").is_some());
234 }
235
236 #[test]
237 fn test_register_duplicate() {
238 let mut registry = ToolRegistry::new();
239 registry.register(Arc::new(EchoTool)).unwrap();
240
241 let result = registry.register(Arc::new(EchoTool));
242 assert!(result.is_err());
243 match result.unwrap_err() {
244 ToolError::AlreadyRegistered { name } => assert_eq!(name, "echo"),
245 _ => panic!("Expected AlreadyRegistered error"),
246 }
247 }
248
249 #[test]
250 fn test_unregister_tool() {
251 let mut registry = ToolRegistry::new();
252 registry.register(Arc::new(EchoTool)).unwrap();
253 assert_eq!(registry.len(), 1);
254
255 registry.unregister("echo").unwrap();
256 assert_eq!(registry.len(), 0);
257 assert!(registry.get("echo").is_none());
258 }
259
260 #[test]
261 fn test_unregister_nonexistent() {
262 let mut registry = ToolRegistry::new();
263 let result = registry.unregister("nonexistent");
264 assert!(result.is_err());
265 }
266
267 #[test]
268 fn test_list_definitions() {
269 let mut registry = ToolRegistry::new();
270 registry.register(Arc::new(EchoTool)).unwrap();
271
272 let defs = registry.list_definitions();
273 assert_eq!(defs.len(), 1);
274 assert_eq!(defs[0].name, "echo");
275 assert_eq!(defs[0].description, "Echoes the input text back");
276 }
277
278 #[test]
279 fn test_list_names() {
280 let mut registry = ToolRegistry::new();
281 registry.register(Arc::new(EchoTool)).unwrap();
282
283 let names = registry.list_names();
284 assert_eq!(names, vec!["echo"]);
285 }
286
287 #[tokio::test]
288 async fn test_execute_tool() {
289 let mut registry = ToolRegistry::new();
290 registry.register(Arc::new(EchoTool)).unwrap();
291
292 let result = registry
293 .execute("echo", serde_json::json!({"text": "hello"}))
294 .await
295 .unwrap();
296 assert_eq!(result.content, "Echo: hello");
297 }
298
299 #[tokio::test]
300 async fn test_execute_nonexistent_tool() {
301 let registry = ToolRegistry::new();
302 let result = registry.execute("missing", serde_json::json!({})).await;
303 assert!(result.is_err());
304 match result.unwrap_err() {
305 ToolError::NotFound { name } => assert_eq!(name, "missing"),
306 _ => panic!("Expected NotFound error"),
307 }
308 }
309
310 #[tokio::test]
311 async fn test_execute_invalid_args() {
312 let mut registry = ToolRegistry::new();
313 registry.register(Arc::new(EchoTool)).unwrap();
314
315 let result = registry.execute("echo", serde_json::json!({})).await;
317 assert!(result.is_err());
318 }
319
320 #[tokio::test]
321 async fn test_execute_timeout() {
322 let mut registry = ToolRegistry::new();
323 registry.register(Arc::new(SlowTool)).unwrap();
324
325 let result = registry.execute("slow", serde_json::json!({})).await;
326 assert!(result.is_err());
327 match result.unwrap_err() {
328 ToolError::Timeout { name, .. } => assert_eq!(name, "slow"),
329 e => panic!("Expected Timeout error, got: {:?}", e),
330 }
331 }
332
333 #[test]
334 fn test_get_nonexistent() {
335 let registry = ToolRegistry::new();
336 assert!(registry.get("missing").is_none());
337 }
338}