spec_ai_core/tools/
mod.rs1pub mod builtin;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::Arc;
9use tracing::debug;
10
11use self::builtin::{
12 AudioTranscriptionTool, BashTool, EchoTool, FileExtractTool, FileReadTool, FileWriteTool,
13 GraphTool, MathTool, PromptUserTool, SearchTool, ShellTool,
14};
15
16#[cfg(feature = "api")]
17use self::builtin::WebSearchTool;
18
19#[cfg(feature = "web-scraping")]
20use self::builtin::WebScraperTool;
21use crate::embeddings::EmbeddingsClient;
22use crate::persistence::Persistence;
23
24#[cfg(feature = "openai")]
25use async_openai::types::ChatCompletionTool;
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct ToolResult {
30 pub success: bool,
32 pub output: String,
34 pub error: Option<String>,
36}
37
38impl ToolResult {
39 pub fn success(output: impl Into<String>) -> Self {
41 Self {
42 success: true,
43 output: output.into(),
44 error: None,
45 }
46 }
47
48 pub fn failure(error: impl Into<String>) -> Self {
50 Self {
51 success: false,
52 output: String::new(),
53 error: Some(error.into()),
54 }
55 }
56}
57
58#[async_trait]
60pub trait Tool: Send + Sync {
61 fn name(&self) -> &str;
63
64 fn description(&self) -> &str;
66
67 fn parameters(&self) -> Value;
69
70 async fn execute(&self, args: Value) -> Result<ToolResult>;
72}
73
74pub struct ToolRegistry {
76 tools: HashMap<String, Arc<dyn Tool>>,
77}
78
79impl ToolRegistry {
80 pub fn new() -> Self {
82 Self {
83 tools: HashMap::new(),
84 }
85 }
86
87 #[allow(unused_variables)]
92 pub fn with_builtin_tools(
93 persistence: Option<Arc<Persistence>>,
94 embeddings: Option<EmbeddingsClient>,
95 ) -> Self {
96 let mut registry = Self::new();
97
98 registry.register(Arc::new(EchoTool::new()));
100 registry.register(Arc::new(MathTool::new()));
101 registry.register(Arc::new(FileReadTool::new()));
102 registry.register(Arc::new(FileExtractTool::new()));
103 registry.register(Arc::new(FileWriteTool::new()));
104 registry.register(Arc::new(PromptUserTool::new()));
105 registry.register(Arc::new(SearchTool::new()));
106 registry.register(Arc::new(BashTool::new()));
107 registry.register(Arc::new(ShellTool::new()));
108
109 #[cfg(feature = "api")]
111 registry.register(Arc::new(WebSearchTool::new().with_embeddings(embeddings)));
112
113 #[cfg(feature = "web-scraping")]
115 registry.register(Arc::new(WebScraperTool::new()));
116
117 if let Some(persistence) = persistence {
118 registry.register(Arc::new(GraphTool::new(persistence.clone())));
119 registry.register(Arc::new(AudioTranscriptionTool::with_persistence(
120 persistence,
121 )));
122 } else {
123 registry.register(Arc::new(AudioTranscriptionTool::new()));
124 }
125
126 tracing::debug!("ToolRegistry created with {} tools", registry.tools.len());
127 for name in registry.tools.keys() {
128 tracing::debug!(" - Tool: {}", name);
129 }
130
131 registry
132 }
133
134 pub fn register(&mut self, tool: Arc<dyn Tool>) {
136 let name = tool.name().to_string();
137 self.tools.insert(name, tool);
138 }
139
140 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
142 self.tools.get(name).cloned()
143 }
144
145 pub fn list(&self) -> Vec<&str> {
147 self.tools.keys().map(|s| s.as_str()).collect()
148 }
149
150 pub fn has(&self, name: &str) -> bool {
152 self.tools.contains_key(name)
153 }
154
155 pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
157 let tool = self
158 .get(name)
159 .ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
160
161 debug!("Executing tool '{}'", name);
162 let result = tool.execute(args).await;
163 match &result {
164 Ok(res) => {
165 debug!(
166 "Tool '{}' completed: success={}, error={:?}",
167 name, res.success, res.error
168 );
169 }
170 Err(err) => {
171 debug!("Tool '{}' failed to execute: {}", name, err);
172 }
173 }
174 result
175 }
176
177 pub fn len(&self) -> usize {
179 self.tools.len()
180 }
181
182 pub fn is_empty(&self) -> bool {
184 self.tools.is_empty()
185 }
186
187 #[cfg(any(feature = "openai", feature = "mlx", feature = "lmstudio"))]
192 pub fn to_openai_tools(&self) -> Vec<ChatCompletionTool> {
193 use crate::agent::function_calling::tool_to_openai_function;
194
195 self.tools
196 .values()
197 .map(|tool| {
198 tool_to_openai_function(tool.name(), tool.description(), &tool.parameters())
199 })
200 .collect()
201 }
202}
203
204impl Default for ToolRegistry {
205 fn default() -> Self {
206 Self::new()
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 struct DummyTool;
215
216 #[async_trait]
217 impl Tool for DummyTool {
218 fn name(&self) -> &str {
219 "dummy"
220 }
221
222 fn description(&self) -> &str {
223 "A dummy tool for testing"
224 }
225
226 fn parameters(&self) -> Value {
227 serde_json::json!({
228 "type": "object",
229 "properties": {}
230 })
231 }
232
233 async fn execute(&self, _args: Value) -> Result<ToolResult> {
234 Ok(ToolResult::success("dummy output"))
235 }
236 }
237
238 #[tokio::test]
239 async fn test_register_and_get_tool() {
240 let mut registry = ToolRegistry::new();
241 let tool = Arc::new(DummyTool);
242
243 registry.register(tool.clone());
244
245 assert!(registry.has("dummy"));
246 assert!(registry.get("dummy").is_some());
247 assert_eq!(registry.len(), 1);
248 }
249
250 #[tokio::test]
251 async fn test_list_tools() {
252 let mut registry = ToolRegistry::new();
253 registry.register(Arc::new(DummyTool));
254
255 let tools = registry.list();
256 assert_eq!(tools.len(), 1);
257 assert!(tools.contains(&"dummy"));
258 }
259
260 #[tokio::test]
261 async fn test_execute_tool() {
262 let mut registry = ToolRegistry::new();
263 registry.register(Arc::new(DummyTool));
264
265 let result = registry.execute("dummy", Value::Null).await.unwrap();
266 assert!(result.success);
267 assert_eq!(result.output, "dummy output");
268 }
269
270 #[tokio::test]
271 async fn test_execute_nonexistent_tool() {
272 let registry = ToolRegistry::new();
273 let result = registry.execute("nonexistent", Value::Null).await;
274 assert!(result.is_err());
275 }
276
277 #[tokio::test]
278 async fn test_tool_result_success() {
279 let result = ToolResult::success("test output");
280 assert!(result.success);
281 assert_eq!(result.output, "test output");
282 assert!(result.error.is_none());
283 }
284
285 #[tokio::test]
286 async fn test_tool_result_failure() {
287 let result = ToolResult::failure("test error");
288 assert!(!result.success);
289 assert_eq!(result.error, Some("test error".to_string()));
290 }
291}