spec_ai_core/tools/
mod.rs

1pub mod builtin;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tracing::debug;
10
11use self::builtin::{
12    AudioTranscriptionTool, BashTool, EchoTool, FileExtractTool, FileReadTool, FileWriteTool,
13    GraphTool, MathTool, PromptUserTool, SearchTool, ShellTool,
14};
15
16#[cfg(feature = "api")]
17use self::builtin::WebSearchTool;
18
19#[cfg(feature = "web-scraping")]
20use self::builtin::WebScraperTool;
21use crate::embeddings::EmbeddingsClient;
22use crate::persistence::Persistence;
23
24#[cfg(feature = "openai")]
25use async_openai::types::ChatCompletionTool;
26
27/// Result of tool execution
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ToolResult {
30    /// Whether execution succeeded
31    pub success: bool,
32    /// Output from the tool
33    pub output: String,
34    /// Error message if execution failed
35    pub error: Option<String>,
36}
37
38impl ToolResult {
39    /// Create a successful result
40    pub fn success(output: impl Into<String>) -> Self {
41        Self {
42            success: true,
43            output: output.into(),
44            error: None,
45        }
46    }
47
48    /// Create a failure result
49    pub fn failure(error: impl Into<String>) -> Self {
50        Self {
51            success: false,
52            output: String::new(),
53            error: Some(error.into()),
54        }
55    }
56}
57
58/// Trait for all tools that can be executed by the agent
59#[async_trait]
60pub trait Tool: Send + Sync {
61    /// Unique name of the tool
62    fn name(&self) -> &str;
63
64    /// Human-readable description of what the tool does
65    fn description(&self) -> &str;
66
67    /// JSON Schema describing the tool's parameters
68    fn parameters(&self) -> Value;
69
70    /// Execute the tool with the given arguments
71    async fn execute(&self, args: Value) -> Result<ToolResult>;
72}
73
74/// Registry for managing and executing tools
75pub struct ToolRegistry {
76    tools: HashMap<String, Arc<dyn Tool>>,
77}
78
79impl ToolRegistry {
80    /// Create a new empty tool registry
81    pub fn new() -> Self {
82        Self {
83            tools: HashMap::new(),
84        }
85    }
86
87    /// Create a registry populated with all built-in tools.
88    ///
89    /// Tools that require persistence (e.g., `graph`) are only registered when
90    /// an [`Arc<Persistence>`] is provided.
91    #[allow(unused_variables)]
92    pub fn with_builtin_tools(
93        persistence: Option<Arc<Persistence>>,
94        embeddings: Option<EmbeddingsClient>,
95    ) -> Self {
96        let mut registry = Self::new();
97
98        // Register all built-in tools
99        registry.register(Arc::new(EchoTool::new()));
100        registry.register(Arc::new(MathTool::new()));
101        registry.register(Arc::new(FileReadTool::new()));
102        registry.register(Arc::new(FileExtractTool::new()));
103        registry.register(Arc::new(FileWriteTool::new()));
104        registry.register(Arc::new(PromptUserTool::new()));
105        registry.register(Arc::new(SearchTool::new()));
106        registry.register(Arc::new(BashTool::new()));
107        registry.register(Arc::new(ShellTool::new()));
108
109        // Register web search if api feature is enabled
110        #[cfg(feature = "api")]
111        registry.register(Arc::new(WebSearchTool::new().with_embeddings(embeddings)));
112
113        // Register web scraper if feature is enabled
114        #[cfg(feature = "web-scraping")]
115        registry.register(Arc::new(WebScraperTool::new()));
116
117        if let Some(persistence) = persistence {
118            registry.register(Arc::new(GraphTool::new(persistence.clone())));
119            registry.register(Arc::new(AudioTranscriptionTool::with_persistence(
120                persistence,
121            )));
122        } else {
123            registry.register(Arc::new(AudioTranscriptionTool::new()));
124        }
125
126        tracing::debug!("ToolRegistry created with {} tools", registry.tools.len());
127        for name in registry.tools.keys() {
128            tracing::debug!("  - Tool: {}", name);
129        }
130
131        registry
132    }
133
134    /// Register a tool in the registry
135    pub fn register(&mut self, tool: Arc<dyn Tool>) {
136        let name = tool.name().to_string();
137        self.tools.insert(name, tool);
138    }
139
140    /// Get a tool by name
141    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
142        self.tools.get(name).cloned()
143    }
144
145    /// List all registered tool names
146    pub fn list(&self) -> Vec<&str> {
147        self.tools.keys().map(|s| s.as_str()).collect()
148    }
149
150    /// Check if a tool is registered
151    pub fn has(&self, name: &str) -> bool {
152        self.tools.contains_key(name)
153    }
154
155    /// Execute a tool by name with the given arguments
156    pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
157        let tool = self
158            .get(name)
159            .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
160
161        debug!("Executing tool '{}'", name);
162        let result = tool.execute(args).await;
163        match &result {
164            Ok(res) => {
165                debug!(
166                    "Tool '{}' completed: success={}, error={:?}",
167                    name, res.success, res.error
168                );
169            }
170            Err(err) => {
171                debug!("Tool '{}' failed to execute: {}", name, err);
172            }
173        }
174        result
175    }
176
177    /// Get the number of registered tools
178    pub fn len(&self) -> usize {
179        self.tools.len()
180    }
181
182    /// Check if the registry is empty
183    pub fn is_empty(&self) -> bool {
184        self.tools.is_empty()
185    }
186
187    /// Convert all tools in the registry to OpenAI ChatCompletionTool format.
188    ///
189    /// Used by providers that support native function calling (OpenAI-compatible,
190    /// including MLX and LM Studio when enabled).
191    #[cfg(any(feature = "openai", feature = "mlx", feature = "lmstudio"))]
192    pub fn to_openai_tools(&self) -> Vec<ChatCompletionTool> {
193        use crate::agent::function_calling::tool_to_openai_function;
194
195        self.tools
196            .values()
197            .map(|tool| {
198                tool_to_openai_function(tool.name(), tool.description(), &tool.parameters())
199            })
200            .collect()
201    }
202}
203
204impl Default for ToolRegistry {
205    fn default() -> Self {
206        Self::new()
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    struct DummyTool;
215
216    #[async_trait]
217    impl Tool for DummyTool {
218        fn name(&self) -> &str {
219            "dummy"
220        }
221
222        fn description(&self) -> &str {
223            "A dummy tool for testing"
224        }
225
226        fn parameters(&self) -> Value {
227            serde_json::json!({
228                "type": "object",
229                "properties": {}
230            })
231        }
232
233        async fn execute(&self, _args: Value) -> Result<ToolResult> {
234            Ok(ToolResult::success("dummy output"))
235        }
236    }
237
238    #[tokio::test]
239    async fn test_register_and_get_tool() {
240        let mut registry = ToolRegistry::new();
241        let tool = Arc::new(DummyTool);
242
243        registry.register(tool.clone());
244
245        assert!(registry.has("dummy"));
246        assert!(registry.get("dummy").is_some());
247        assert_eq!(registry.len(), 1);
248    }
249
250    #[tokio::test]
251    async fn test_list_tools() {
252        let mut registry = ToolRegistry::new();
253        registry.register(Arc::new(DummyTool));
254
255        let tools = registry.list();
256        assert_eq!(tools.len(), 1);
257        assert!(tools.contains(&"dummy"));
258    }
259
260    #[tokio::test]
261    async fn test_execute_tool() {
262        let mut registry = ToolRegistry::new();
263        registry.register(Arc::new(DummyTool));
264
265        let result = registry.execute("dummy", Value::Null).await.unwrap();
266        assert!(result.success);
267        assert_eq!(result.output, "dummy output");
268    }
269
270    #[tokio::test]
271    async fn test_execute_nonexistent_tool() {
272        let registry = ToolRegistry::new();
273        let result = registry.execute("nonexistent", Value::Null).await;
274        assert!(result.is_err());
275    }
276
277    #[tokio::test]
278    async fn test_tool_result_success() {
279        let result = ToolResult::success("test output");
280        assert!(result.success);
281        assert_eq!(result.output, "test output");
282        assert!(result.error.is_none());
283    }
284
285    #[tokio::test]
286    async fn test_tool_result_failure() {
287        let result = ToolResult::failure("test error");
288        assert!(!result.success);
289        assert_eq!(result.error, Some("test error".to_string()));
290    }
291}