spec_ai_core/tools/
mod.rs1pub mod builtin;
2pub mod plugin_adapter;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12use self::builtin::{
13 AudioTranscriptionTool, BashTool, CodeSearchTool, EchoTool, FileExtractTool, FileReadTool,
14 FileWriteTool, GraphTool, GrepTool, MathTool, PromptUserTool, SearchTool, ShellTool,
15};
16
17#[cfg(feature = "api")]
18use self::builtin::WebSearchTool;
19
20#[cfg(feature = "web-scraping")]
21use self::builtin::WebScraperTool;
22use crate::embeddings::EmbeddingsClient;
23use crate::persistence::Persistence;
24
25pub use plugin_adapter::PluginToolAdapter;
26
27#[cfg(feature = "openai")]
28use async_openai::types::ChatCompletionTool;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ToolResult {
33 pub success: bool,
35 pub output: String,
37 pub error: Option<String>,
39}
40
41impl ToolResult {
42 pub fn success(output: impl Into<String>) -> Self {
44 Self {
45 success: true,
46 output: output.into(),
47 error: None,
48 }
49 }
50
51 pub fn failure(error: impl Into<String>) -> Self {
53 Self {
54 success: false,
55 output: String::new(),
56 error: Some(error.into()),
57 }
58 }
59}
60
61#[async_trait]
63pub trait Tool: Send + Sync {
64 fn name(&self) -> &str;
66
67 fn description(&self) -> &str;
69
70 fn parameters(&self) -> Value;
72
73 async fn execute(&self, args: Value) -> Result<ToolResult>;
75}
76
77pub struct ToolRegistry {
79 tools: HashMap<String, Arc<dyn Tool>>,
80}
81
82impl ToolRegistry {
83 pub fn new() -> Self {
85 Self {
86 tools: HashMap::new(),
87 }
88 }
89
90 #[allow(unused_variables)]
95 pub fn with_builtin_tools(
96 persistence: Option<Arc<Persistence>>,
97 embeddings: Option<EmbeddingsClient>,
98 ) -> Self {
99 let mut registry = Self::new();
100
101 registry.register(Arc::new(EchoTool::new()));
103 registry.register(Arc::new(MathTool::new()));
104 registry.register(Arc::new(FileReadTool::new()));
105 registry.register(Arc::new(FileExtractTool::new()));
106 registry.register(Arc::new(FileWriteTool::new()));
107 registry.register(Arc::new(PromptUserTool::new()));
108 registry.register(Arc::new(SearchTool::new()));
109 registry.register(Arc::new(GrepTool::new()));
110 registry.register(Arc::new(CodeSearchTool::new()));
111 registry.register(Arc::new(BashTool::new()));
112 registry.register(Arc::new(ShellTool::new()));
113
114 #[cfg(feature = "api")]
116 registry.register(Arc::new(WebSearchTool::new().with_embeddings(embeddings)));
117
118 #[cfg(feature = "web-scraping")]
120 registry.register(Arc::new(WebScraperTool::new()));
121
122 if let Some(persistence) = persistence {
123 registry.register(Arc::new(GraphTool::new(persistence.clone())));
124 registry.register(Arc::new(AudioTranscriptionTool::with_persistence(
125 persistence,
126 )));
127 } else {
128 registry.register(Arc::new(AudioTranscriptionTool::new()));
129 }
130
131 tracing::debug!("ToolRegistry created with {} tools", registry.tools.len());
132 for name in registry.tools.keys() {
133 tracing::debug!(" - Tool: {}", name);
134 }
135
136 registry
137 }
138
139 pub fn register(&mut self, tool: Arc<dyn Tool>) {
141 let name = tool.name().to_string();
142 self.tools.insert(name, tool);
143 }
144
145 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
147 self.tools.get(name).cloned()
148 }
149
150 pub fn list(&self) -> Vec<&str> {
152 self.tools.keys().map(|s| s.as_str()).collect()
153 }
154
155 pub fn has(&self, name: &str) -> bool {
157 self.tools.contains_key(name)
158 }
159
160 pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
162 let tool = self
163 .get(name)
164 .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
165
166 debug!("Executing tool '{}'", name);
167 let result = tool.execute(args).await;
168 match &result {
169 Ok(res) => {
170 debug!(
171 "Tool '{}' completed: success={}, error={:?}",
172 name, res.success, res.error
173 );
174 }
175 Err(err) => {
176 debug!("Tool '{}' failed to execute: {}", name, err);
177 }
178 }
179 result
180 }
181
182 pub fn len(&self) -> usize {
184 self.tools.len()
185 }
186
187 pub fn is_empty(&self) -> bool {
189 self.tools.is_empty()
190 }
191
192 pub fn load_plugins(
201 &mut self,
202 dir: &std::path::Path,
203 allow_override: bool,
204 ) -> anyhow::Result<spec_ai_plugin::LoadStats> {
205 use spec_ai_plugin::{expand_tilde, PluginLoader};
206
207 let expanded_dir = expand_tilde(dir);
208
209 let mut loader = PluginLoader::new();
210 let stats = loader.load_directory(&expanded_dir)?;
211
212 for (tool_ref, plugin_name) in loader.all_tools() {
214 let adapter = match PluginToolAdapter::new(tool_ref, plugin_name) {
215 Ok(a) => a,
216 Err(e) => {
217 tracing::warn!(
218 "Failed to create adapter for tool from {}: {}",
219 plugin_name,
220 e
221 );
222 continue;
223 }
224 };
225
226 let tool_name = adapter.name().to_string();
227
228 if self.has(&tool_name) {
230 if allow_override {
231 tracing::info!(
232 "Plugin tool '{}' from '{}' overriding built-in tool",
233 tool_name,
234 plugin_name
235 );
236 } else {
237 tracing::warn!(
238 "Plugin tool '{}' from '{}' would override built-in, skipping (set allow_override_builtin=true to allow)",
239 tool_name,
240 plugin_name
241 );
242 continue;
243 }
244 }
245
246 tracing::debug!(
247 "Registering plugin tool '{}' from '{}'",
248 tool_name,
249 plugin_name
250 );
251 self.register(Arc::new(adapter));
252 }
253
254 Ok(stats)
255 }
256
257 #[cfg(any(feature = "openai", feature = "mlx", feature = "lmstudio"))]
262 pub fn to_openai_tools(&self) -> Vec<ChatCompletionTool> {
263 use crate::agent::function_calling::tool_to_openai_function;
264
265 self.tools
266 .values()
267 .map(|tool| {
268 tool_to_openai_function(tool.name(), tool.description(), &tool.parameters())
269 })
270 .collect()
271 }
272}
273
274impl Default for ToolRegistry {
275 fn default() -> Self {
276 Self::new()
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 struct DummyTool;
285
286 #[async_trait]
287 impl Tool for DummyTool {
288 fn name(&self) -> &str {
289 "dummy"
290 }
291
292 fn description(&self) -> &str {
293 "A dummy tool for testing"
294 }
295
296 fn parameters(&self) -> Value {
297 serde_json::json!({
298 "type": "object",
299 "properties": {}
300 })
301 }
302
303 async fn execute(&self, _args: Value) -> Result<ToolResult> {
304 Ok(ToolResult::success("dummy output"))
305 }
306 }
307
308 #[tokio::test]
309 async fn test_register_and_get_tool() {
310 let mut registry = ToolRegistry::new();
311 let tool = Arc::new(DummyTool);
312
313 registry.register(tool.clone());
314
315 assert!(registry.has("dummy"));
316 assert!(registry.get("dummy").is_some());
317 assert_eq!(registry.len(), 1);
318 }
319
320 #[tokio::test]
321 async fn test_list_tools() {
322 let mut registry = ToolRegistry::new();
323 registry.register(Arc::new(DummyTool));
324
325 let tools = registry.list();
326 assert_eq!(tools.len(), 1);
327 assert!(tools.contains(&"dummy"));
328 }
329
330 #[tokio::test]
331 async fn test_execute_tool() {
332 let mut registry = ToolRegistry::new();
333 registry.register(Arc::new(DummyTool));
334
335 let result = registry.execute("dummy", Value::Null).await.unwrap();
336 assert!(result.success);
337 assert_eq!(result.output, "dummy output");
338 }
339
340 #[tokio::test]
341 async fn test_execute_nonexistent_tool() {
342 let registry = ToolRegistry::new();
343 let result = registry.execute("nonexistent", Value::Null).await;
344 assert!(result.is_err());
345 }
346
347 #[tokio::test]
348 async fn test_tool_result_success() {
349 let result = ToolResult::success("test output");
350 assert!(result.success);
351 assert_eq!(result.output, "test output");
352 assert!(result.error.is_none());
353 }
354
355 #[tokio::test]
356 async fn test_tool_result_failure() {
357 let result = ToolResult::failure("test error");
358 assert!(!result.success);
359 assert_eq!(result.error, Some("test error".to_string()));
360 }
361}