Skip to main content

saorsa_agent/
tool.rs

1//! Tool trait and registry for agent tool execution.
2
3use std::collections::HashMap;
4
5use saorsa_ai::ToolDefinition;
6
7use crate::error::Result;
8
9/// A tool that the agent can execute.
10#[async_trait::async_trait]
11pub trait Tool: Send + Sync {
12    /// The unique name of the tool.
13    fn name(&self) -> &str;
14
15    /// A human-readable description of what the tool does.
16    fn description(&self) -> &str;
17
18    /// JSON Schema describing the tool's input parameters.
19    fn input_schema(&self) -> serde_json::Value;
20
21    /// Execute the tool with the given JSON input and return the result as a string.
22    async fn execute(&self, input: serde_json::Value) -> Result<String>;
23
24    /// Convert this tool to a `ToolDefinition` for the LLM API.
25    fn to_definition(&self) -> ToolDefinition {
26        ToolDefinition::new(self.name(), self.description(), self.input_schema())
27    }
28}
29
30/// Registry for managing available tools.
31pub struct ToolRegistry {
32    tools: HashMap<String, Box<dyn Tool>>,
33}
34
35impl ToolRegistry {
36    /// Create a new empty tool registry.
37    pub fn new() -> Self {
38        Self {
39            tools: HashMap::new(),
40        }
41    }
42
43    /// Register a tool. Replaces any existing tool with the same name.
44    pub fn register(&mut self, tool: Box<dyn Tool>) {
45        self.tools.insert(tool.name().to_string(), tool);
46    }
47
48    /// Look up a tool by name.
49    pub fn get(&self, name: &str) -> Option<&dyn Tool> {
50        self.tools.get(name).map(AsRef::as_ref)
51    }
52
53    /// Get all tool definitions for the LLM API.
54    pub fn definitions(&self) -> Vec<ToolDefinition> {
55        self.tools.values().map(|t| t.to_definition()).collect()
56    }
57
58    /// Get the names of all registered tools.
59    pub fn names(&self) -> Vec<&str> {
60        self.tools.keys().map(String::as_str).collect()
61    }
62
63    /// Return the number of registered tools.
64    pub fn len(&self) -> usize {
65        self.tools.len()
66    }
67
68    /// Check if the registry is empty.
69    pub fn is_empty(&self) -> bool {
70        self.tools.is_empty()
71    }
72}
73
74impl Default for ToolRegistry {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    struct EchoTool;
85
86    #[async_trait::async_trait]
87    impl Tool for EchoTool {
88        fn name(&self) -> &str {
89            "echo"
90        }
91        fn description(&self) -> &str {
92            "Echoes input back"
93        }
94        fn input_schema(&self) -> serde_json::Value {
95            serde_json::json!({
96                "type": "object",
97                "properties": {
98                    "text": {"type": "string"}
99                },
100                "required": ["text"]
101            })
102        }
103        async fn execute(&self, input: serde_json::Value) -> Result<String> {
104            let text = input
105                .get("text")
106                .and_then(|v| v.as_str())
107                .unwrap_or("(empty)");
108            Ok(text.to_string())
109        }
110    }
111
112    #[test]
113    fn registry_register_and_get() {
114        let mut registry = ToolRegistry::new();
115        registry.register(Box::new(EchoTool));
116        assert_eq!(registry.len(), 1);
117        assert!(!registry.is_empty());
118        assert!(registry.get("echo").is_some());
119        assert!(registry.get("nonexistent").is_none());
120    }
121
122    #[test]
123    fn registry_definitions() {
124        let mut registry = ToolRegistry::new();
125        registry.register(Box::new(EchoTool));
126        let defs = registry.definitions();
127        assert_eq!(defs.len(), 1);
128        assert_eq!(defs[0].name, "echo");
129    }
130
131    #[test]
132    fn registry_names() {
133        let mut registry = ToolRegistry::new();
134        registry.register(Box::new(EchoTool));
135        let names = registry.names();
136        assert!(names.contains(&"echo"));
137    }
138
139    #[test]
140    fn tool_to_definition() {
141        let tool = EchoTool;
142        let def = tool.to_definition();
143        assert_eq!(def.name, "echo");
144        assert_eq!(def.description, "Echoes input back");
145    }
146
147    #[test]
148    fn registry_default() {
149        let registry = ToolRegistry::default();
150        assert!(registry.is_empty());
151    }
152
153    #[tokio::test]
154    async fn tool_execute() {
155        let tool = EchoTool;
156        let result = tool.execute(serde_json::json!({"text": "hello"})).await;
157        assert!(result.is_ok());
158        if let Ok(output) = result {
159            assert_eq!(output, "hello");
160        }
161    }
162}