1use std::collections::HashMap;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::error::{AgnoError, Result};
9
10#[async_trait]
11pub trait Tool: Send + Sync {
12 fn name(&self) -> &str;
13 fn description(&self) -> &str;
14
15 fn parameters(&self) -> Option<Value> {
17 None
18 }
19 async fn call(&self, input: Value) -> Result<Value>;
20}
21
22#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
24pub struct ToolDescription {
25 pub name: String,
26 pub description: String,
27 pub parameters: Option<Value>,
28}
29
30#[derive(Default, Clone)]
31pub struct ToolRegistry {
32 tools: HashMap<String, Arc<dyn Tool>>,
33}
34
35impl ToolRegistry {
36 pub fn new() -> Self {
37 Self {
38 tools: HashMap::new(),
39 }
40 }
41
42 pub fn register<T: Tool + 'static>(&mut self, tool: T) {
43 self.tools.insert(tool.name().to_string(), Arc::new(tool));
44 }
45
46 pub fn names(&self) -> Vec<String> {
47 self.tools.keys().cloned().collect()
48 }
49
50 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
52 self.tools.get(name).cloned()
53 }
54
55 pub fn describe(&self) -> Vec<ToolDescription> {
56 let mut descriptions: Vec<ToolDescription> = self
57 .tools
58 .values()
59 .map(|tool| ToolDescription {
60 name: tool.name().to_string(),
61 description: tool.description().to_string(),
62 parameters: tool.parameters(),
63 })
64 .collect();
65
66 descriptions.sort_by(|a, b| a.name.cmp(&b.name));
67 descriptions
68 }
69
70 pub async fn call(&self, name: &str, input: Value) -> Result<Value> {
71 let tool = self
72 .tools
73 .get(name)
74 .ok_or_else(|| AgnoError::ToolNotFound(name.to_string()))?;
75 tool.call(input)
76 .await
77 .map_err(|source| AgnoError::ToolInvocation {
78 name: name.to_string(),
79 source: Box::new(source),
80 })
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use async_trait::async_trait;
88
89 struct Echo;
90
91 #[async_trait]
92 impl Tool for Echo {
93 fn name(&self) -> &str {
94 "echo"
95 }
96
97 fn description(&self) -> &str {
98 "Echoes whatever input is provided"
99 }
100
101 fn parameters(&self) -> Option<Value> {
102 Some(serde_json::json!({
103 "type": "object",
104 "properties": {"text": {"type": "string"}},
105 "required": ["text"],
106 }))
107 }
108
109 async fn call(&self, input: Value) -> Result<Value> {
110 Ok(input)
111 }
112 }
113
114 #[tokio::test]
115 async fn describes_registered_tools_with_parameters() {
116 let mut registry = ToolRegistry::new();
117 registry.register(Echo);
118
119 let descriptions = registry.describe();
120 assert_eq!(descriptions.len(), 1);
121 let desc = &descriptions[0];
122 assert_eq!(desc.name, "echo");
123 assert!(desc
124 .parameters
125 .as_ref()
126 .unwrap()
127 .get("properties")
128 .is_some());
129
130 let output = registry
131 .call("echo", serde_json::json!({"text":"hi"}))
132 .await
133 .unwrap();
134 assert_eq!(output["text"], "hi");
135 }
136
137 #[tokio::test]
138 async fn describes_tools_in_deterministic_order() {
139 struct Second;
140
141 #[async_trait]
142 impl Tool for Second {
143 fn name(&self) -> &str {
144 "second"
145 }
146
147 fn description(&self) -> &str {
148 "Second tool"
149 }
150
151 async fn call(&self, input: Value) -> Result<Value> {
152 Ok(input)
153 }
154 }
155
156 let mut registry = ToolRegistry::new();
157 registry.register(Echo);
158 registry.register(Second);
159
160 let descriptions = registry.describe();
161 let names: Vec<String> = descriptions.into_iter().map(|d| d.name).collect();
162 assert_eq!(names, vec!["echo", "second"]);
163 }
164}