1use crate::core::error::{ExecutionError, RustChainError};
2use async_trait::async_trait;
3use serde_json::Value;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
7pub enum ToolResult {
8 Success(String),
9 StructuredJson(Value),
10 Error(String),
11}
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum ToolCapability {
15 Basic,
16 WasmPlugin,
17 SystemAccess,
18 NetworkAccess,
19}
20
21#[async_trait]
22pub trait Tool: Send + Sync {
23 fn name(&self) -> &'static str;
24 fn capabilities(&self) -> Vec<ToolCapability>;
25 async fn invoke(&self, input: &str) -> Result<ToolResult, RustChainError>;
26}
27
28pub struct ToolRegistry {
29 tools: HashMap<String, Box<dyn Tool>>,
30}
31
32impl ToolRegistry {
33 pub fn new() -> Self {
34 Self {
35 tools: HashMap::new(),
36 }
37 }
38
39 pub fn with_defaults() -> Self {
41 let mut registry = Self::new();
42 registry.register_defaults();
43 registry
44 }
45
46 pub fn register_defaults(&mut self) {
48 #[cfg(feature = "tools")]
49 {
50 tracing::info!("Registering tools feature components...");
51
52 crate::core::web_search_tools::register_web_search_tools(self);
54
55 tracing::info!("About to register document loaders...");
57 crate::core::document_loaders::register_document_loaders(self);
58
59 crate::core::python_interpreter::register_python_interpreter(self);
61
62 crate::core::github_toolkit::register_github_client(self);
64
65 self.register_http_tool();
67 }
68
69 #[cfg(feature = "rag")]
70 {
71 crate::core::pinecone_vector_store::register_pinecone_vector_store(self);
73 crate::core::chroma_vector_store::register_chroma_vector_store(self);
74 }
75 }
76
77 pub fn register(&mut self, tool: Box<dyn Tool>) {
78 self.tools.insert(tool.name().to_string(), tool);
79 }
80
81 pub fn get(&self, name: &str) -> Option<&Box<dyn Tool>> {
82 self.tools.get(name)
83 }
84
85 pub fn list(&self) -> Vec<String> {
86 self.tools.keys().cloned().collect()
87 }
88
89 pub fn tools_by_capability(&self, cap: ToolCapability) -> Vec<&Box<dyn Tool>> {
90 self.tools
91 .values()
92 .filter(|tool| tool.capabilities().contains(&cap))
93 .collect()
94 }
95
96 pub fn count(&self) -> usize {
97 self.tools.len()
98 }
99
100 pub fn clear(&mut self) {
101 self.tools.clear();
102 }
103
104 pub fn remove(&mut self, name: &str) -> Option<Box<dyn Tool>> {
105 self.tools.remove(name)
106 }
107
108 pub fn contains(&self, name: &str) -> bool {
109 self.tools.contains_key(name)
110 }
111
112 pub fn get_tool(&self, name: &str) -> Option<&Box<dyn Tool>> {
113 self.tools.get(name)
114 }
115
116 pub fn get_capabilities(&self, name: &str) -> Option<Vec<ToolCapability>> {
117 self.tools.get(name).map(|tool| tool.capabilities())
118 }
119
120 #[allow(dead_code)]
122 fn register_http_tool(&mut self) {
123 self.register(Box::new(HttpToolBridge::new()));
124 tracing::info!("Registered HTTP tool for ToolRegistry");
125 }
126}
127
128pub struct HttpToolBridge;
130
131impl HttpToolBridge {
132 pub fn new() -> Self {
133 Self
134 }
135}
136
137#[async_trait::async_trait]
138impl Tool for HttpToolBridge {
139 fn name(&self) -> &'static str {
140 "http"
141 }
142
143 fn capabilities(&self) -> Vec<ToolCapability> {
144 vec![ToolCapability::Basic, ToolCapability::NetworkAccess]
145 }
146
147 async fn invoke(&self, input: &str) -> Result<ToolResult, RustChainError> {
148 #[cfg(feature = "tools")]
149 {
150 use crate::tools::{ToolCall, ToolExecutor, HttpTool};
151 use crate::core::RuntimeContext;
152
153 let params: serde_json::Value = serde_json::from_str(input)
155 .map_err(|e| RustChainError::Execution(ExecutionError::step_failed("http", "parse_input", format!("Invalid JSON input: {}", e))))?;
156
157 let tool_call = ToolCall::new(
159 "http".to_string(),
160 params,
161 );
162
163 let context = RuntimeContext::new();
165
166 let http_tool = HttpTool;
168 let result = http_tool.execute(tool_call, &context).await
169 .map_err(|e| RustChainError::Execution(ExecutionError::step_failed("http", "http_request", format!("HTTP request failed: {}", e))))?;
170
171 if result.success {
173 Ok(ToolResult::StructuredJson(result.output))
174 } else {
175 Ok(ToolResult::Error(result.error.unwrap_or_else(|| "HTTP request failed".to_string())))
176 }
177 }
178
179 #[cfg(not(feature = "tools"))]
180 {
181 let _ = input; Err(RustChainError::Execution(ExecutionError::step_failed("http", "feature_disabled", "Tools feature not enabled".to_string())))
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use async_trait::async_trait;
191 use serde_json::json;
192
193 struct MockBasicTool {
195 name: &'static str,
196 result: ToolResult,
197 should_fail: bool,
198 }
199
200 impl MockBasicTool {
201 fn new(name: &'static str) -> Self {
202 Self {
203 name,
204 result: ToolResult::Success("mock success".to_string()),
205 should_fail: false,
206 }
207 }
208
209 fn with_result(mut self, result: ToolResult) -> Self {
210 self.result = result;
211 self
212 }
213
214 fn with_failure(mut self) -> Self {
215 self.should_fail = true;
216 self
217 }
218 }
219
220 #[async_trait]
221 impl Tool for MockBasicTool {
222 fn name(&self) -> &'static str {
223 self.name
224 }
225
226 fn capabilities(&self) -> Vec<ToolCapability> {
227 vec![ToolCapability::Basic]
228 }
229
230 async fn invoke(&self, _input: &str) -> Result<ToolResult, RustChainError> {
231 if self.should_fail {
232 Err(RustChainError::Tool(crate::core::error::ToolError::execution_failed(
233 self.name,
234 "Mock tool failure".to_string()
235 )))
236 } else {
237 Ok(self.result.clone())
238 }
239 }
240 }
241
242 struct MockNetworkTool;
243
244 #[async_trait]
245 impl Tool for MockNetworkTool {
246 fn name(&self) -> &'static str {
247 "network_tool"
248 }
249
250 fn capabilities(&self) -> Vec<ToolCapability> {
251 vec![ToolCapability::NetworkAccess, ToolCapability::Basic]
252 }
253
254 async fn invoke(&self, input: &str) -> Result<ToolResult, RustChainError> {
255 if input.contains("fail") {
256 Ok(ToolResult::Error("Network operation failed".to_string()))
257 } else {
258 Ok(ToolResult::StructuredJson(json!({
259 "status": "success",
260 "data": "network response"
261 })))
262 }
263 }
264 }
265
266 struct MockSystemTool;
267
268 #[async_trait]
269 impl Tool for MockSystemTool {
270 fn name(&self) -> &'static str {
271 "system_tool"
272 }
273
274 fn capabilities(&self) -> Vec<ToolCapability> {
275 vec![ToolCapability::SystemAccess, ToolCapability::WasmPlugin]
276 }
277
278 async fn invoke(&self, input: &str) -> Result<ToolResult, RustChainError> {
279 Ok(ToolResult::Success(format!("System executed: {}", input)))
280 }
281 }
282
283 #[tokio::test]
284 async fn test_tool_result_variants() {
285 let success = ToolResult::Success("success message".to_string());
287 let structured = ToolResult::StructuredJson(json!({"key": "value"}));
288 let error = ToolResult::Error("error message".to_string());
289
290 assert!(format!("{:?}", success).contains("Success"));
292 assert!(format!("{:?}", structured).contains("StructuredJson"));
293 assert!(format!("{:?}", error).contains("Error"));
294
295 match success {
297 ToolResult::Success(msg) => assert_eq!(msg, "success message"),
298 _ => panic!("Expected Success variant"),
299 }
300
301 match structured {
302 ToolResult::StructuredJson(val) => {
303 assert_eq!(val["key"], "value");
304 },
305 _ => panic!("Expected StructuredJson variant"),
306 }
307
308 match error {
309 ToolResult::Error(msg) => assert_eq!(msg, "error message"),
310 _ => panic!("Expected Error variant"),
311 }
312 }
313
314 #[test]
315 fn test_tool_capability_variants() {
316 let basic = ToolCapability::Basic;
318 let wasm = ToolCapability::WasmPlugin;
319 let system = ToolCapability::SystemAccess;
320 let network = ToolCapability::NetworkAccess;
321
322 assert_eq!(basic.clone(), ToolCapability::Basic);
324 assert_ne!(basic, wasm);
325 assert_ne!(system, network);
326
327 let mut cap_map = HashMap::new();
329 cap_map.insert(basic.clone(), "basic");
330 cap_map.insert(wasm.clone(), "wasm");
331 cap_map.insert(system.clone(), "system");
332 cap_map.insert(network.clone(), "network");
333
334 assert_eq!(cap_map.get(&basic), Some(&"basic"));
335 assert_eq!(cap_map.get(&wasm), Some(&"wasm"));
336 assert_eq!(cap_map.len(), 4);
337 }
338
339 #[tokio::test]
340 async fn test_mock_basic_tool() {
341 let tool = MockBasicTool::new("test_basic");
342
343 assert_eq!(tool.name(), "test_basic");
344 assert_eq!(tool.capabilities(), vec![ToolCapability::Basic]);
345
346 let result = tool.invoke("test input").await.unwrap();
347 match result {
348 ToolResult::Success(msg) => assert_eq!(msg, "mock success"),
349 _ => panic!("Expected Success result"),
350 }
351 }
352
353 #[tokio::test]
354 async fn test_mock_basic_tool_with_custom_result() {
355 let tool = MockBasicTool::new("custom_tool")
356 .with_result(ToolResult::StructuredJson(json!({"custom": "data"})));
357
358 let result = tool.invoke("input").await.unwrap();
359 match result {
360 ToolResult::StructuredJson(val) => {
361 assert_eq!(val["custom"], "data");
362 },
363 _ => panic!("Expected StructuredJson result"),
364 }
365 }
366
367 #[tokio::test]
368 async fn test_mock_basic_tool_failure() {
369 let tool = MockBasicTool::new("failing_tool").with_failure();
370
371 let result = tool.invoke("input").await;
372 assert!(result.is_err());
373
374 match result {
375 Err(RustChainError::Tool(e)) => {
376 assert!(e.to_string().contains("Mock tool failure"));
377 },
378 _ => panic!("Expected Tool error"),
379 }
380 }
381
382 #[tokio::test]
383 async fn test_mock_network_tool() {
384 let tool = MockNetworkTool;
385
386 assert_eq!(tool.name(), "network_tool");
387 assert_eq!(tool.capabilities(), vec![ToolCapability::NetworkAccess, ToolCapability::Basic]);
388
389 let result = tool.invoke("success").await.unwrap();
391 match result {
392 ToolResult::StructuredJson(val) => {
393 assert_eq!(val["status"], "success");
394 assert_eq!(val["data"], "network response");
395 },
396 _ => panic!("Expected StructuredJson result"),
397 }
398
399 let result = tool.invoke("fail").await.unwrap();
401 match result {
402 ToolResult::Error(msg) => {
403 assert_eq!(msg, "Network operation failed");
404 },
405 _ => panic!("Expected Error result"),
406 }
407 }
408
409 #[tokio::test]
410 async fn test_mock_system_tool() {
411 let tool = MockSystemTool;
412
413 assert_eq!(tool.name(), "system_tool");
414 assert_eq!(tool.capabilities(), vec![ToolCapability::SystemAccess, ToolCapability::WasmPlugin]);
415
416 let result = tool.invoke("system command").await.unwrap();
417 match result {
418 ToolResult::Success(msg) => {
419 assert_eq!(msg, "System executed: system command");
420 },
421 _ => panic!("Expected Success result"),
422 }
423 }
424
425 #[tokio::test]
426 async fn test_tool_registry_basic_operations() {
427 let mut registry = ToolRegistry::new();
428
429 assert_eq!(registry.count(), 0);
431 assert!(registry.list().is_empty());
432 assert!(!registry.contains("nonexistent"));
433
434 registry.register(Box::new(MockBasicTool::new("tool1")));
436 assert_eq!(registry.count(), 1);
437 assert!(registry.contains("tool1"));
438
439 let tools = registry.list();
441 assert_eq!(tools.len(), 1);
442 assert!(tools.contains(&"tool1".to_string()));
443
444 let tool = registry.get("tool1");
446 assert!(tool.is_some());
447 assert_eq!(tool.unwrap().name(), "tool1");
448
449 assert!(registry.get("nonexistent").is_none());
451 }
452
453 #[tokio::test]
454 async fn test_tool_registry_multiple_tools() {
455 let mut registry = ToolRegistry::new();
456
457 registry.register(Box::new(MockBasicTool::new("basic1")));
459 registry.register(Box::new(MockBasicTool::new("basic2")));
460 registry.register(Box::new(MockNetworkTool));
461 registry.register(Box::new(MockSystemTool));
462
463 assert_eq!(registry.count(), 4);
464
465 let tools = registry.list();
466 assert_eq!(tools.len(), 4);
467 assert!(tools.contains(&"basic1".to_string()));
468 assert!(tools.contains(&"basic2".to_string()));
469 assert!(tools.contains(&"network_tool".to_string()));
470 assert!(tools.contains(&"system_tool".to_string()));
471 }
472
473 #[test]
474 fn test_tool_registry_tools_by_capability() {
475 let mut registry = ToolRegistry::new();
476
477 registry.register(Box::new(MockBasicTool::new("basic1")));
478 registry.register(Box::new(MockBasicTool::new("basic2")));
479 registry.register(Box::new(MockNetworkTool));
480 registry.register(Box::new(MockSystemTool));
481
482 let basic_tools = registry.tools_by_capability(ToolCapability::Basic);
484 assert_eq!(basic_tools.len(), 3);
485
486 let network_tools = registry.tools_by_capability(ToolCapability::NetworkAccess);
488 assert_eq!(network_tools.len(), 1);
489 assert_eq!(network_tools[0].name(), "network_tool");
490
491 let system_tools = registry.tools_by_capability(ToolCapability::SystemAccess);
493 assert_eq!(system_tools.len(), 1);
494 assert_eq!(system_tools[0].name(), "system_tool");
495
496 let wasm_tools = registry.tools_by_capability(ToolCapability::WasmPlugin);
498 assert_eq!(wasm_tools.len(), 1);
499 assert_eq!(wasm_tools[0].name(), "system_tool");
500 }
501
502 #[test]
503 fn test_tool_registry_get_capabilities() {
504 let mut registry = ToolRegistry::new();
505
506 registry.register(Box::new(MockBasicTool::new("basic_tool")));
507 registry.register(Box::new(MockNetworkTool));
508
509 let basic_caps = registry.get_capabilities("basic_tool");
511 assert!(basic_caps.is_some());
512 assert_eq!(basic_caps.unwrap(), vec![ToolCapability::Basic]);
513
514 let network_caps = registry.get_capabilities("network_tool");
515 assert!(network_caps.is_some());
516 assert_eq!(network_caps.unwrap(), vec![ToolCapability::NetworkAccess, ToolCapability::Basic]);
517
518 let nonexistent_caps = registry.get_capabilities("nonexistent");
520 assert!(nonexistent_caps.is_none());
521 }
522
523 #[test]
524 fn test_tool_registry_remove() {
525 let mut registry = ToolRegistry::new();
526
527 registry.register(Box::new(MockBasicTool::new("removable_tool")));
528 registry.register(Box::new(MockBasicTool::new("permanent_tool")));
529
530 assert_eq!(registry.count(), 2);
531 assert!(registry.contains("removable_tool"));
532
533 let removed = registry.remove("removable_tool");
535 assert!(removed.is_some());
536 assert_eq!(removed.unwrap().name(), "removable_tool");
537
538 assert_eq!(registry.count(), 1);
539 assert!(!registry.contains("removable_tool"));
540 assert!(registry.contains("permanent_tool"));
541
542 let not_removed = registry.remove("nonexistent");
544 assert!(not_removed.is_none());
545 assert_eq!(registry.count(), 1);
546 }
547
548 #[test]
549 fn test_tool_registry_clear() {
550 let mut registry = ToolRegistry::new();
551
552 registry.register(Box::new(MockBasicTool::new("tool1")));
553 registry.register(Box::new(MockBasicTool::new("tool2")));
554 registry.register(Box::new(MockNetworkTool));
555
556 assert_eq!(registry.count(), 3);
557
558 registry.clear();
559
560 assert_eq!(registry.count(), 0);
561 assert!(registry.list().is_empty());
562 assert!(!registry.contains("tool1"));
563 assert!(!registry.contains("tool2"));
564 assert!(!registry.contains("network_tool"));
565 }
566
567 #[test]
568 fn test_tool_registry_overwrite() {
569 let mut registry = ToolRegistry::new();
570
571 registry.register(Box::new(MockBasicTool::new("tool1")));
573 assert_eq!(registry.count(), 1);
574
575 registry.register(Box::new(MockBasicTool::new("tool1")));
577 assert_eq!(registry.count(), 1);
578
579 let tool = registry.get("tool1").unwrap();
580 assert_eq!(tool.name(), "tool1");
581 }
582
583 #[tokio::test]
584 async fn test_tool_trait_object_usage() {
585 let tool: Box<dyn Tool> = Box::new(MockBasicTool::new("trait_object_tool"));
587
588 assert_eq!(tool.name(), "trait_object_tool");
589 assert_eq!(tool.capabilities(), vec![ToolCapability::Basic]);
590
591 let result = tool.invoke("test").await.unwrap();
592 match result {
593 ToolResult::Success(msg) => assert_eq!(msg, "mock success"),
594 _ => panic!("Expected Success result"),
595 }
596 }
597
598 #[tokio::test]
599 async fn test_multiple_capability_tool() {
600 let tool = MockNetworkTool;
601 let capabilities = tool.capabilities();
602
603 assert_eq!(capabilities.len(), 2);
604 assert!(capabilities.contains(&ToolCapability::NetworkAccess));
605 assert!(capabilities.contains(&ToolCapability::Basic));
606
607 let mut registry = ToolRegistry::new();
609 registry.register(Box::new(MockNetworkTool));
610
611 let basic_tools = registry.tools_by_capability(ToolCapability::Basic);
612 assert_eq!(basic_tools.len(), 1);
613
614 let network_tools = registry.tools_by_capability(ToolCapability::NetworkAccess);
615 assert_eq!(network_tools.len(), 1);
616
617 assert_eq!(basic_tools[0].name(), network_tools[0].name());
618 }
619
620 #[tokio::test]
621 async fn test_tool_result_cloning() {
622 let original = ToolResult::Success("cloneable".to_string());
624 let cloned = original.clone();
625
626 match (original, cloned) {
627 (ToolResult::Success(orig), ToolResult::Success(clone)) => {
628 assert_eq!(orig, clone);
629 },
630 _ => panic!("Clone failed"),
631 }
632
633 let json_original = ToolResult::StructuredJson(json!({"clone": "test"}));
634 let json_cloned = json_original.clone();
635
636 match (json_original, json_cloned) {
637 (ToolResult::StructuredJson(orig), ToolResult::StructuredJson(clone)) => {
638 assert_eq!(orig, clone);
639 },
640 _ => panic!("JSON clone failed"),
641 }
642
643 let error_original = ToolResult::Error("cloneable error".to_string());
644 let error_cloned = error_original.clone();
645
646 match (error_original, error_cloned) {
647 (ToolResult::Error(orig), ToolResult::Error(clone)) => {
648 assert_eq!(orig, clone);
649 },
650 _ => panic!("Error clone failed"),
651 }
652 }
653
654 #[tokio::test]
655 async fn test_edge_cases() {
656 let mut registry = ToolRegistry::new();
657
658 struct EmptyNameTool;
660
661 #[async_trait]
662 impl Tool for EmptyNameTool {
663 fn name(&self) -> &'static str {
664 ""
665 }
666
667 fn capabilities(&self) -> Vec<ToolCapability> {
668 vec![]
669 }
670
671 async fn invoke(&self, _input: &str) -> Result<ToolResult, RustChainError> {
672 Ok(ToolResult::Success("empty name tool".to_string()))
673 }
674 }
675
676 registry.register(Box::new(EmptyNameTool));
677 assert_eq!(registry.count(), 1);
678 assert!(registry.contains(""));
679
680 let tool = registry.get("").unwrap();
681 assert_eq!(tool.name(), "");
682 assert!(tool.capabilities().is_empty());
683
684 let no_cap_tools = registry.tools_by_capability(ToolCapability::Basic);
686 assert_eq!(no_cap_tools.len(), 0);
687 }
688
689 #[test]
690 fn test_large_registry_performance() {
691 let mut registry = ToolRegistry::new();
692
693 for i in 0..100 {
695 registry.register(Box::new(MockBasicTool::new("tool").with_result(
696 ToolResult::Success(format!("Tool {}", i))
697 )));
698 }
699
700 assert_eq!(registry.count(), 1); registry.clear();
705 for i in 0..50 {
706 let tool_name = format!("tool_{}", i);
707 struct UniqueNameTool {
709 index: usize,
710 }
711
712 #[async_trait]
713 impl Tool for UniqueNameTool {
714 fn name(&self) -> &'static str {
715 "unique_tool"
718 }
719
720 fn capabilities(&self) -> Vec<ToolCapability> {
721 vec![ToolCapability::Basic]
722 }
723
724 async fn invoke(&self, _input: &str) -> Result<ToolResult, RustChainError> {
725 Ok(ToolResult::Success(format!("Tool {}", self.index)))
726 }
727 }
728
729 registry.register(Box::new(UniqueNameTool { index: i }));
730 }
731
732 assert_eq!(registry.count(), 1);
734
735 let tools = registry.list();
737 assert_eq!(tools.len(), 1);
738
739 let basic_tools = registry.tools_by_capability(ToolCapability::Basic);
740 assert_eq!(basic_tools.len(), 1);
741 }
742}