spec_ai_core/tools/
mod.rs1pub mod builtin;
2pub mod plugin_adapter;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12use self::builtin::{
13 AudioTranscriptionTool, BashTool, CodeSearchTool, EchoTool, FileExtractTool, FileReadTool,
14 FileWriteTool, GraphTool, MathTool, PromptUserTool, SearchTool, ShellTool,
15};
16
17#[cfg(feature = "api")]
18use self::builtin::WebSearchTool;
19
20#[cfg(feature = "web-scraping")]
21use self::builtin::WebScraperTool;
22use crate::embeddings::EmbeddingsClient;
23use crate::persistence::Persistence;
24
25pub use plugin_adapter::PluginToolAdapter;
26
27#[cfg(feature = "openai")]
28use async_openai::types::ChatCompletionTool;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ToolResult {
33 pub success: bool,
35 pub output: String,
37 pub error: Option<String>,
39}
40
41impl ToolResult {
42 pub fn success(output: impl Into<String>) -> Self {
44 Self {
45 success: true,
46 output: output.into(),
47 error: None,
48 }
49 }
50
51 pub fn failure(error: impl Into<String>) -> Self {
53 Self {
54 success: false,
55 output: String::new(),
56 error: Some(error.into()),
57 }
58 }
59}
60
61#[async_trait]
63pub trait Tool: Send + Sync {
64 fn name(&self) -> &str;
66
67 fn description(&self) -> &str;
69
70 fn parameters(&self) -> Value;
72
73 async fn execute(&self, args: Value) -> Result<ToolResult>;
75}
76
77pub struct ToolRegistry {
79 tools: HashMap<String, Arc<dyn Tool>>,
80}
81
82impl ToolRegistry {
83 pub fn new() -> Self {
85 Self {
86 tools: HashMap::new(),
87 }
88 }
89
90 #[allow(unused_variables)]
95 pub fn with_builtin_tools(
96 persistence: Option<Arc<Persistence>>,
97 embeddings: Option<EmbeddingsClient>,
98 ) -> Self {
99 let mut registry = Self::new();
100
101 registry.register(Arc::new(EchoTool::new()));
103 registry.register(Arc::new(MathTool::new()));
104 registry.register(Arc::new(FileReadTool::new()));
105 registry.register(Arc::new(FileExtractTool::new()));
106 registry.register(Arc::new(FileWriteTool::new()));
107 registry.register(Arc::new(PromptUserTool::new()));
108 registry.register(Arc::new(SearchTool::new()));
109 registry.register(Arc::new(CodeSearchTool::new()));
110 registry.register(Arc::new(BashTool::new()));
111 registry.register(Arc::new(ShellTool::new()));
112
113 #[cfg(feature = "api")]
115 registry.register(Arc::new(WebSearchTool::new().with_embeddings(embeddings)));
116
117 #[cfg(feature = "web-scraping")]
119 registry.register(Arc::new(WebScraperTool::new()));
120
121 if let Some(persistence) = persistence {
122 registry.register(Arc::new(GraphTool::new(persistence.clone())));
123 registry.register(Arc::new(AudioTranscriptionTool::with_persistence(
124 persistence,
125 )));
126 } else {
127 registry.register(Arc::new(AudioTranscriptionTool::new()));
128 }
129
130 tracing::debug!("ToolRegistry created with {} tools", registry.tools.len());
131 for name in registry.tools.keys() {
132 tracing::debug!(" - Tool: {}", name);
133 }
134
135 registry
136 }
137
138 pub fn register(&mut self, tool: Arc<dyn Tool>) {
140 let name = tool.name().to_string();
141 self.tools.insert(name, tool);
142 }
143
144 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
146 self.tools.get(name).cloned()
147 }
148
149 pub fn list(&self) -> Vec<&str> {
151 self.tools.keys().map(|s| s.as_str()).collect()
152 }
153
154 pub fn has(&self, name: &str) -> bool {
156 self.tools.contains_key(name)
157 }
158
159 pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
161 let tool = self
162 .get(name)
163 .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
164
165 debug!("Executing tool '{}'", name);
166 let result = tool.execute(args).await;
167 match &result {
168 Ok(res) => {
169 debug!(
170 "Tool '{}' completed: success={}, error={:?}",
171 name, res.success, res.error
172 );
173 }
174 Err(err) => {
175 debug!("Tool '{}' failed to execute: {}", name, err);
176 }
177 }
178 result
179 }
180
181 pub fn len(&self) -> usize {
183 self.tools.len()
184 }
185
186 pub fn is_empty(&self) -> bool {
188 self.tools.is_empty()
189 }
190
191 pub fn load_plugins(
200 &mut self,
201 dir: &std::path::Path,
202 allow_override: bool,
203 ) -> anyhow::Result<spec_ai_plugin::LoadStats> {
204 use spec_ai_plugin::{expand_tilde, PluginLoader};
205
206 let expanded_dir = expand_tilde(dir);
207
208 let mut loader = PluginLoader::new();
209 let stats = loader.load_directory(&expanded_dir)?;
210
211 for (tool_ref, plugin_name) in loader.all_tools() {
213 let adapter = match PluginToolAdapter::new(tool_ref, plugin_name) {
214 Ok(a) => a,
215 Err(e) => {
216 tracing::warn!(
217 "Failed to create adapter for tool from {}: {}",
218 plugin_name,
219 e
220 );
221 continue;
222 }
223 };
224
225 let tool_name = adapter.name().to_string();
226
227 if self.has(&tool_name) {
229 if allow_override {
230 tracing::info!(
231 "Plugin tool '{}' from '{}' overriding built-in tool",
232 tool_name,
233 plugin_name
234 );
235 } else {
236 tracing::warn!(
237 "Plugin tool '{}' from '{}' would override built-in, skipping (set allow_override_builtin=true to allow)",
238 tool_name,
239 plugin_name
240 );
241 continue;
242 }
243 }
244
245 tracing::debug!(
246 "Registering plugin tool '{}' from '{}'",
247 tool_name,
248 plugin_name
249 );
250 self.register(Arc::new(adapter));
251 }
252
253 Ok(stats)
254 }
255
256 #[cfg(any(feature = "openai", feature = "mlx", feature = "lmstudio"))]
261 pub fn to_openai_tools(&self) -> Vec<ChatCompletionTool> {
262 use crate::agent::function_calling::tool_to_openai_function;
263
264 self.tools
265 .values()
266 .map(|tool| {
267 tool_to_openai_function(tool.name(), tool.description(), &tool.parameters())
268 })
269 .collect()
270 }
271}
272
273impl Default for ToolRegistry {
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 struct DummyTool;
284
285 #[async_trait]
286 impl Tool for DummyTool {
287 fn name(&self) -> &str {
288 "dummy"
289 }
290
291 fn description(&self) -> &str {
292 "A dummy tool for testing"
293 }
294
295 fn parameters(&self) -> Value {
296 serde_json::json!({
297 "type": "object",
298 "properties": {}
299 })
300 }
301
302 async fn execute(&self, _args: Value) -> Result<ToolResult> {
303 Ok(ToolResult::success("dummy output"))
304 }
305 }
306
307 #[tokio::test]
308 async fn test_register_and_get_tool() {
309 let mut registry = ToolRegistry::new();
310 let tool = Arc::new(DummyTool);
311
312 registry.register(tool.clone());
313
314 assert!(registry.has("dummy"));
315 assert!(registry.get("dummy").is_some());
316 assert_eq!(registry.len(), 1);
317 }
318
319 #[tokio::test]
320 async fn test_list_tools() {
321 let mut registry = ToolRegistry::new();
322 registry.register(Arc::new(DummyTool));
323
324 let tools = registry.list();
325 assert_eq!(tools.len(), 1);
326 assert!(tools.contains(&"dummy"));
327 }
328
329 #[tokio::test]
330 async fn test_execute_tool() {
331 let mut registry = ToolRegistry::new();
332 registry.register(Arc::new(DummyTool));
333
334 let result = registry.execute("dummy", Value::Null).await.unwrap();
335 assert!(result.success);
336 assert_eq!(result.output, "dummy output");
337 }
338
339 #[tokio::test]
340 async fn test_execute_nonexistent_tool() {
341 let registry = ToolRegistry::new();
342 let result = registry.execute("nonexistent", Value::Null).await;
343 assert!(result.is_err());
344 }
345
346 #[tokio::test]
347 async fn test_tool_result_success() {
348 let result = ToolResult::success("test output");
349 assert!(result.success);
350 assert_eq!(result.output, "test output");
351 assert!(result.error.is_none());
352 }
353
354 #[tokio::test]
355 async fn test_tool_result_failure() {
356 let result = ToolResult::failure("test error");
357 assert!(!result.success);
358 assert_eq!(result.error, Some("test error".to_string()));
359 }
360}