spec_ai/spec_ai_core/tools/
mod.rs1pub mod builtin;
2pub mod plugin_adapter;
3
4use anyhow::Result;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tracing::debug;
11
12use self::builtin::{
13 AudioTranscriptionTool, BashTool, CodeSearchTool, EchoTool, FileExtractTool, FileReadTool,
14 FileWriteTool, GenerateCodeTool, GraphTool, GrepTool, MathTool, PromptUserTool, RgTool,
15 SearchTool, ShellTool,
16};
17
18#[cfg(feature = "api")]
19use self::builtin::WebSearchTool;
20
21#[cfg(feature = "web-scraping")]
22use self::builtin::WebScraperTool;
23use crate::spec_ai_core::agent::model::ModelProvider;
24use crate::spec_ai_core::agent::safety::RunSafetyBudget;
25use crate::spec_ai_core::embeddings::EmbeddingsClient;
26use crate::spec_ai_core::persistence::Persistence;
27
28pub use plugin_adapter::PluginToolAdapter;
29
30#[cfg(feature = "openai")]
31use async_openai::types::chat::ChatCompletionTool;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ToolResult {
36 pub success: bool,
38 pub output: String,
40 pub error: Option<String>,
42}
43
44#[derive(Clone, Default)]
46pub struct ToolExecutionContext {
47 pub safety: Option<RunSafetyBudget>,
48 pub delegation_depth: usize,
49}
50
51impl ToolResult {
52 pub fn success(output: impl Into<String>) -> Self {
54 Self {
55 success: true,
56 output: output.into(),
57 error: None,
58 }
59 }
60
61 pub fn failure(error: impl Into<String>) -> Self {
63 Self {
64 success: false,
65 output: String::new(),
66 error: Some(error.into()),
67 }
68 }
69}
70
71#[async_trait]
73pub trait Tool: Send + Sync {
74 fn name(&self) -> &str;
76
77 fn description(&self) -> &str;
79
80 fn parameters(&self) -> Value;
82
83 async fn execute(&self, args: Value) -> Result<ToolResult>;
85
86 async fn execute_with_context(
88 &self,
89 args: Value,
90 _context: ToolExecutionContext,
91 ) -> Result<ToolResult> {
92 self.execute(args).await
93 }
94}
95
96pub struct ToolRegistry {
98 tools: HashMap<String, Arc<dyn Tool>>,
99}
100
101impl ToolRegistry {
102 pub fn new() -> Self {
104 Self {
105 tools: HashMap::new(),
106 }
107 }
108
109 #[allow(unused_variables)]
114 pub fn with_builtin_tools(
115 persistence: Option<Arc<Persistence>>,
116 embeddings: Option<EmbeddingsClient>,
117 code_model_provider: Option<Arc<dyn ModelProvider>>,
118 ) -> Self {
119 let mut registry = Self::new();
120
121 registry.register(Arc::new(EchoTool::new()));
123 registry.register(Arc::new(MathTool::new()));
124 registry.register(Arc::new(FileReadTool::new()));
125 registry.register(Arc::new(FileExtractTool::new()));
126 registry.register(Arc::new(FileWriteTool::new()));
127 registry.register(Arc::new(PromptUserTool::new()));
128 registry.register(Arc::new(SearchTool::new()));
129 registry.register(Arc::new(GrepTool::new()));
130 registry.register(Arc::new(RgTool::new()));
131 registry.register(Arc::new(CodeSearchTool::new()));
132 registry.register(Arc::new(BashTool::new()));
133 registry.register(Arc::new(ShellTool::new()));
134 if let Some(provider) = code_model_provider {
135 registry.register(Arc::new(GenerateCodeTool::new(provider)));
136 }
137
138 #[cfg(feature = "api")]
140 registry.register(Arc::new(WebSearchTool::new().with_embeddings(embeddings)));
141
142 #[cfg(feature = "web-scraping")]
144 registry.register(Arc::new(WebScraperTool::new()));
145
146 if let Some(persistence) = persistence {
147 registry.register(Arc::new(GraphTool::new(persistence.clone())));
148 registry.register(Arc::new(AudioTranscriptionTool::with_persistence(
149 persistence,
150 )));
151 } else {
152 registry.register(Arc::new(AudioTranscriptionTool::new()));
153 }
154
155 tracing::debug!("ToolRegistry created with {} tools", registry.tools.len());
156 for name in registry.tools.keys() {
157 tracing::debug!(" - Tool: {}", name);
158 }
159
160 registry
161 }
162
163 pub fn register(&mut self, tool: Arc<dyn Tool>) {
165 let name = tool.name().to_string();
166 self.tools.insert(name, tool);
167 }
168
169 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
171 self.tools.get(name).cloned()
172 }
173
174 pub fn list(&self) -> Vec<&str> {
176 self.tools.keys().map(|s| s.as_str()).collect()
177 }
178
179 pub fn has(&self, name: &str) -> bool {
181 self.tools.contains_key(name)
182 }
183
184 pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
186 self.execute_with_context(name, args, ToolExecutionContext::default())
187 .await
188 }
189
190 pub async fn execute_with_context(
192 &self,
193 name: &str,
194 args: Value,
195 context: ToolExecutionContext,
196 ) -> Result<ToolResult> {
197 let tool = self
198 .get(name)
199 .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
200
201 debug!("Executing tool '{}'", name);
202 let result = tool.execute_with_context(args, context).await;
203 match &result {
204 Ok(res) => {
205 debug!(
206 "Tool '{}' completed: success={}, error={:?}",
207 name, res.success, res.error
208 );
209 }
210 Err(err) => {
211 debug!("Tool '{}' failed to execute: {}", name, err);
212 }
213 }
214 result
215 }
216
217 pub fn len(&self) -> usize {
219 self.tools.len()
220 }
221
222 pub fn is_empty(&self) -> bool {
224 self.tools.is_empty()
225 }
226
227 pub fn load_plugins(
236 &mut self,
237 dir: &std::path::Path,
238 allow_override: bool,
239 ) -> anyhow::Result<crate::spec_ai_plugin::LoadStats> {
240 use crate::spec_ai_plugin::{PluginLoader, expand_tilde};
241
242 let expanded_dir = expand_tilde(dir);
243
244 let mut loader = PluginLoader::new();
245 let stats = loader.load_directory(&expanded_dir)?;
246
247 for (tool_ref, plugin_name) in loader.all_tools() {
249 let adapter = match PluginToolAdapter::new(tool_ref, plugin_name) {
250 Ok(a) => a,
251 Err(e) => {
252 tracing::warn!(
253 "Failed to create adapter for tool from {}: {}",
254 plugin_name,
255 e
256 );
257 continue;
258 }
259 };
260
261 let tool_name = adapter.name().to_string();
262
263 if self.has(&tool_name) {
265 if allow_override {
266 tracing::info!(
267 "Plugin tool '{}' from '{}' overriding built-in tool",
268 tool_name,
269 plugin_name
270 );
271 } else {
272 tracing::warn!(
273 "Plugin tool '{}' from '{}' would override built-in, skipping (set allow_override_builtin=true to allow)",
274 tool_name,
275 plugin_name
276 );
277 continue;
278 }
279 }
280
281 tracing::debug!(
282 "Registering plugin tool '{}' from '{}'",
283 tool_name,
284 plugin_name
285 );
286 self.register(Arc::new(adapter));
287 }
288
289 Ok(stats)
290 }
291
292 #[cfg(any(feature = "openai", feature = "mlx", feature = "lmstudio"))]
297 pub fn to_openai_tools(&self) -> Vec<ChatCompletionTool> {
298 use crate::spec_ai_core::agent::function_calling::tool_to_openai_function;
299
300 self.tools
301 .values()
302 .map(|tool| {
303 tool_to_openai_function(tool.name(), tool.description(), &tool.parameters())
304 })
305 .collect()
306 }
307}
308
309impl Default for ToolRegistry {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 struct DummyTool;
320
321 #[async_trait]
322 impl Tool for DummyTool {
323 fn name(&self) -> &str {
324 "dummy"
325 }
326
327 fn description(&self) -> &str {
328 "A dummy tool for testing"
329 }
330
331 fn parameters(&self) -> Value {
332 serde_json::json!({
333 "type": "object",
334 "properties": {}
335 })
336 }
337
338 async fn execute(&self, _args: Value) -> Result<ToolResult> {
339 Ok(ToolResult::success("dummy output"))
340 }
341 }
342
343 #[tokio::test]
344 async fn test_register_and_get_tool() {
345 let mut registry = ToolRegistry::new();
346 let tool = Arc::new(DummyTool);
347
348 registry.register(tool.clone());
349
350 assert!(registry.has("dummy"));
351 assert!(registry.get("dummy").is_some());
352 assert_eq!(registry.len(), 1);
353 }
354
355 #[tokio::test]
356 async fn test_list_tools() {
357 let mut registry = ToolRegistry::new();
358 registry.register(Arc::new(DummyTool));
359
360 let tools = registry.list();
361 assert_eq!(tools.len(), 1);
362 assert!(tools.contains(&"dummy"));
363 }
364
365 #[tokio::test]
366 async fn test_execute_tool() {
367 let mut registry = ToolRegistry::new();
368 registry.register(Arc::new(DummyTool));
369
370 let result = registry.execute("dummy", Value::Null).await.unwrap();
371 assert!(result.success);
372 assert_eq!(result.output, "dummy output");
373 }
374
375 #[tokio::test]
376 async fn test_execute_nonexistent_tool() {
377 let registry = ToolRegistry::new();
378 let result = registry.execute("nonexistent", Value::Null).await;
379 assert!(result.is_err());
380 }
381
382 #[tokio::test]
383 async fn test_tool_result_success() {
384 let result = ToolResult::success("test output");
385 assert!(result.success);
386 assert_eq!(result.output, "test output");
387 assert!(result.error.is_none());
388 }
389
390 #[tokio::test]
391 async fn test_tool_result_failure() {
392 let result = ToolResult::failure("test error");
393 assert!(!result.success);
394 assert_eq!(result.error, Some("test error".to_string()));
395 }
396}