1use crate::{
4 ExecutionContext, FromToolCall, ProviderSchema, Tool, ToolCall, ToolError, ToolExecutor,
5 ToolMetadata, ToolResult, ToolSchema,
6};
7use dashmap::DashMap;
8use serde_json::Value;
9use std::collections::HashSet;
10use std::sync::Arc;
11
12pub struct ToolRegistry {
14 tools: DashMap<String, Arc<dyn DynamicTool>>,
15 tags: DashMap<String, HashSet<String>>,
16 categories: DashMap<String, HashSet<String>>,
17}
18
19impl ToolRegistry {
20 pub fn new() -> Self {
21 Self {
22 tools: DashMap::new(),
23 tags: DashMap::new(),
24 categories: DashMap::new(),
25 }
26 }
27
28 pub fn register<T>(&self, tool: T) -> ToolResult<()>
29 where
30 T: Tool + FromToolCall + ToolExecutor + 'static,
31 T::Output: serde::Serialize,
32 {
33 let wrapper = DynamicToolWrapper::new(tool);
34 let metadata = wrapper.metadata();
35 let name = metadata.name.clone();
36
37 if self.tools.contains_key(&name) {
38 return Err(ToolError::execution_failed(format!(
39 "Tool '{}' already registered",
40 name
41 )));
42 }
43
44 for tag in &metadata.tags {
45 self.tags
46 .entry(tag.clone())
47 .or_insert_with(HashSet::new)
48 .insert(name.clone());
49 }
50
51 if let Some(category) = &metadata.category {
52 self.categories
53 .entry(category.clone())
54 .or_insert_with(HashSet::new)
55 .insert(name.clone());
56 }
57
58 self.tools.insert(name, Arc::new(wrapper));
59 Ok(())
60 }
61
62 pub fn get(&self, name: &str) -> Option<Arc<dyn DynamicTool>> {
63 self.tools.get(name).map(|entry| Arc::clone(entry.value()))
64 }
65
66 pub fn list_tools(&self) -> Vec<String> {
67 self.tools.iter().map(|entry| entry.key().clone()).collect()
68 }
69
70 pub fn list_metadata(&self) -> Vec<ToolMetadata> {
71 self.tools
72 .iter()
73 .map(|entry| entry.value().metadata())
74 .collect()
75 }
76
77 pub fn find_by_tag(&self, tag: &str) -> Vec<String> {
78 self.tags
79 .get(tag)
80 .map(|entry| entry.value().iter().cloned().collect())
81 .unwrap_or_default()
82 }
83
84 pub fn find_by_category(&self, category: &str) -> Vec<String> {
85 self.categories
86 .get(category)
87 .map(|entry| entry.value().iter().cloned().collect())
88 .unwrap_or_default()
89 }
90
91 pub fn export_schemas(&self, provider: Provider) -> Vec<Value> {
92 self.tools
93 .iter()
94 .map(|entry| {
95 let schema = entry.value().schema();
96 match provider {
97 Provider::OpenAI => schema.to_openai_schema(),
98 Provider::Anthropic => schema.to_anthropic_schema(),
99 Provider::Gemini => schema.to_gemini_schema(),
100 Provider::Generic => schema.to_json_schema(),
101 }
102 })
103 .collect()
104 }
105
106 pub fn len(&self) -> usize {
107 self.tools.len()
108 }
109
110 pub fn is_empty(&self) -> bool {
111 self.tools.is_empty()
112 }
113}
114
115impl Default for ToolRegistry {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
123pub enum Provider {
124 OpenAI,
125 Anthropic,
126 Gemini,
127 Generic,
128}
129
130pub trait DynamicTool: Send + Sync {
132 fn metadata(&self) -> ToolMetadata;
133 fn schema(&self) -> ToolSchema;
134 fn execute_dynamic<'a>(
135 &'a self,
136 args: Value,
137 ctx: &'a ExecutionContext,
138 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolResult<Value>> + Send + 'a>>;
139}
140
141struct DynamicToolWrapper<T> {
143 tool: Arc<T>,
144}
145
146impl<T> DynamicToolWrapper<T> {
147 fn new(tool: T) -> Self {
148 Self {
149 tool: Arc::new(tool),
150 }
151 }
152}
153
154impl<T> DynamicTool for DynamicToolWrapper<T>
155where
156 T: Tool + FromToolCall + ToolExecutor + Send + Sync + 'static,
157 T::Output: serde::Serialize,
158{
159 fn metadata(&self) -> ToolMetadata {
160 self.tool.metadata()
161 }
162
163 fn schema(&self) -> ToolSchema {
164 self.tool.schema()
165 }
166
167 fn execute_dynamic<'a>(
168 &'a self,
169 args: Value,
170 ctx: &'a ExecutionContext,
171 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolResult<Value>> + Send + 'a>> {
172 Box::pin(async move {
173 let tool_call = ToolCall::new(self.tool.name(), args);
174 let instance = T::from_tool_call(&tool_call)?;
175 let result = instance.execute_tool(ctx).await?;
176 serde_json::to_value(&result).map_err(|e| ToolError::Serialization(e))
177 })
178 }
179}