rustchain/core/
tools.rs

1use crate::core::error::{ExecutionError, RustChainError};
2use async_trait::async_trait;
3use serde_json::Value;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
7pub enum ToolResult {
8    Success(String),
9    StructuredJson(Value),
10    Error(String),
11}
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum ToolCapability {
15    Basic,
16    WasmPlugin,
17    SystemAccess,
18    NetworkAccess,
19}
20
21#[async_trait]
22pub trait Tool: Send + Sync {
23    fn name(&self) -> &'static str;
24    fn capabilities(&self) -> Vec<ToolCapability>;
25    async fn invoke(&self, input: &str) -> Result<ToolResult, RustChainError>;
26}
27
28pub struct ToolRegistry {
29    tools: HashMap<String, Box<dyn Tool>>,
30}
31
32impl ToolRegistry {
33    pub fn new() -> Self {
34        Self {
35            tools: HashMap::new(),
36        }
37    }
38
39    /// Create a new registry with default tools registered
40    pub fn with_defaults() -> Self {
41        let mut registry = Self::new();
42        registry.register_defaults();
43        registry
44    }
45
46    /// Register default tools based on available features and environment
47    pub fn register_defaults(&mut self) {
48        #[cfg(feature = "tools")]
49        {
50            tracing::info!("Registering tools feature components...");
51            
52            // Register web search tools if environment variables are available
53            crate::core::web_search_tools::register_web_search_tools(self);
54
55            // Register document loaders
56            tracing::info!("About to register document loaders...");
57            crate::core::document_loaders::register_document_loaders(self);
58
59            // Register code interpreters
60            crate::core::python_interpreter::register_python_interpreter(self);
61
62            // Register developer toolkits
63            crate::core::github_toolkit::register_github_client(self);
64            
65            // Register HTTP client for web requests
66            self.register_http_tool();
67        }
68
69        #[cfg(feature = "rag")]
70        {
71            // Register vector stores if environment variables are available
72            crate::core::pinecone_vector_store::register_pinecone_vector_store(self);
73            crate::core::chroma_vector_store::register_chroma_vector_store(self);
74        }
75    }
76
77    pub fn register(&mut self, tool: Box<dyn Tool>) {
78        self.tools.insert(tool.name().to_string(), tool);
79    }
80
81    pub fn get(&self, name: &str) -> Option<&Box<dyn Tool>> {
82        self.tools.get(name)
83    }
84
85    pub fn list(&self) -> Vec<String> {
86        self.tools.keys().cloned().collect()
87    }
88
89    pub fn tools_by_capability(&self, cap: ToolCapability) -> Vec<&Box<dyn Tool>> {
90        self.tools
91            .values()
92            .filter(|tool| tool.capabilities().contains(&cap))
93            .collect()
94    }
95
96    pub fn count(&self) -> usize {
97        self.tools.len()
98    }
99
100    pub fn clear(&mut self) {
101        self.tools.clear();
102    }
103
104    pub fn remove(&mut self, name: &str) -> Option<Box<dyn Tool>> {
105        self.tools.remove(name)
106    }
107
108    pub fn contains(&self, name: &str) -> bool {
109        self.tools.contains_key(name)
110    }
111    
112    pub fn get_tool(&self, name: &str) -> Option<&Box<dyn Tool>> {
113        self.tools.get(name)
114    }
115
116    pub fn get_capabilities(&self, name: &str) -> Option<Vec<ToolCapability>> {
117        self.tools.get(name).map(|tool| tool.capabilities())
118    }
119    
120    /// Register HTTP tool for web requests
121    #[allow(dead_code)]
122    fn register_http_tool(&mut self) {
123        self.register(Box::new(HttpToolBridge::new()));
124        tracing::info!("Registered HTTP tool for ToolRegistry");
125    }
126}
127
128/// Bridge adapter that allows ToolManager's HttpTool to work with ToolRegistry
129pub struct HttpToolBridge;
130
131impl HttpToolBridge {
132    pub fn new() -> Self {
133        Self
134    }
135}
136
137#[async_trait::async_trait]
138impl Tool for HttpToolBridge {
139    fn name(&self) -> &'static str {
140        "http"
141    }
142
143    fn capabilities(&self) -> Vec<ToolCapability> {
144        vec![ToolCapability::Basic, ToolCapability::NetworkAccess]
145    }
146
147    async fn invoke(&self, input: &str) -> Result<ToolResult, RustChainError> {
148        #[cfg(feature = "tools")]
149        {
150            use crate::tools::{ToolCall, ToolExecutor, HttpTool};
151            use crate::core::RuntimeContext;
152            
153            // Parse input as JSON parameters
154            let params: serde_json::Value = serde_json::from_str(input)
155                .map_err(|e| RustChainError::Execution(ExecutionError::step_failed("http", "parse_input", format!("Invalid JSON input: {}", e))))?;
156            
157            // Create a ToolCall from the parameters  
158            let tool_call = ToolCall::new(
159                "http".to_string(),
160                params,
161            );
162            
163            // Create a minimal RuntimeContext
164            let context = RuntimeContext::new();
165            
166            // Execute using the actual HttpTool
167            let http_tool = HttpTool;
168            let result = http_tool.execute(tool_call, &context).await
169                .map_err(|e| RustChainError::Execution(ExecutionError::step_failed("http", "http_request", format!("HTTP request failed: {}", e))))?;
170            
171            // Convert tools::ToolResult to core::ToolResult
172            if result.success {
173                Ok(ToolResult::StructuredJson(result.output))
174            } else {
175                Ok(ToolResult::Error(result.error.unwrap_or_else(|| "HTTP request failed".to_string())))
176            }
177        }
178        
179        #[cfg(not(feature = "tools"))]
180        {
181            let _ = input; // Suppress unused parameter warning
182            Err(RustChainError::Execution(ExecutionError::step_failed("http", "feature_disabled", "Tools feature not enabled".to_string())))
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use async_trait::async_trait;
191    use serde_json::json;
192
193    // Mock tools for testing
194    struct MockBasicTool {
195        name: &'static str,
196        result: ToolResult,
197        should_fail: bool,
198    }
199
200    impl MockBasicTool {
201        fn new(name: &'static str) -> Self {
202            Self {
203                name,
204                result: ToolResult::Success("mock success".to_string()),
205                should_fail: false,
206            }
207        }
208
209        fn with_result(mut self, result: ToolResult) -> Self {
210            self.result = result;
211            self
212        }
213
214        fn with_failure(mut self) -> Self {
215            self.should_fail = true;
216            self
217        }
218    }
219
220    #[async_trait]
221    impl Tool for MockBasicTool {
222        fn name(&self) -> &'static str {
223            self.name
224        }
225
226        fn capabilities(&self) -> Vec<ToolCapability> {
227            vec![ToolCapability::Basic]
228        }
229
230        async fn invoke(&self, _input: &str) -> Result<ToolResult, RustChainError> {
231            if self.should_fail {
232                Err(RustChainError::Tool(crate::core::error::ToolError::execution_failed(
233                    self.name,
234                    "Mock tool failure".to_string()
235                )))
236            } else {
237                Ok(self.result.clone())
238            }
239        }
240    }
241
242    struct MockNetworkTool;
243
244    #[async_trait]
245    impl Tool for MockNetworkTool {
246        fn name(&self) -> &'static str {
247            "network_tool"
248        }
249
250        fn capabilities(&self) -> Vec<ToolCapability> {
251            vec![ToolCapability::NetworkAccess, ToolCapability::Basic]
252        }
253
254        async fn invoke(&self, input: &str) -> Result<ToolResult, RustChainError> {
255            if input.contains("fail") {
256                Ok(ToolResult::Error("Network operation failed".to_string()))
257            } else {
258                Ok(ToolResult::StructuredJson(json!({
259                    "status": "success",
260                    "data": "network response"
261                })))
262            }
263        }
264    }
265
266    struct MockSystemTool;
267
268    #[async_trait]
269    impl Tool for MockSystemTool {
270        fn name(&self) -> &'static str {
271            "system_tool"
272        }
273
274        fn capabilities(&self) -> Vec<ToolCapability> {
275            vec![ToolCapability::SystemAccess, ToolCapability::WasmPlugin]
276        }
277
278        async fn invoke(&self, input: &str) -> Result<ToolResult, RustChainError> {
279            Ok(ToolResult::Success(format!("System executed: {}", input)))
280        }
281    }
282
283    #[tokio::test]
284    async fn test_tool_result_variants() {
285        // Test all ToolResult variants
286        let success = ToolResult::Success("success message".to_string());
287        let structured = ToolResult::StructuredJson(json!({"key": "value"}));
288        let error = ToolResult::Error("error message".to_string());
289
290        // Test Debug implementation
291        assert!(format!("{:?}", success).contains("Success"));
292        assert!(format!("{:?}", structured).contains("StructuredJson"));
293        assert!(format!("{:?}", error).contains("Error"));
294
295        // Verify content
296        match success {
297            ToolResult::Success(msg) => assert_eq!(msg, "success message"),
298            _ => panic!("Expected Success variant"),
299        }
300
301        match structured {
302            ToolResult::StructuredJson(val) => {
303                assert_eq!(val["key"], "value");
304            },
305            _ => panic!("Expected StructuredJson variant"),
306        }
307
308        match error {
309            ToolResult::Error(msg) => assert_eq!(msg, "error message"),
310            _ => panic!("Expected Error variant"),
311        }
312    }
313
314    #[test]
315    fn test_tool_capability_variants() {
316        // Test all ToolCapability variants
317        let basic = ToolCapability::Basic;
318        let wasm = ToolCapability::WasmPlugin;
319        let system = ToolCapability::SystemAccess;
320        let network = ToolCapability::NetworkAccess;
321
322        // Test Debug, Clone, PartialEq, Eq, Hash implementations
323        assert_eq!(basic.clone(), ToolCapability::Basic);
324        assert_ne!(basic, wasm);
325        assert_ne!(system, network);
326
327        // Test in HashMap (Hash trait)
328        let mut cap_map = HashMap::new();
329        cap_map.insert(basic.clone(), "basic");
330        cap_map.insert(wasm.clone(), "wasm");
331        cap_map.insert(system.clone(), "system");
332        cap_map.insert(network.clone(), "network");
333
334        assert_eq!(cap_map.get(&basic), Some(&"basic"));
335        assert_eq!(cap_map.get(&wasm), Some(&"wasm"));
336        assert_eq!(cap_map.len(), 4);
337    }
338
339    #[tokio::test]
340    async fn test_mock_basic_tool() {
341        let tool = MockBasicTool::new("test_basic");
342
343        assert_eq!(tool.name(), "test_basic");
344        assert_eq!(tool.capabilities(), vec![ToolCapability::Basic]);
345
346        let result = tool.invoke("test input").await.unwrap();
347        match result {
348            ToolResult::Success(msg) => assert_eq!(msg, "mock success"),
349            _ => panic!("Expected Success result"),
350        }
351    }
352
353    #[tokio::test]
354    async fn test_mock_basic_tool_with_custom_result() {
355        let tool = MockBasicTool::new("custom_tool")
356            .with_result(ToolResult::StructuredJson(json!({"custom": "data"})));
357
358        let result = tool.invoke("input").await.unwrap();
359        match result {
360            ToolResult::StructuredJson(val) => {
361                assert_eq!(val["custom"], "data");
362            },
363            _ => panic!("Expected StructuredJson result"),
364        }
365    }
366
367    #[tokio::test]
368    async fn test_mock_basic_tool_failure() {
369        let tool = MockBasicTool::new("failing_tool").with_failure();
370
371        let result = tool.invoke("input").await;
372        assert!(result.is_err());
373
374        match result {
375            Err(RustChainError::Tool(e)) => {
376                assert!(e.to_string().contains("Mock tool failure"));
377            },
378            _ => panic!("Expected Tool error"),
379        }
380    }
381
382    #[tokio::test]
383    async fn test_mock_network_tool() {
384        let tool = MockNetworkTool;
385
386        assert_eq!(tool.name(), "network_tool");
387        assert_eq!(tool.capabilities(), vec![ToolCapability::NetworkAccess, ToolCapability::Basic]);
388
389        // Test success case
390        let result = tool.invoke("success").await.unwrap();
391        match result {
392            ToolResult::StructuredJson(val) => {
393                assert_eq!(val["status"], "success");
394                assert_eq!(val["data"], "network response");
395            },
396            _ => panic!("Expected StructuredJson result"),
397        }
398
399        // Test error case
400        let result = tool.invoke("fail").await.unwrap();
401        match result {
402            ToolResult::Error(msg) => {
403                assert_eq!(msg, "Network operation failed");
404            },
405            _ => panic!("Expected Error result"),
406        }
407    }
408
409    #[tokio::test]
410    async fn test_mock_system_tool() {
411        let tool = MockSystemTool;
412
413        assert_eq!(tool.name(), "system_tool");
414        assert_eq!(tool.capabilities(), vec![ToolCapability::SystemAccess, ToolCapability::WasmPlugin]);
415
416        let result = tool.invoke("system command").await.unwrap();
417        match result {
418            ToolResult::Success(msg) => {
419                assert_eq!(msg, "System executed: system command");
420            },
421            _ => panic!("Expected Success result"),
422        }
423    }
424
425    #[tokio::test]
426    async fn test_tool_registry_basic_operations() {
427        let mut registry = ToolRegistry::new();
428
429        // Test empty registry
430        assert_eq!(registry.count(), 0);
431        assert!(registry.list().is_empty());
432        assert!(!registry.contains("nonexistent"));
433
434        // Register a tool
435        registry.register(Box::new(MockBasicTool::new("tool1")));
436        assert_eq!(registry.count(), 1);
437        assert!(registry.contains("tool1"));
438
439        // Test list
440        let tools = registry.list();
441        assert_eq!(tools.len(), 1);
442        assert!(tools.contains(&"tool1".to_string()));
443
444        // Test get
445        let tool = registry.get("tool1");
446        assert!(tool.is_some());
447        assert_eq!(tool.unwrap().name(), "tool1");
448
449        // Test get non-existent
450        assert!(registry.get("nonexistent").is_none());
451    }
452
453    #[tokio::test]
454    async fn test_tool_registry_multiple_tools() {
455        let mut registry = ToolRegistry::new();
456
457        // Register multiple tools
458        registry.register(Box::new(MockBasicTool::new("basic1")));
459        registry.register(Box::new(MockBasicTool::new("basic2")));
460        registry.register(Box::new(MockNetworkTool));
461        registry.register(Box::new(MockSystemTool));
462
463        assert_eq!(registry.count(), 4);
464
465        let tools = registry.list();
466        assert_eq!(tools.len(), 4);
467        assert!(tools.contains(&"basic1".to_string()));
468        assert!(tools.contains(&"basic2".to_string()));
469        assert!(tools.contains(&"network_tool".to_string()));
470        assert!(tools.contains(&"system_tool".to_string()));
471    }
472
473    #[test]
474    fn test_tool_registry_tools_by_capability() {
475        let mut registry = ToolRegistry::new();
476
477        registry.register(Box::new(MockBasicTool::new("basic1")));
478        registry.register(Box::new(MockBasicTool::new("basic2")));
479        registry.register(Box::new(MockNetworkTool));
480        registry.register(Box::new(MockSystemTool));
481
482        // Test Basic capability (should include basic1, basic2, network_tool)
483        let basic_tools = registry.tools_by_capability(ToolCapability::Basic);
484        assert_eq!(basic_tools.len(), 3);
485
486        // Test NetworkAccess capability (should include only network_tool)
487        let network_tools = registry.tools_by_capability(ToolCapability::NetworkAccess);
488        assert_eq!(network_tools.len(), 1);
489        assert_eq!(network_tools[0].name(), "network_tool");
490
491        // Test SystemAccess capability (should include only system_tool)
492        let system_tools = registry.tools_by_capability(ToolCapability::SystemAccess);
493        assert_eq!(system_tools.len(), 1);
494        assert_eq!(system_tools[0].name(), "system_tool");
495
496        // Test WasmPlugin capability (should include only system_tool)
497        let wasm_tools = registry.tools_by_capability(ToolCapability::WasmPlugin);
498        assert_eq!(wasm_tools.len(), 1);
499        assert_eq!(wasm_tools[0].name(), "system_tool");
500    }
501
502    #[test]
503    fn test_tool_registry_get_capabilities() {
504        let mut registry = ToolRegistry::new();
505
506        registry.register(Box::new(MockBasicTool::new("basic_tool")));
507        registry.register(Box::new(MockNetworkTool));
508
509        // Test getting capabilities for existing tools
510        let basic_caps = registry.get_capabilities("basic_tool");
511        assert!(basic_caps.is_some());
512        assert_eq!(basic_caps.unwrap(), vec![ToolCapability::Basic]);
513
514        let network_caps = registry.get_capabilities("network_tool");
515        assert!(network_caps.is_some());
516        assert_eq!(network_caps.unwrap(), vec![ToolCapability::NetworkAccess, ToolCapability::Basic]);
517
518        // Test getting capabilities for non-existent tool
519        let nonexistent_caps = registry.get_capabilities("nonexistent");
520        assert!(nonexistent_caps.is_none());
521    }
522
523    #[test]
524    fn test_tool_registry_remove() {
525        let mut registry = ToolRegistry::new();
526
527        registry.register(Box::new(MockBasicTool::new("removable_tool")));
528        registry.register(Box::new(MockBasicTool::new("permanent_tool")));
529
530        assert_eq!(registry.count(), 2);
531        assert!(registry.contains("removable_tool"));
532
533        // Remove existing tool
534        let removed = registry.remove("removable_tool");
535        assert!(removed.is_some());
536        assert_eq!(removed.unwrap().name(), "removable_tool");
537
538        assert_eq!(registry.count(), 1);
539        assert!(!registry.contains("removable_tool"));
540        assert!(registry.contains("permanent_tool"));
541
542        // Try to remove non-existent tool
543        let not_removed = registry.remove("nonexistent");
544        assert!(not_removed.is_none());
545        assert_eq!(registry.count(), 1);
546    }
547
548    #[test]
549    fn test_tool_registry_clear() {
550        let mut registry = ToolRegistry::new();
551
552        registry.register(Box::new(MockBasicTool::new("tool1")));
553        registry.register(Box::new(MockBasicTool::new("tool2")));
554        registry.register(Box::new(MockNetworkTool));
555
556        assert_eq!(registry.count(), 3);
557
558        registry.clear();
559
560        assert_eq!(registry.count(), 0);
561        assert!(registry.list().is_empty());
562        assert!(!registry.contains("tool1"));
563        assert!(!registry.contains("tool2"));
564        assert!(!registry.contains("network_tool"));
565    }
566
567    #[test]
568    fn test_tool_registry_overwrite() {
569        let mut registry = ToolRegistry::new();
570
571        // Register a tool
572        registry.register(Box::new(MockBasicTool::new("tool1")));
573        assert_eq!(registry.count(), 1);
574
575        // Register another tool with the same name (should overwrite)
576        registry.register(Box::new(MockBasicTool::new("tool1")));
577        assert_eq!(registry.count(), 1);
578
579        let tool = registry.get("tool1").unwrap();
580        assert_eq!(tool.name(), "tool1");
581    }
582
583    #[tokio::test]
584    async fn test_tool_trait_object_usage() {
585        // Test using Tool as a trait object
586        let tool: Box<dyn Tool> = Box::new(MockBasicTool::new("trait_object_tool"));
587
588        assert_eq!(tool.name(), "trait_object_tool");
589        assert_eq!(tool.capabilities(), vec![ToolCapability::Basic]);
590
591        let result = tool.invoke("test").await.unwrap();
592        match result {
593            ToolResult::Success(msg) => assert_eq!(msg, "mock success"),
594            _ => panic!("Expected Success result"),
595        }
596    }
597
598    #[tokio::test]
599    async fn test_multiple_capability_tool() {
600        let tool = MockNetworkTool;
601        let capabilities = tool.capabilities();
602
603        assert_eq!(capabilities.len(), 2);
604        assert!(capabilities.contains(&ToolCapability::NetworkAccess));
605        assert!(capabilities.contains(&ToolCapability::Basic));
606
607        // Test that the tool appears in searches for both capabilities
608        let mut registry = ToolRegistry::new();
609        registry.register(Box::new(MockNetworkTool));
610
611        let basic_tools = registry.tools_by_capability(ToolCapability::Basic);
612        assert_eq!(basic_tools.len(), 1);
613
614        let network_tools = registry.tools_by_capability(ToolCapability::NetworkAccess);
615        assert_eq!(network_tools.len(), 1);
616
617        assert_eq!(basic_tools[0].name(), network_tools[0].name());
618    }
619
620    #[tokio::test]
621    async fn test_tool_result_cloning() {
622        // Test ToolResult can be cloned (needed for MockBasicTool)
623        let original = ToolResult::Success("cloneable".to_string());
624        let cloned = original.clone();
625
626        match (original, cloned) {
627            (ToolResult::Success(orig), ToolResult::Success(clone)) => {
628                assert_eq!(orig, clone);
629            },
630            _ => panic!("Clone failed"),
631        }
632
633        let json_original = ToolResult::StructuredJson(json!({"clone": "test"}));
634        let json_cloned = json_original.clone();
635
636        match (json_original, json_cloned) {
637            (ToolResult::StructuredJson(orig), ToolResult::StructuredJson(clone)) => {
638                assert_eq!(orig, clone);
639            },
640            _ => panic!("JSON clone failed"),
641        }
642
643        let error_original = ToolResult::Error("cloneable error".to_string());
644        let error_cloned = error_original.clone();
645
646        match (error_original, error_cloned) {
647            (ToolResult::Error(orig), ToolResult::Error(clone)) => {
648                assert_eq!(orig, clone);
649            },
650            _ => panic!("Error clone failed"),
651        }
652    }
653
654    #[tokio::test]
655    async fn test_edge_cases() {
656        let mut registry = ToolRegistry::new();
657
658        // Test with empty tool name (edge case)
659        struct EmptyNameTool;
660
661        #[async_trait]
662        impl Tool for EmptyNameTool {
663            fn name(&self) -> &'static str {
664                ""
665            }
666
667            fn capabilities(&self) -> Vec<ToolCapability> {
668                vec![]
669            }
670
671            async fn invoke(&self, _input: &str) -> Result<ToolResult, RustChainError> {
672                Ok(ToolResult::Success("empty name tool".to_string()))
673            }
674        }
675
676        registry.register(Box::new(EmptyNameTool));
677        assert_eq!(registry.count(), 1);
678        assert!(registry.contains(""));
679
680        let tool = registry.get("").unwrap();
681        assert_eq!(tool.name(), "");
682        assert!(tool.capabilities().is_empty());
683
684        // Test tool with no capabilities
685        let no_cap_tools = registry.tools_by_capability(ToolCapability::Basic);
686        assert_eq!(no_cap_tools.len(), 0);
687    }
688
689    #[test]
690    fn test_large_registry_performance() {
691        let mut registry = ToolRegistry::new();
692
693        // Register many tools to test performance
694        for i in 0..100 {
695            registry.register(Box::new(MockBasicTool::new("tool").with_result(
696                ToolResult::Success(format!("Tool {}", i))
697            )));
698        }
699
700        // The registry should overwrite tools with the same name
701        assert_eq!(registry.count(), 1); // Only one tool named "tool"
702
703        // Register tools with unique names
704        registry.clear();
705        for i in 0..50 {
706            let tool_name = format!("tool_{}", i);
707            // Since MockBasicTool only accepts &'static str, we'll create a different approach
708            struct UniqueNameTool {
709                index: usize,
710            }
711
712            #[async_trait]
713            impl Tool for UniqueNameTool {
714                fn name(&self) -> &'static str {
715                    // This is a limitation - we can't easily create unique static strings
716                    // For now, we'll use a fixed name but test the registry behavior
717                    "unique_tool"
718                }
719
720                fn capabilities(&self) -> Vec<ToolCapability> {
721                    vec![ToolCapability::Basic]
722                }
723
724                async fn invoke(&self, _input: &str) -> Result<ToolResult, RustChainError> {
725                    Ok(ToolResult::Success(format!("Tool {}", self.index)))
726                }
727            }
728
729            registry.register(Box::new(UniqueNameTool { index: i }));
730        }
731
732        // All tools have the same name, so only 1 should remain
733        assert_eq!(registry.count(), 1);
734
735        // Test that operations still work efficiently
736        let tools = registry.list();
737        assert_eq!(tools.len(), 1);
738
739        let basic_tools = registry.tools_by_capability(ToolCapability::Basic);
740        assert_eq!(basic_tools.len(), 1);
741    }
742}