spec_ai_core/tools/
mod.rs

1pub mod builtin;
2pub mod plugin_adapter;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12use self::builtin::{
13    AudioTranscriptionTool, BashTool, CodeSearchTool, EchoTool, FileExtractTool, FileReadTool,
14    FileWriteTool, GenerateCodeTool, GraphTool, GrepTool, MathTool, PromptUserTool, RgTool,
15    SearchTool, ShellTool,
16};
17
18#[cfg(feature = "api")]
19use self::builtin::WebSearchTool;
20
21#[cfg(feature = "web-scraping")]
22use self::builtin::WebScraperTool;
23use crate::agent::model::ModelProvider;
24use crate::embeddings::EmbeddingsClient;
25use crate::persistence::Persistence;
26
27pub use plugin_adapter::PluginToolAdapter;
28
29#[cfg(feature = "openai")]
30use async_openai::types::ChatCompletionTool;
31
32/// Result of tool execution
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ToolResult {
35    /// Whether execution succeeded
36    pub success: bool,
37    /// Output from the tool
38    pub output: String,
39    /// Error message if execution failed
40    pub error: Option<String>,
41}
42
43impl ToolResult {
44    /// Create a successful result
45    pub fn success(output: impl Into<String>) -> Self {
46        Self {
47            success: true,
48            output: output.into(),
49            error: None,
50        }
51    }
52
53    /// Create a failure result
54    pub fn failure(error: impl Into<String>) -> Self {
55        Self {
56            success: false,
57            output: String::new(),
58            error: Some(error.into()),
59        }
60    }
61}
62
63/// Trait for all tools that can be executed by the agent
64#[async_trait]
65pub trait Tool: Send + Sync {
66    /// Unique name of the tool
67    fn name(&self) -> &str;
68
69    /// Human-readable description of what the tool does
70    fn description(&self) -> &str;
71
72    /// JSON Schema describing the tool's parameters
73    fn parameters(&self) -> Value;
74
75    /// Execute the tool with the given arguments
76    async fn execute(&self, args: Value) -> Result<ToolResult>;
77}
78
79/// Registry for managing and executing tools
80pub struct ToolRegistry {
81    tools: HashMap<String, Arc<dyn Tool>>,
82}
83
84impl ToolRegistry {
85    /// Create a new empty tool registry
86    pub fn new() -> Self {
87        Self {
88            tools: HashMap::new(),
89        }
90    }
91
92    /// Create a registry populated with all built-in tools.
93    ///
94    /// Tools that require persistence (e.g., `graph`) are only registered when
95    /// an [`Arc<Persistence>`] is provided.
96    #[allow(unused_variables)]
97    pub fn with_builtin_tools(
98        persistence: Option<Arc<Persistence>>,
99        embeddings: Option<EmbeddingsClient>,
100        code_model_provider: Option<Arc<dyn ModelProvider>>,
101    ) -> Self {
102        let mut registry = Self::new();
103
104        // Register all built-in tools
105        registry.register(Arc::new(EchoTool::new()));
106        registry.register(Arc::new(MathTool::new()));
107        registry.register(Arc::new(FileReadTool::new()));
108        registry.register(Arc::new(FileExtractTool::new()));
109        registry.register(Arc::new(FileWriteTool::new()));
110        registry.register(Arc::new(PromptUserTool::new()));
111        registry.register(Arc::new(SearchTool::new()));
112        registry.register(Arc::new(GrepTool::new()));
113        registry.register(Arc::new(RgTool::new()));
114        registry.register(Arc::new(CodeSearchTool::new()));
115        registry.register(Arc::new(BashTool::new()));
116        registry.register(Arc::new(ShellTool::new()));
117        if let Some(provider) = code_model_provider {
118            registry.register(Arc::new(GenerateCodeTool::new(provider)));
119        }
120
121        // Register web search if api feature is enabled
122        #[cfg(feature = "api")]
123        registry.register(Arc::new(WebSearchTool::new().with_embeddings(embeddings)));
124
125        // Register web scraper if feature is enabled
126        #[cfg(feature = "web-scraping")]
127        registry.register(Arc::new(WebScraperTool::new()));
128
129        if let Some(persistence) = persistence {
130            registry.register(Arc::new(GraphTool::new(persistence.clone())));
131            registry.register(Arc::new(AudioTranscriptionTool::with_persistence(
132                persistence,
133            )));
134        } else {
135            registry.register(Arc::new(AudioTranscriptionTool::new()));
136        }
137
138        tracing::debug!("ToolRegistry created with {} tools", registry.tools.len());
139        for name in registry.tools.keys() {
140            tracing::debug!("  - Tool: {}", name);
141        }
142
143        registry
144    }
145
146    /// Register a tool in the registry
147    pub fn register(&mut self, tool: Arc<dyn Tool>) {
148        let name = tool.name().to_string();
149        self.tools.insert(name, tool);
150    }
151
152    /// Get a tool by name
153    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
154        self.tools.get(name).cloned()
155    }
156
157    /// List all registered tool names
158    pub fn list(&self) -> Vec<&str> {
159        self.tools.keys().map(|s| s.as_str()).collect()
160    }
161
162    /// Check if a tool is registered
163    pub fn has(&self, name: &str) -> bool {
164        self.tools.contains_key(name)
165    }
166
167    /// Execute a tool by name with the given arguments
168    pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
169        let tool = self
170            .get(name)
171            .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
172
173        debug!("Executing tool '{}'", name);
174        let result = tool.execute(args).await;
175        match &result {
176            Ok(res) => {
177                debug!(
178                    "Tool '{}' completed: success={}, error={:?}",
179                    name, res.success, res.error
180                );
181            }
182            Err(err) => {
183                debug!("Tool '{}' failed to execute: {}", name, err);
184            }
185        }
186        result
187    }
188
189    /// Get the number of registered tools
190    pub fn len(&self) -> usize {
191        self.tools.len()
192    }
193
194    /// Check if the registry is empty
195    pub fn is_empty(&self) -> bool {
196        self.tools.is_empty()
197    }
198
199    /// Load plugins from a directory and register their tools
200    ///
201    /// # Arguments
202    /// * `dir` - Directory containing plugin libraries
203    /// * `allow_override` - Whether plugins can override built-in tools
204    ///
205    /// # Returns
206    /// Statistics about the loading process
207    pub fn load_plugins(
208        &mut self,
209        dir: &std::path::Path,
210        allow_override: bool,
211    ) -> anyhow::Result<spec_ai_plugin::LoadStats> {
212        use spec_ai_plugin::{expand_tilde, PluginLoader};
213
214        let expanded_dir = expand_tilde(dir);
215
216        let mut loader = PluginLoader::new();
217        let stats = loader.load_directory(&expanded_dir)?;
218
219        // Register tools from plugins
220        for (tool_ref, plugin_name) in loader.all_tools() {
221            let adapter = match PluginToolAdapter::new(tool_ref, plugin_name) {
222                Ok(a) => a,
223                Err(e) => {
224                    tracing::warn!(
225                        "Failed to create adapter for tool from {}: {}",
226                        plugin_name,
227                        e
228                    );
229                    continue;
230                }
231            };
232
233            let tool_name = adapter.name().to_string();
234
235            // Check for conflicts with built-in tools
236            if self.has(&tool_name) {
237                if allow_override {
238                    tracing::info!(
239                        "Plugin tool '{}' from '{}' overriding built-in tool",
240                        tool_name,
241                        plugin_name
242                    );
243                } else {
244                    tracing::warn!(
245                        "Plugin tool '{}' from '{}' would override built-in, skipping (set allow_override_builtin=true to allow)",
246                        tool_name,
247                        plugin_name
248                    );
249                    continue;
250                }
251            }
252
253            tracing::debug!(
254                "Registering plugin tool '{}' from '{}'",
255                tool_name,
256                plugin_name
257            );
258            self.register(Arc::new(adapter));
259        }
260
261        Ok(stats)
262    }
263
264    /// Convert all tools in the registry to OpenAI ChatCompletionTool format.
265    ///
266    /// Used by providers that support native function calling (OpenAI-compatible,
267    /// including MLX and LM Studio when enabled).
268    #[cfg(any(feature = "openai", feature = "mlx", feature = "lmstudio"))]
269    pub fn to_openai_tools(&self) -> Vec<ChatCompletionTool> {
270        use crate::agent::function_calling::tool_to_openai_function;
271
272        self.tools
273            .values()
274            .map(|tool| {
275                tool_to_openai_function(tool.name(), tool.description(), &tool.parameters())
276            })
277            .collect()
278    }
279}
280
281impl Default for ToolRegistry {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    struct DummyTool;
292
293    #[async_trait]
294    impl Tool for DummyTool {
295        fn name(&self) -> &str {
296            "dummy"
297        }
298
299        fn description(&self) -> &str {
300            "A dummy tool for testing"
301        }
302
303        fn parameters(&self) -> Value {
304            serde_json::json!({
305                "type": "object",
306                "properties": {}
307            })
308        }
309
310        async fn execute(&self, _args: Value) -> Result<ToolResult> {
311            Ok(ToolResult::success("dummy output"))
312        }
313    }
314
315    #[tokio::test]
316    async fn test_register_and_get_tool() {
317        let mut registry = ToolRegistry::new();
318        let tool = Arc::new(DummyTool);
319
320        registry.register(tool.clone());
321
322        assert!(registry.has("dummy"));
323        assert!(registry.get("dummy").is_some());
324        assert_eq!(registry.len(), 1);
325    }
326
327    #[tokio::test]
328    async fn test_list_tools() {
329        let mut registry = ToolRegistry::new();
330        registry.register(Arc::new(DummyTool));
331
332        let tools = registry.list();
333        assert_eq!(tools.len(), 1);
334        assert!(tools.contains(&"dummy"));
335    }
336
337    #[tokio::test]
338    async fn test_execute_tool() {
339        let mut registry = ToolRegistry::new();
340        registry.register(Arc::new(DummyTool));
341
342        let result = registry.execute("dummy", Value::Null).await.unwrap();
343        assert!(result.success);
344        assert_eq!(result.output, "dummy output");
345    }
346
347    #[tokio::test]
348    async fn test_execute_nonexistent_tool() {
349        let registry = ToolRegistry::new();
350        let result = registry.execute("nonexistent", Value::Null).await;
351        assert!(result.is_err());
352    }
353
354    #[tokio::test]
355    async fn test_tool_result_success() {
356        let result = ToolResult::success("test output");
357        assert!(result.success);
358        assert_eq!(result.output, "test output");
359        assert!(result.error.is_none());
360    }
361
362    #[tokio::test]
363    async fn test_tool_result_failure() {
364        let result = ToolResult::failure("test error");
365        assert!(!result.success);
366        assert_eq!(result.error, Some("test error".to_string()));
367    }
368}