spec_ai_core/tools/
mod.rs

1pub mod builtin;
2pub mod plugin_adapter;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12use self::builtin::{
13    AudioTranscriptionTool, BashTool, CodeSearchTool, EchoTool, FileExtractTool, FileReadTool,
14    FileWriteTool, GraphTool, GrepTool, MathTool, PromptUserTool, RgTool, SearchTool, ShellTool,
15};
16
17#[cfg(feature = "api")]
18use self::builtin::WebSearchTool;
19
20#[cfg(feature = "web-scraping")]
21use self::builtin::WebScraperTool;
22use crate::embeddings::EmbeddingsClient;
23use crate::persistence::Persistence;
24
25pub use plugin_adapter::PluginToolAdapter;
26
27#[cfg(feature = "openai")]
28use async_openai::types::ChatCompletionTool;
29
30/// Result of tool execution
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ToolResult {
33    /// Whether execution succeeded
34    pub success: bool,
35    /// Output from the tool
36    pub output: String,
37    /// Error message if execution failed
38    pub error: Option<String>,
39}
40
41impl ToolResult {
42    /// Create a successful result
43    pub fn success(output: impl Into<String>) -> Self {
44        Self {
45            success: true,
46            output: output.into(),
47            error: None,
48        }
49    }
50
51    /// Create a failure result
52    pub fn failure(error: impl Into<String>) -> Self {
53        Self {
54            success: false,
55            output: String::new(),
56            error: Some(error.into()),
57        }
58    }
59}
60
61/// Trait for all tools that can be executed by the agent
62#[async_trait]
63pub trait Tool: Send + Sync {
64    /// Unique name of the tool
65    fn name(&self) -> &str;
66
67    /// Human-readable description of what the tool does
68    fn description(&self) -> &str;
69
70    /// JSON Schema describing the tool's parameters
71    fn parameters(&self) -> Value;
72
73    /// Execute the tool with the given arguments
74    async fn execute(&self, args: Value) -> Result<ToolResult>;
75}
76
77/// Registry for managing and executing tools
78pub struct ToolRegistry {
79    tools: HashMap<String, Arc<dyn Tool>>,
80}
81
82impl ToolRegistry {
83    /// Create a new empty tool registry
84    pub fn new() -> Self {
85        Self {
86            tools: HashMap::new(),
87        }
88    }
89
90    /// Create a registry populated with all built-in tools.
91    ///
92    /// Tools that require persistence (e.g., `graph`) are only registered when
93    /// an [`Arc<Persistence>`] is provided.
94    #[allow(unused_variables)]
95    pub fn with_builtin_tools(
96        persistence: Option<Arc<Persistence>>,
97        embeddings: Option<EmbeddingsClient>,
98    ) -> Self {
99        let mut registry = Self::new();
100
101        // Register all built-in tools
102        registry.register(Arc::new(EchoTool::new()));
103        registry.register(Arc::new(MathTool::new()));
104        registry.register(Arc::new(FileReadTool::new()));
105        registry.register(Arc::new(FileExtractTool::new()));
106        registry.register(Arc::new(FileWriteTool::new()));
107        registry.register(Arc::new(PromptUserTool::new()));
108        registry.register(Arc::new(SearchTool::new()));
109        registry.register(Arc::new(GrepTool::new()));
110        registry.register(Arc::new(RgTool::new()));
111        registry.register(Arc::new(CodeSearchTool::new()));
112        registry.register(Arc::new(BashTool::new()));
113        registry.register(Arc::new(ShellTool::new()));
114
115        // Register web search if api feature is enabled
116        #[cfg(feature = "api")]
117        registry.register(Arc::new(WebSearchTool::new().with_embeddings(embeddings)));
118
119        // Register web scraper if feature is enabled
120        #[cfg(feature = "web-scraping")]
121        registry.register(Arc::new(WebScraperTool::new()));
122
123        if let Some(persistence) = persistence {
124            registry.register(Arc::new(GraphTool::new(persistence.clone())));
125            registry.register(Arc::new(AudioTranscriptionTool::with_persistence(
126                persistence,
127            )));
128        } else {
129            registry.register(Arc::new(AudioTranscriptionTool::new()));
130        }
131
132        tracing::debug!("ToolRegistry created with {} tools", registry.tools.len());
133        for name in registry.tools.keys() {
134            tracing::debug!("  - Tool: {}", name);
135        }
136
137        registry
138    }
139
140    /// Register a tool in the registry
141    pub fn register(&mut self, tool: Arc<dyn Tool>) {
142        let name = tool.name().to_string();
143        self.tools.insert(name, tool);
144    }
145
146    /// Get a tool by name
147    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
148        self.tools.get(name).cloned()
149    }
150
151    /// List all registered tool names
152    pub fn list(&self) -> Vec<&str> {
153        self.tools.keys().map(|s| s.as_str()).collect()
154    }
155
156    /// Check if a tool is registered
157    pub fn has(&self, name: &str) -> bool {
158        self.tools.contains_key(name)
159    }
160
161    /// Execute a tool by name with the given arguments
162    pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
163        let tool = self
164            .get(name)
165            .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
166
167        debug!("Executing tool '{}'", name);
168        let result = tool.execute(args).await;
169        match &result {
170            Ok(res) => {
171                debug!(
172                    "Tool '{}' completed: success={}, error={:?}",
173                    name, res.success, res.error
174                );
175            }
176            Err(err) => {
177                debug!("Tool '{}' failed to execute: {}", name, err);
178            }
179        }
180        result
181    }
182
183    /// Get the number of registered tools
184    pub fn len(&self) -> usize {
185        self.tools.len()
186    }
187
188    /// Check if the registry is empty
189    pub fn is_empty(&self) -> bool {
190        self.tools.is_empty()
191    }
192
193    /// Load plugins from a directory and register their tools
194    ///
195    /// # Arguments
196    /// * `dir` - Directory containing plugin libraries
197    /// * `allow_override` - Whether plugins can override built-in tools
198    ///
199    /// # Returns
200    /// Statistics about the loading process
201    pub fn load_plugins(
202        &mut self,
203        dir: &std::path::Path,
204        allow_override: bool,
205    ) -> anyhow::Result<spec_ai_plugin::LoadStats> {
206        use spec_ai_plugin::{expand_tilde, PluginLoader};
207
208        let expanded_dir = expand_tilde(dir);
209
210        let mut loader = PluginLoader::new();
211        let stats = loader.load_directory(&expanded_dir)?;
212
213        // Register tools from plugins
214        for (tool_ref, plugin_name) in loader.all_tools() {
215            let adapter = match PluginToolAdapter::new(tool_ref, plugin_name) {
216                Ok(a) => a,
217                Err(e) => {
218                    tracing::warn!(
219                        "Failed to create adapter for tool from {}: {}",
220                        plugin_name,
221                        e
222                    );
223                    continue;
224                }
225            };
226
227            let tool_name = adapter.name().to_string();
228
229            // Check for conflicts with built-in tools
230            if self.has(&tool_name) {
231                if allow_override {
232                    tracing::info!(
233                        "Plugin tool '{}' from '{}' overriding built-in tool",
234                        tool_name,
235                        plugin_name
236                    );
237                } else {
238                    tracing::warn!(
239                        "Plugin tool '{}' from '{}' would override built-in, skipping (set allow_override_builtin=true to allow)",
240                        tool_name,
241                        plugin_name
242                    );
243                    continue;
244                }
245            }
246
247            tracing::debug!(
248                "Registering plugin tool '{}' from '{}'",
249                tool_name,
250                plugin_name
251            );
252            self.register(Arc::new(adapter));
253        }
254
255        Ok(stats)
256    }
257
258    /// Convert all tools in the registry to OpenAI ChatCompletionTool format.
259    ///
260    /// Used by providers that support native function calling (OpenAI-compatible,
261    /// including MLX and LM Studio when enabled).
262    #[cfg(any(feature = "openai", feature = "mlx", feature = "lmstudio"))]
263    pub fn to_openai_tools(&self) -> Vec<ChatCompletionTool> {
264        use crate::agent::function_calling::tool_to_openai_function;
265
266        self.tools
267            .values()
268            .map(|tool| {
269                tool_to_openai_function(tool.name(), tool.description(), &tool.parameters())
270            })
271            .collect()
272    }
273}
274
275impl Default for ToolRegistry {
276    fn default() -> Self {
277        Self::new()
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    struct DummyTool;
286
287    #[async_trait]
288    impl Tool for DummyTool {
289        fn name(&self) -> &str {
290            "dummy"
291        }
292
293        fn description(&self) -> &str {
294            "A dummy tool for testing"
295        }
296
297        fn parameters(&self) -> Value {
298            serde_json::json!({
299                "type": "object",
300                "properties": {}
301            })
302        }
303
304        async fn execute(&self, _args: Value) -> Result<ToolResult> {
305            Ok(ToolResult::success("dummy output"))
306        }
307    }
308
309    #[tokio::test]
310    async fn test_register_and_get_tool() {
311        let mut registry = ToolRegistry::new();
312        let tool = Arc::new(DummyTool);
313
314        registry.register(tool.clone());
315
316        assert!(registry.has("dummy"));
317        assert!(registry.get("dummy").is_some());
318        assert_eq!(registry.len(), 1);
319    }
320
321    #[tokio::test]
322    async fn test_list_tools() {
323        let mut registry = ToolRegistry::new();
324        registry.register(Arc::new(DummyTool));
325
326        let tools = registry.list();
327        assert_eq!(tools.len(), 1);
328        assert!(tools.contains(&"dummy"));
329    }
330
331    #[tokio::test]
332    async fn test_execute_tool() {
333        let mut registry = ToolRegistry::new();
334        registry.register(Arc::new(DummyTool));
335
336        let result = registry.execute("dummy", Value::Null).await.unwrap();
337        assert!(result.success);
338        assert_eq!(result.output, "dummy output");
339    }
340
341    #[tokio::test]
342    async fn test_execute_nonexistent_tool() {
343        let registry = ToolRegistry::new();
344        let result = registry.execute("nonexistent", Value::Null).await;
345        assert!(result.is_err());
346    }
347
348    #[tokio::test]
349    async fn test_tool_result_success() {
350        let result = ToolResult::success("test output");
351        assert!(result.success);
352        assert_eq!(result.output, "test output");
353        assert!(result.error.is_none());
354    }
355
356    #[tokio::test]
357    async fn test_tool_result_failure() {
358        let result = ToolResult::failure("test error");
359        assert!(!result.success);
360        assert_eq!(result.error, Some("test error".to_string()));
361    }
362}