spec_ai_core/tools/
mod.rs1pub mod builtin;
2pub mod plugin_adapter;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12use self::builtin::{
13 AudioTranscriptionTool, BashTool, EchoTool, FileExtractTool, FileReadTool, FileWriteTool,
14 GraphTool, MathTool, PromptUserTool, SearchTool, ShellTool,
15};
16
17#[cfg(feature = "api")]
18use self::builtin::WebSearchTool;
19
20#[cfg(feature = "web-scraping")]
21use self::builtin::WebScraperTool;
22use crate::embeddings::EmbeddingsClient;
23use crate::persistence::Persistence;
24
25pub use plugin_adapter::PluginToolAdapter;
26
27#[cfg(feature = "openai")]
28use async_openai::types::ChatCompletionTool;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ToolResult {
33 pub success: bool,
35 pub output: String,
37 pub error: Option<String>,
39}
40
41impl ToolResult {
42 pub fn success(output: impl Into<String>) -> Self {
44 Self {
45 success: true,
46 output: output.into(),
47 error: None,
48 }
49 }
50
51 pub fn failure(error: impl Into<String>) -> Self {
53 Self {
54 success: false,
55 output: String::new(),
56 error: Some(error.into()),
57 }
58 }
59}
60
61#[async_trait]
63pub trait Tool: Send + Sync {
64 fn name(&self) -> &str;
66
67 fn description(&self) -> &str;
69
70 fn parameters(&self) -> Value;
72
73 async fn execute(&self, args: Value) -> Result<ToolResult>;
75}
76
77pub struct ToolRegistry {
79 tools: HashMap<String, Arc<dyn Tool>>,
80}
81
82impl ToolRegistry {
83 pub fn new() -> Self {
85 Self {
86 tools: HashMap::new(),
87 }
88 }
89
90 #[allow(unused_variables)]
95 pub fn with_builtin_tools(
96 persistence: Option<Arc<Persistence>>,
97 embeddings: Option<EmbeddingsClient>,
98 ) -> Self {
99 let mut registry = Self::new();
100
101 registry.register(Arc::new(EchoTool::new()));
103 registry.register(Arc::new(MathTool::new()));
104 registry.register(Arc::new(FileReadTool::new()));
105 registry.register(Arc::new(FileExtractTool::new()));
106 registry.register(Arc::new(FileWriteTool::new()));
107 registry.register(Arc::new(PromptUserTool::new()));
108 registry.register(Arc::new(SearchTool::new()));
109 registry.register(Arc::new(BashTool::new()));
110 registry.register(Arc::new(ShellTool::new()));
111
112 #[cfg(feature = "api")]
114 registry.register(Arc::new(WebSearchTool::new().with_embeddings(embeddings)));
115
116 #[cfg(feature = "web-scraping")]
118 registry.register(Arc::new(WebScraperTool::new()));
119
120 if let Some(persistence) = persistence {
121 registry.register(Arc::new(GraphTool::new(persistence.clone())));
122 registry.register(Arc::new(AudioTranscriptionTool::with_persistence(
123 persistence,
124 )));
125 } else {
126 registry.register(Arc::new(AudioTranscriptionTool::new()));
127 }
128
129 tracing::debug!("ToolRegistry created with {} tools", registry.tools.len());
130 for name in registry.tools.keys() {
131 tracing::debug!(" - Tool: {}", name);
132 }
133
134 registry
135 }
136
137 pub fn register(&mut self, tool: Arc<dyn Tool>) {
139 let name = tool.name().to_string();
140 self.tools.insert(name, tool);
141 }
142
143 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
145 self.tools.get(name).cloned()
146 }
147
148 pub fn list(&self) -> Vec<&str> {
150 self.tools.keys().map(|s| s.as_str()).collect()
151 }
152
153 pub fn has(&self, name: &str) -> bool {
155 self.tools.contains_key(name)
156 }
157
158 pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
160 let tool = self
161 .get(name)
162 .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
163
164 debug!("Executing tool '{}'", name);
165 let result = tool.execute(args).await;
166 match &result {
167 Ok(res) => {
168 debug!(
169 "Tool '{}' completed: success={}, error={:?}",
170 name, res.success, res.error
171 );
172 }
173 Err(err) => {
174 debug!("Tool '{}' failed to execute: {}", name, err);
175 }
176 }
177 result
178 }
179
180 pub fn len(&self) -> usize {
182 self.tools.len()
183 }
184
185 pub fn is_empty(&self) -> bool {
187 self.tools.is_empty()
188 }
189
190 pub fn load_plugins(
199 &mut self,
200 dir: &std::path::Path,
201 allow_override: bool,
202 ) -> anyhow::Result<spec_ai_plugin::LoadStats> {
203 use spec_ai_plugin::{expand_tilde, PluginLoader};
204
205 let expanded_dir = expand_tilde(dir);
206
207 let mut loader = PluginLoader::new();
208 let stats = loader.load_directory(&expanded_dir)?;
209
210 for (tool_ref, plugin_name) in loader.all_tools() {
212 let adapter = match PluginToolAdapter::new(tool_ref, plugin_name) {
213 Ok(a) => a,
214 Err(e) => {
215 tracing::warn!("Failed to create adapter for tool from {}: {}", plugin_name, e);
216 continue;
217 }
218 };
219
220 let tool_name = adapter.name().to_string();
221
222 if self.has(&tool_name) {
224 if allow_override {
225 tracing::info!(
226 "Plugin tool '{}' from '{}' overriding built-in tool",
227 tool_name,
228 plugin_name
229 );
230 } else {
231 tracing::warn!(
232 "Plugin tool '{}' from '{}' would override built-in, skipping (set allow_override_builtin=true to allow)",
233 tool_name,
234 plugin_name
235 );
236 continue;
237 }
238 }
239
240 tracing::debug!(
241 "Registering plugin tool '{}' from '{}'",
242 tool_name,
243 plugin_name
244 );
245 self.register(Arc::new(adapter));
246 }
247
248 Ok(stats)
249 }
250
251 #[cfg(any(feature = "openai", feature = "mlx", feature = "lmstudio"))]
256 pub fn to_openai_tools(&self) -> Vec<ChatCompletionTool> {
257 use crate::agent::function_calling::tool_to_openai_function;
258
259 self.tools
260 .values()
261 .map(|tool| {
262 tool_to_openai_function(tool.name(), tool.description(), &tool.parameters())
263 })
264 .collect()
265 }
266}
267
268impl Default for ToolRegistry {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 struct DummyTool;
279
280 #[async_trait]
281 impl Tool for DummyTool {
282 fn name(&self) -> &str {
283 "dummy"
284 }
285
286 fn description(&self) -> &str {
287 "A dummy tool for testing"
288 }
289
290 fn parameters(&self) -> Value {
291 serde_json::json!({
292 "type": "object",
293 "properties": {}
294 })
295 }
296
297 async fn execute(&self, _args: Value) -> Result<ToolResult> {
298 Ok(ToolResult::success("dummy output"))
299 }
300 }
301
302 #[tokio::test]
303 async fn test_register_and_get_tool() {
304 let mut registry = ToolRegistry::new();
305 let tool = Arc::new(DummyTool);
306
307 registry.register(tool.clone());
308
309 assert!(registry.has("dummy"));
310 assert!(registry.get("dummy").is_some());
311 assert_eq!(registry.len(), 1);
312 }
313
314 #[tokio::test]
315 async fn test_list_tools() {
316 let mut registry = ToolRegistry::new();
317 registry.register(Arc::new(DummyTool));
318
319 let tools = registry.list();
320 assert_eq!(tools.len(), 1);
321 assert!(tools.contains(&"dummy"));
322 }
323
324 #[tokio::test]
325 async fn test_execute_tool() {
326 let mut registry = ToolRegistry::new();
327 registry.register(Arc::new(DummyTool));
328
329 let result = registry.execute("dummy", Value::Null).await.unwrap();
330 assert!(result.success);
331 assert_eq!(result.output, "dummy output");
332 }
333
334 #[tokio::test]
335 async fn test_execute_nonexistent_tool() {
336 let registry = ToolRegistry::new();
337 let result = registry.execute("nonexistent", Value::Null).await;
338 assert!(result.is_err());
339 }
340
341 #[tokio::test]
342 async fn test_tool_result_success() {
343 let result = ToolResult::success("test output");
344 assert!(result.success);
345 assert_eq!(result.output, "test output");
346 assert!(result.error.is_none());
347 }
348
349 #[tokio::test]
350 async fn test_tool_result_failure() {
351 let result = ToolResult::failure("test error");
352 assert!(!result.success);
353 assert_eq!(result.error, Some("test error".to_string()));
354 }
355}