Skip to main content

spec_ai/spec_ai_core/tools/
mod.rs

1pub mod builtin;
2pub mod plugin_adapter;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12use self::builtin::{
13    AudioTranscriptionTool, BashTool, CodeSearchTool, EchoTool, FileExtractTool, FileReadTool,
14    FileWriteTool, GenerateCodeTool, GraphTool, GrepTool, MathTool, PromptUserTool, RgTool,
15    SearchTool, ShellTool,
16};
17
18#[cfg(feature = "api")]
19use self::builtin::WebSearchTool;
20
21#[cfg(feature = "web-scraping")]
22use self::builtin::WebScraperTool;
23use crate::spec_ai_core::agent::model::ModelProvider;
24use crate::spec_ai_core::agent::safety::RunSafetyBudget;
25use crate::spec_ai_core::embeddings::EmbeddingsClient;
26use crate::spec_ai_core::persistence::Persistence;
27
28pub use plugin_adapter::PluginToolAdapter;
29
30#[cfg(feature = "openai")]
31use async_openai::types::chat::ChatCompletionTool;
32
33/// Result of tool execution
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ToolResult {
36    /// Whether execution succeeded
37    pub success: bool,
38    /// Output from the tool
39    pub output: String,
40    /// Error message if execution failed
41    pub error: Option<String>,
42}
43
44/// Per-invocation context passed to tools that need run-level guardrails.
45#[derive(Clone, Default)]
46pub struct ToolExecutionContext {
47    pub safety: Option<RunSafetyBudget>,
48    pub delegation_depth: usize,
49}
50
51impl ToolResult {
52    /// Create a successful result
53    pub fn success(output: impl Into<String>) -> Self {
54        Self {
55            success: true,
56            output: output.into(),
57            error: None,
58        }
59    }
60
61    /// Create a failure result
62    pub fn failure(error: impl Into<String>) -> Self {
63        Self {
64            success: false,
65            output: String::new(),
66            error: Some(error.into()),
67        }
68    }
69}
70
71/// Trait for all tools that can be executed by the agent
72#[async_trait]
73pub trait Tool: Send + Sync {
74    /// Unique name of the tool
75    fn name(&self) -> &str;
76
77    /// Human-readable description of what the tool does
78    fn description(&self) -> &str;
79
80    /// JSON Schema describing the tool's parameters
81    fn parameters(&self) -> Value;
82
83    /// Execute the tool with the given arguments
84    async fn execute(&self, args: Value) -> Result<ToolResult>;
85
86    /// Execute with run context. Tools that do not need context use `execute`.
87    async fn execute_with_context(
88        &self,
89        args: Value,
90        _context: ToolExecutionContext,
91    ) -> Result<ToolResult> {
92        self.execute(args).await
93    }
94}
95
96/// Registry for managing and executing tools
97pub struct ToolRegistry {
98    tools: HashMap<String, Arc<dyn Tool>>,
99}
100
101impl ToolRegistry {
102    /// Create a new empty tool registry
103    pub fn new() -> Self {
104        Self {
105            tools: HashMap::new(),
106        }
107    }
108
109    /// Create a registry populated with all built-in tools.
110    ///
111    /// Tools that require persistence (e.g., `graph`) are only registered when
112    /// an [`Arc<Persistence>`] is provided.
113    #[allow(unused_variables)]
114    pub fn with_builtin_tools(
115        persistence: Option<Arc<Persistence>>,
116        embeddings: Option<EmbeddingsClient>,
117        code_model_provider: Option<Arc<dyn ModelProvider>>,
118    ) -> Self {
119        let mut registry = Self::new();
120
121        // Register all built-in tools
122        registry.register(Arc::new(EchoTool::new()));
123        registry.register(Arc::new(MathTool::new()));
124        registry.register(Arc::new(FileReadTool::new()));
125        registry.register(Arc::new(FileExtractTool::new()));
126        registry.register(Arc::new(FileWriteTool::new()));
127        registry.register(Arc::new(PromptUserTool::new()));
128        registry.register(Arc::new(SearchTool::new()));
129        registry.register(Arc::new(GrepTool::new()));
130        registry.register(Arc::new(RgTool::new()));
131        registry.register(Arc::new(CodeSearchTool::new()));
132        registry.register(Arc::new(BashTool::new()));
133        registry.register(Arc::new(ShellTool::new()));
134        if let Some(provider) = code_model_provider {
135            registry.register(Arc::new(GenerateCodeTool::new(provider)));
136        }
137
138        // Register web search if api feature is enabled
139        #[cfg(feature = "api")]
140        registry.register(Arc::new(WebSearchTool::new().with_embeddings(embeddings)));
141
142        // Register web scraper if feature is enabled
143        #[cfg(feature = "web-scraping")]
144        registry.register(Arc::new(WebScraperTool::new()));
145
146        if let Some(persistence) = persistence {
147            registry.register(Arc::new(GraphTool::new(persistence.clone())));
148            registry.register(Arc::new(AudioTranscriptionTool::with_persistence(
149                persistence,
150            )));
151        } else {
152            registry.register(Arc::new(AudioTranscriptionTool::new()));
153        }
154
155        tracing::debug!("ToolRegistry created with {} tools", registry.tools.len());
156        for name in registry.tools.keys() {
157            tracing::debug!("  - Tool: {}", name);
158        }
159
160        registry
161    }
162
163    /// Register a tool in the registry
164    pub fn register(&mut self, tool: Arc<dyn Tool>) {
165        let name = tool.name().to_string();
166        self.tools.insert(name, tool);
167    }
168
169    /// Get a tool by name
170    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
171        self.tools.get(name).cloned()
172    }
173
174    /// List all registered tool names
175    pub fn list(&self) -> Vec<&str> {
176        self.tools.keys().map(|s| s.as_str()).collect()
177    }
178
179    /// Check if a tool is registered
180    pub fn has(&self, name: &str) -> bool {
181        self.tools.contains_key(name)
182    }
183
184    /// Execute a tool by name with the given arguments
185    pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
186        self.execute_with_context(name, args, ToolExecutionContext::default())
187            .await
188    }
189
190    /// Execute a tool by name with run context.
191    pub async fn execute_with_context(
192        &self,
193        name: &str,
194        args: Value,
195        context: ToolExecutionContext,
196    ) -> Result<ToolResult> {
197        let tool = self
198            .get(name)
199            .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
200
201        debug!("Executing tool '{}'", name);
202        let result = tool.execute_with_context(args, context).await;
203        match &result {
204            Ok(res) => {
205                debug!(
206                    "Tool '{}' completed: success={}, error={:?}",
207                    name, res.success, res.error
208                );
209            }
210            Err(err) => {
211                debug!("Tool '{}' failed to execute: {}", name, err);
212            }
213        }
214        result
215    }
216
217    /// Get the number of registered tools
218    pub fn len(&self) -> usize {
219        self.tools.len()
220    }
221
222    /// Check if the registry is empty
223    pub fn is_empty(&self) -> bool {
224        self.tools.is_empty()
225    }
226
227    /// Load plugins from a directory and register their tools
228    ///
229    /// # Arguments
230    /// * `dir` - Directory containing plugin libraries
231    /// * `allow_override` - Whether plugins can override built-in tools
232    ///
233    /// # Returns
234    /// Statistics about the loading process
235    pub fn load_plugins(
236        &mut self,
237        dir: &std::path::Path,
238        allow_override: bool,
239    ) -> anyhow::Result<crate::spec_ai_plugin::LoadStats> {
240        use crate::spec_ai_plugin::{PluginLoader, expand_tilde};
241
242        let expanded_dir = expand_tilde(dir);
243
244        let mut loader = PluginLoader::new();
245        let stats = loader.load_directory(&expanded_dir)?;
246
247        // Register tools from plugins
248        for (tool_ref, plugin_name) in loader.all_tools() {
249            let adapter = match PluginToolAdapter::new(tool_ref, plugin_name) {
250                Ok(a) => a,
251                Err(e) => {
252                    tracing::warn!(
253                        "Failed to create adapter for tool from {}: {}",
254                        plugin_name,
255                        e
256                    );
257                    continue;
258                }
259            };
260
261            let tool_name = adapter.name().to_string();
262
263            // Check for conflicts with built-in tools
264            if self.has(&tool_name) {
265                if allow_override {
266                    tracing::info!(
267                        "Plugin tool '{}' from '{}' overriding built-in tool",
268                        tool_name,
269                        plugin_name
270                    );
271                } else {
272                    tracing::warn!(
273                        "Plugin tool '{}' from '{}' would override built-in, skipping (set allow_override_builtin=true to allow)",
274                        tool_name,
275                        plugin_name
276                    );
277                    continue;
278                }
279            }
280
281            tracing::debug!(
282                "Registering plugin tool '{}' from '{}'",
283                tool_name,
284                plugin_name
285            );
286            self.register(Arc::new(adapter));
287        }
288
289        Ok(stats)
290    }
291
292    /// Convert all tools in the registry to OpenAI ChatCompletionTool format.
293    ///
294    /// Used by providers that support native function calling (OpenAI-compatible,
295    /// including MLX and LM Studio when enabled).
296    #[cfg(any(feature = "openai", feature = "mlx", feature = "lmstudio"))]
297    pub fn to_openai_tools(&self) -> Vec<ChatCompletionTool> {
298        use crate::spec_ai_core::agent::function_calling::tool_to_openai_function;
299
300        self.tools
301            .values()
302            .map(|tool| {
303                tool_to_openai_function(tool.name(), tool.description(), &tool.parameters())
304            })
305            .collect()
306    }
307}
308
309impl Default for ToolRegistry {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    struct DummyTool;
320
321    #[async_trait]
322    impl Tool for DummyTool {
323        fn name(&self) -> &str {
324            "dummy"
325        }
326
327        fn description(&self) -> &str {
328            "A dummy tool for testing"
329        }
330
331        fn parameters(&self) -> Value {
332            serde_json::json!({
333                "type": "object",
334                "properties": {}
335            })
336        }
337
338        async fn execute(&self, _args: Value) -> Result<ToolResult> {
339            Ok(ToolResult::success("dummy output"))
340        }
341    }
342
343    #[tokio::test]
344    async fn test_register_and_get_tool() {
345        let mut registry = ToolRegistry::new();
346        let tool = Arc::new(DummyTool);
347
348        registry.register(tool.clone());
349
350        assert!(registry.has("dummy"));
351        assert!(registry.get("dummy").is_some());
352        assert_eq!(registry.len(), 1);
353    }
354
355    #[tokio::test]
356    async fn test_list_tools() {
357        let mut registry = ToolRegistry::new();
358        registry.register(Arc::new(DummyTool));
359
360        let tools = registry.list();
361        assert_eq!(tools.len(), 1);
362        assert!(tools.contains(&"dummy"));
363    }
364
365    #[tokio::test]
366    async fn test_execute_tool() {
367        let mut registry = ToolRegistry::new();
368        registry.register(Arc::new(DummyTool));
369
370        let result = registry.execute("dummy", Value::Null).await.unwrap();
371        assert!(result.success);
372        assert_eq!(result.output, "dummy output");
373    }
374
375    #[tokio::test]
376    async fn test_execute_nonexistent_tool() {
377        let registry = ToolRegistry::new();
378        let result = registry.execute("nonexistent", Value::Null).await;
379        assert!(result.is_err());
380    }
381
382    #[tokio::test]
383    async fn test_tool_result_success() {
384        let result = ToolResult::success("test output");
385        assert!(result.success);
386        assert_eq!(result.output, "test output");
387        assert!(result.error.is_none());
388    }
389
390    #[tokio::test]
391    async fn test_tool_result_failure() {
392        let result = ToolResult::failure("test error");
393        assert!(!result.success);
394        assert_eq!(result.error, Some("test error".to_string()));
395    }
396}