thulp_core/
traits.rs

1//! Core traits for thulp.
2
3use crate::{Result, ToolCall, ToolDefinition, ToolResult};
4use async_trait::async_trait;
5use serde_json::Value;
6
7/// Trait for executable tools.
8#[async_trait]
9pub trait Tool: Send + Sync {
10    /// Get the tool definition.
11    fn definition(&self) -> &ToolDefinition;
12
13    /// Execute the tool with the given arguments.
14    async fn execute(&self, args: Value) -> Result<ToolResult>;
15
16    /// Get the tool name.
17    fn name(&self) -> &str {
18        &self.definition().name
19    }
20
21    /// Validate arguments before execution.
22    fn validate(&self, args: &Value) -> Result<()> {
23        self.definition().validate_args(args)
24    }
25}
26
27/// Trait for communication transports (MCP, HTTP, etc.).
28#[async_trait]
29pub trait Transport: Send + Sync {
30    /// Connect to the transport.
31    async fn connect(&mut self) -> Result<()>;
32
33    /// Disconnect from the transport.
34    async fn disconnect(&mut self) -> Result<()>;
35
36    /// Check if connected.
37    fn is_connected(&self) -> bool;
38
39    /// List available tools.
40    async fn list_tools(&self) -> Result<Vec<ToolDefinition>>;
41
42    /// Execute a tool call.
43    async fn call(&self, call: &ToolCall) -> Result<ToolResult>;
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49    use crate::Parameter;
50    use serde_json::json;
51    use std::sync::atomic::{AtomicBool, Ordering};
52    use std::sync::Arc;
53
54    // Mock tool for testing
55    struct MockTool {
56        definition: ToolDefinition,
57        execute_result: ToolResult,
58    }
59
60    impl MockTool {
61        fn new(name: &str, result: ToolResult) -> Self {
62            Self {
63                definition: ToolDefinition::new(name),
64                execute_result: result,
65            }
66        }
67    }
68
69    #[async_trait]
70    impl Tool for MockTool {
71        fn definition(&self) -> &ToolDefinition {
72            &self.definition
73        }
74
75        async fn execute(&self, _args: Value) -> Result<ToolResult> {
76            Ok(self.execute_result.clone())
77        }
78    }
79
80    // Mock transport for testing
81    struct MockTransport {
82        connected: Arc<AtomicBool>,
83        tools: Vec<ToolDefinition>,
84    }
85
86    impl MockTransport {
87        fn new(tools: Vec<ToolDefinition>) -> Self {
88            Self {
89                connected: Arc::new(AtomicBool::new(false)),
90                tools,
91            }
92        }
93    }
94
95    #[async_trait]
96    impl Transport for MockTransport {
97        async fn connect(&mut self) -> Result<()> {
98            self.connected.store(true, Ordering::SeqCst);
99            Ok(())
100        }
101
102        async fn disconnect(&mut self) -> Result<()> {
103            self.connected.store(false, Ordering::SeqCst);
104            Ok(())
105        }
106
107        fn is_connected(&self) -> bool {
108            self.connected.load(Ordering::SeqCst)
109        }
110
111        async fn list_tools(&self) -> Result<Vec<ToolDefinition>> {
112            Ok(self.tools.clone())
113        }
114
115        async fn call(&self, call: &ToolCall) -> Result<ToolResult> {
116            Ok(ToolResult::success(json!({
117                "tool": call.tool,
118                "called": true
119            })))
120        }
121    }
122
123    #[tokio::test]
124    async fn tool_trait_execute() {
125        let tool = MockTool::new("test", ToolResult::success(json!({"result": "ok"})));
126
127        let result = tool.execute(json!({})).await.unwrap();
128        assert!(result.is_success());
129        assert_eq!(result.data.unwrap()["result"], "ok");
130    }
131
132    #[tokio::test]
133    async fn tool_trait_name() {
134        let tool = MockTool::new("my_tool", ToolResult::success(json!(null)));
135        assert_eq!(tool.name(), "my_tool");
136    }
137
138    #[tokio::test]
139    async fn tool_trait_validate() {
140        let mut tool = MockTool::new("test", ToolResult::success(json!(null)));
141        tool.definition = ToolDefinition::builder("test")
142            .parameter(Parameter::required_string("name"))
143            .build();
144
145        // Valid args
146        assert!(tool.validate(&json!({"name": "value"})).is_ok());
147
148        // Missing required
149        assert!(tool.validate(&json!({})).is_err());
150    }
151
152    #[tokio::test]
153    async fn transport_trait_connect_disconnect() {
154        let mut transport = MockTransport::new(vec![]);
155
156        assert!(!transport.is_connected());
157
158        transport.connect().await.unwrap();
159        assert!(transport.is_connected());
160
161        transport.disconnect().await.unwrap();
162        assert!(!transport.is_connected());
163    }
164
165    #[tokio::test]
166    async fn transport_trait_list_tools() {
167        let tools = vec![ToolDefinition::new("tool1"), ToolDefinition::new("tool2")];
168        let transport = MockTransport::new(tools.clone());
169
170        let listed = transport.list_tools().await.unwrap();
171        assert_eq!(listed.len(), 2);
172        assert_eq!(listed[0].name, "tool1");
173        assert_eq!(listed[1].name, "tool2");
174    }
175
176    #[tokio::test]
177    async fn transport_trait_call() {
178        let transport = MockTransport::new(vec![]);
179
180        let call = ToolCall::new("test_tool");
181        let result = transport.call(&call).await.unwrap();
182
183        assert!(result.is_success());
184        assert_eq!(result.data.unwrap()["tool"], "test_tool");
185    }
186}