1use std::collections::HashMap;
4
5use saorsa_ai::ToolDefinition;
6
7use crate::error::Result;
8
9#[async_trait::async_trait]
11pub trait Tool: Send + Sync {
12 fn name(&self) -> &str;
14
15 fn description(&self) -> &str;
17
18 fn input_schema(&self) -> serde_json::Value;
20
21 async fn execute(&self, input: serde_json::Value) -> Result<String>;
23
24 fn to_definition(&self) -> ToolDefinition {
26 ToolDefinition::new(self.name(), self.description(), self.input_schema())
27 }
28}
29
30pub struct ToolRegistry {
32 tools: HashMap<String, Box<dyn Tool>>,
33}
34
35impl ToolRegistry {
36 pub fn new() -> Self {
38 Self {
39 tools: HashMap::new(),
40 }
41 }
42
43 pub fn register(&mut self, tool: Box<dyn Tool>) {
45 self.tools.insert(tool.name().to_string(), tool);
46 }
47
48 pub fn get(&self, name: &str) -> Option<&dyn Tool> {
50 self.tools.get(name).map(AsRef::as_ref)
51 }
52
53 pub fn definitions(&self) -> Vec<ToolDefinition> {
55 self.tools.values().map(|t| t.to_definition()).collect()
56 }
57
58 pub fn names(&self) -> Vec<&str> {
60 self.tools.keys().map(String::as_str).collect()
61 }
62
63 pub fn len(&self) -> usize {
65 self.tools.len()
66 }
67
68 pub fn is_empty(&self) -> bool {
70 self.tools.is_empty()
71 }
72}
73
74impl Default for ToolRegistry {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 struct EchoTool;
85
86 #[async_trait::async_trait]
87 impl Tool for EchoTool {
88 fn name(&self) -> &str {
89 "echo"
90 }
91 fn description(&self) -> &str {
92 "Echoes input back"
93 }
94 fn input_schema(&self) -> serde_json::Value {
95 serde_json::json!({
96 "type": "object",
97 "properties": {
98 "text": {"type": "string"}
99 },
100 "required": ["text"]
101 })
102 }
103 async fn execute(&self, input: serde_json::Value) -> Result<String> {
104 let text = input
105 .get("text")
106 .and_then(|v| v.as_str())
107 .unwrap_or("(empty)");
108 Ok(text.to_string())
109 }
110 }
111
112 #[test]
113 fn registry_register_and_get() {
114 let mut registry = ToolRegistry::new();
115 registry.register(Box::new(EchoTool));
116 assert_eq!(registry.len(), 1);
117 assert!(!registry.is_empty());
118 assert!(registry.get("echo").is_some());
119 assert!(registry.get("nonexistent").is_none());
120 }
121
122 #[test]
123 fn registry_definitions() {
124 let mut registry = ToolRegistry::new();
125 registry.register(Box::new(EchoTool));
126 let defs = registry.definitions();
127 assert_eq!(defs.len(), 1);
128 assert_eq!(defs[0].name, "echo");
129 }
130
131 #[test]
132 fn registry_names() {
133 let mut registry = ToolRegistry::new();
134 registry.register(Box::new(EchoTool));
135 let names = registry.names();
136 assert!(names.contains(&"echo"));
137 }
138
139 #[test]
140 fn tool_to_definition() {
141 let tool = EchoTool;
142 let def = tool.to_definition();
143 assert_eq!(def.name, "echo");
144 assert_eq!(def.description, "Echoes input back");
145 }
146
147 #[test]
148 fn registry_default() {
149 let registry = ToolRegistry::default();
150 assert!(registry.is_empty());
151 }
152
153 #[tokio::test]
154 async fn tool_execute() {
155 let tool = EchoTool;
156 let result = tool.execute(serde_json::json!({"text": "hello"})).await;
157 assert!(result.is_ok());
158 if let Ok(output) = result {
159 assert_eq!(output, "hello");
160 }
161 }
162}