spec_ai_core/tools/
mod.rs1pub mod builtin;
2pub mod plugin_adapter;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12use self::builtin::{
13 AudioTranscriptionTool, BashTool, CodeSearchTool, EchoTool, FileExtractTool, FileReadTool,
14 FileWriteTool, GenerateCodeTool, GraphTool, GrepTool, MathTool, PromptUserTool, RgTool,
15 SearchTool, ShellTool,
16};
17
18#[cfg(feature = "api")]
19use self::builtin::WebSearchTool;
20
21#[cfg(feature = "web-scraping")]
22use self::builtin::WebScraperTool;
23use crate::agent::model::ModelProvider;
24use crate::embeddings::EmbeddingsClient;
25use crate::persistence::Persistence;
26
27pub use plugin_adapter::PluginToolAdapter;
28
29#[cfg(feature = "openai")]
30use async_openai::types::ChatCompletionTool;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ToolResult {
35 pub success: bool,
37 pub output: String,
39 pub error: Option<String>,
41}
42
43impl ToolResult {
44 pub fn success(output: impl Into<String>) -> Self {
46 Self {
47 success: true,
48 output: output.into(),
49 error: None,
50 }
51 }
52
53 pub fn failure(error: impl Into<String>) -> Self {
55 Self {
56 success: false,
57 output: String::new(),
58 error: Some(error.into()),
59 }
60 }
61}
62
63#[async_trait]
65pub trait Tool: Send + Sync {
66 fn name(&self) -> &str;
68
69 fn description(&self) -> &str;
71
72 fn parameters(&self) -> Value;
74
75 async fn execute(&self, args: Value) -> Result<ToolResult>;
77}
78
79pub struct ToolRegistry {
81 tools: HashMap<String, Arc<dyn Tool>>,
82}
83
84impl ToolRegistry {
85 pub fn new() -> Self {
87 Self {
88 tools: HashMap::new(),
89 }
90 }
91
92 #[allow(unused_variables)]
97 pub fn with_builtin_tools(
98 persistence: Option<Arc<Persistence>>,
99 embeddings: Option<EmbeddingsClient>,
100 code_model_provider: Option<Arc<dyn ModelProvider>>,
101 ) -> Self {
102 let mut registry = Self::new();
103
104 registry.register(Arc::new(EchoTool::new()));
106 registry.register(Arc::new(MathTool::new()));
107 registry.register(Arc::new(FileReadTool::new()));
108 registry.register(Arc::new(FileExtractTool::new()));
109 registry.register(Arc::new(FileWriteTool::new()));
110 registry.register(Arc::new(PromptUserTool::new()));
111 registry.register(Arc::new(SearchTool::new()));
112 registry.register(Arc::new(GrepTool::new()));
113 registry.register(Arc::new(RgTool::new()));
114 registry.register(Arc::new(CodeSearchTool::new()));
115 registry.register(Arc::new(BashTool::new()));
116 registry.register(Arc::new(ShellTool::new()));
117 if let Some(provider) = code_model_provider {
118 registry.register(Arc::new(GenerateCodeTool::new(provider)));
119 }
120
121 #[cfg(feature = "api")]
123 registry.register(Arc::new(WebSearchTool::new().with_embeddings(embeddings)));
124
125 #[cfg(feature = "web-scraping")]
127 registry.register(Arc::new(WebScraperTool::new()));
128
129 if let Some(persistence) = persistence {
130 registry.register(Arc::new(GraphTool::new(persistence.clone())));
131 registry.register(Arc::new(AudioTranscriptionTool::with_persistence(
132 persistence,
133 )));
134 } else {
135 registry.register(Arc::new(AudioTranscriptionTool::new()));
136 }
137
138 tracing::debug!("ToolRegistry created with {} tools", registry.tools.len());
139 for name in registry.tools.keys() {
140 tracing::debug!(" - Tool: {}", name);
141 }
142
143 registry
144 }
145
146 pub fn register(&mut self, tool: Arc<dyn Tool>) {
148 let name = tool.name().to_string();
149 self.tools.insert(name, tool);
150 }
151
152 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
154 self.tools.get(name).cloned()
155 }
156
157 pub fn list(&self) -> Vec<&str> {
159 self.tools.keys().map(|s| s.as_str()).collect()
160 }
161
162 pub fn has(&self, name: &str) -> bool {
164 self.tools.contains_key(name)
165 }
166
167 pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
169 let tool = self
170 .get(name)
171 .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
172
173 debug!("Executing tool '{}'", name);
174 let result = tool.execute(args).await;
175 match &result {
176 Ok(res) => {
177 debug!(
178 "Tool '{}' completed: success={}, error={:?}",
179 name, res.success, res.error
180 );
181 }
182 Err(err) => {
183 debug!("Tool '{}' failed to execute: {}", name, err);
184 }
185 }
186 result
187 }
188
189 pub fn len(&self) -> usize {
191 self.tools.len()
192 }
193
194 pub fn is_empty(&self) -> bool {
196 self.tools.is_empty()
197 }
198
199 pub fn load_plugins(
208 &mut self,
209 dir: &std::path::Path,
210 allow_override: bool,
211 ) -> anyhow::Result<spec_ai_plugin::LoadStats> {
212 use spec_ai_plugin::{expand_tilde, PluginLoader};
213
214 let expanded_dir = expand_tilde(dir);
215
216 let mut loader = PluginLoader::new();
217 let stats = loader.load_directory(&expanded_dir)?;
218
219 for (tool_ref, plugin_name) in loader.all_tools() {
221 let adapter = match PluginToolAdapter::new(tool_ref, plugin_name) {
222 Ok(a) => a,
223 Err(e) => {
224 tracing::warn!(
225 "Failed to create adapter for tool from {}: {}",
226 plugin_name,
227 e
228 );
229 continue;
230 }
231 };
232
233 let tool_name = adapter.name().to_string();
234
235 if self.has(&tool_name) {
237 if allow_override {
238 tracing::info!(
239 "Plugin tool '{}' from '{}' overriding built-in tool",
240 tool_name,
241 plugin_name
242 );
243 } else {
244 tracing::warn!(
245 "Plugin tool '{}' from '{}' would override built-in, skipping (set allow_override_builtin=true to allow)",
246 tool_name,
247 plugin_name
248 );
249 continue;
250 }
251 }
252
253 tracing::debug!(
254 "Registering plugin tool '{}' from '{}'",
255 tool_name,
256 plugin_name
257 );
258 self.register(Arc::new(adapter));
259 }
260
261 Ok(stats)
262 }
263
264 #[cfg(any(feature = "openai", feature = "mlx", feature = "lmstudio"))]
269 pub fn to_openai_tools(&self) -> Vec<ChatCompletionTool> {
270 use crate::agent::function_calling::tool_to_openai_function;
271
272 self.tools
273 .values()
274 .map(|tool| {
275 tool_to_openai_function(tool.name(), tool.description(), &tool.parameters())
276 })
277 .collect()
278 }
279}
280
281impl Default for ToolRegistry {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 struct DummyTool;
292
293 #[async_trait]
294 impl Tool for DummyTool {
295 fn name(&self) -> &str {
296 "dummy"
297 }
298
299 fn description(&self) -> &str {
300 "A dummy tool for testing"
301 }
302
303 fn parameters(&self) -> Value {
304 serde_json::json!({
305 "type": "object",
306 "properties": {}
307 })
308 }
309
310 async fn execute(&self, _args: Value) -> Result<ToolResult> {
311 Ok(ToolResult::success("dummy output"))
312 }
313 }
314
315 #[tokio::test]
316 async fn test_register_and_get_tool() {
317 let mut registry = ToolRegistry::new();
318 let tool = Arc::new(DummyTool);
319
320 registry.register(tool.clone());
321
322 assert!(registry.has("dummy"));
323 assert!(registry.get("dummy").is_some());
324 assert_eq!(registry.len(), 1);
325 }
326
327 #[tokio::test]
328 async fn test_list_tools() {
329 let mut registry = ToolRegistry::new();
330 registry.register(Arc::new(DummyTool));
331
332 let tools = registry.list();
333 assert_eq!(tools.len(), 1);
334 assert!(tools.contains(&"dummy"));
335 }
336
337 #[tokio::test]
338 async fn test_execute_tool() {
339 let mut registry = ToolRegistry::new();
340 registry.register(Arc::new(DummyTool));
341
342 let result = registry.execute("dummy", Value::Null).await.unwrap();
343 assert!(result.success);
344 assert_eq!(result.output, "dummy output");
345 }
346
347 #[tokio::test]
348 async fn test_execute_nonexistent_tool() {
349 let registry = ToolRegistry::new();
350 let result = registry.execute("nonexistent", Value::Null).await;
351 assert!(result.is_err());
352 }
353
354 #[tokio::test]
355 async fn test_tool_result_success() {
356 let result = ToolResult::success("test output");
357 assert!(result.success);
358 assert_eq!(result.output, "test output");
359 assert!(result.error.is_none());
360 }
361
362 #[tokio::test]
363 async fn test_tool_result_failure() {
364 let result = ToolResult::failure("test error");
365 assert!(!result.success);
366 assert_eq!(result.error, Some("test error".to_string()));
367 }
368}