1use crate::{Result, ToolCall, ToolDefinition, ToolResult};
4use async_trait::async_trait;
5use serde_json::Value;
6
7#[async_trait]
9pub trait Tool: Send + Sync {
10 fn definition(&self) -> &ToolDefinition;
12
13 async fn execute(&self, args: Value) -> Result<ToolResult>;
15
16 fn name(&self) -> &str {
18 &self.definition().name
19 }
20
21 fn validate(&self, args: &Value) -> Result<()> {
23 self.definition().validate_args(args)
24 }
25}
26
27#[async_trait]
29pub trait Transport: Send + Sync {
30 async fn connect(&mut self) -> Result<()>;
32
33 async fn disconnect(&mut self) -> Result<()>;
35
36 fn is_connected(&self) -> bool;
38
39 async fn list_tools(&self) -> Result<Vec<ToolDefinition>>;
41
42 async fn call(&self, call: &ToolCall) -> Result<ToolResult>;
44}
45
46#[cfg(test)]
47mod tests {
48 use super::*;
49 use crate::Parameter;
50 use serde_json::json;
51 use std::sync::atomic::{AtomicBool, Ordering};
52 use std::sync::Arc;
53
54 struct MockTool {
56 definition: ToolDefinition,
57 execute_result: ToolResult,
58 }
59
60 impl MockTool {
61 fn new(name: &str, result: ToolResult) -> Self {
62 Self {
63 definition: ToolDefinition::new(name),
64 execute_result: result,
65 }
66 }
67 }
68
69 #[async_trait]
70 impl Tool for MockTool {
71 fn definition(&self) -> &ToolDefinition {
72 &self.definition
73 }
74
75 async fn execute(&self, _args: Value) -> Result<ToolResult> {
76 Ok(self.execute_result.clone())
77 }
78 }
79
80 struct MockTransport {
82 connected: Arc<AtomicBool>,
83 tools: Vec<ToolDefinition>,
84 }
85
86 impl MockTransport {
87 fn new(tools: Vec<ToolDefinition>) -> Self {
88 Self {
89 connected: Arc::new(AtomicBool::new(false)),
90 tools,
91 }
92 }
93 }
94
95 #[async_trait]
96 impl Transport for MockTransport {
97 async fn connect(&mut self) -> Result<()> {
98 self.connected.store(true, Ordering::SeqCst);
99 Ok(())
100 }
101
102 async fn disconnect(&mut self) -> Result<()> {
103 self.connected.store(false, Ordering::SeqCst);
104 Ok(())
105 }
106
107 fn is_connected(&self) -> bool {
108 self.connected.load(Ordering::SeqCst)
109 }
110
111 async fn list_tools(&self) -> Result<Vec<ToolDefinition>> {
112 Ok(self.tools.clone())
113 }
114
115 async fn call(&self, call: &ToolCall) -> Result<ToolResult> {
116 Ok(ToolResult::success(json!({
117 "tool": call.tool,
118 "called": true
119 })))
120 }
121 }
122
123 #[tokio::test]
124 async fn tool_trait_execute() {
125 let tool = MockTool::new("test", ToolResult::success(json!({"result": "ok"})));
126
127 let result = tool.execute(json!({})).await.unwrap();
128 assert!(result.is_success());
129 assert_eq!(result.data.unwrap()["result"], "ok");
130 }
131
132 #[tokio::test]
133 async fn tool_trait_name() {
134 let tool = MockTool::new("my_tool", ToolResult::success(json!(null)));
135 assert_eq!(tool.name(), "my_tool");
136 }
137
138 #[tokio::test]
139 async fn tool_trait_validate() {
140 let mut tool = MockTool::new("test", ToolResult::success(json!(null)));
141 tool.definition = ToolDefinition::builder("test")
142 .parameter(Parameter::required_string("name"))
143 .build();
144
145 assert!(tool.validate(&json!({"name": "value"})).is_ok());
147
148 assert!(tool.validate(&json!({})).is_err());
150 }
151
152 #[tokio::test]
153 async fn transport_trait_connect_disconnect() {
154 let mut transport = MockTransport::new(vec![]);
155
156 assert!(!transport.is_connected());
157
158 transport.connect().await.unwrap();
159 assert!(transport.is_connected());
160
161 transport.disconnect().await.unwrap();
162 assert!(!transport.is_connected());
163 }
164
165 #[tokio::test]
166 async fn transport_trait_list_tools() {
167 let tools = vec![ToolDefinition::new("tool1"), ToolDefinition::new("tool2")];
168 let transport = MockTransport::new(tools.clone());
169
170 let listed = transport.list_tools().await.unwrap();
171 assert_eq!(listed.len(), 2);
172 assert_eq!(listed[0].name, "tool1");
173 assert_eq!(listed[1].name, "tool2");
174 }
175
176 #[tokio::test]
177 async fn transport_trait_call() {
178 let transport = MockTransport::new(vec![]);
179
180 let call = ToolCall::new("test_tool");
181 let result = transport.call(&call).await.unwrap();
182
183 assert!(result.is_success());
184 assert_eq!(result.data.unwrap()["tool"], "test_tool");
185 }
186}