1use crate::builtin_tools::BuiltinTool;
6use crate::types::{Layer3Result, ToolCategory};
7use async_trait::async_trait;
8use futures::{SinkExt, StreamExt};
9use std::time::Duration;
10use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
11
12pub struct HttpGetTool;
18
19#[async_trait]
20impl BuiltinTool for HttpGetTool {
21 fn name(&self) -> &str {
22 "http_get"
23 }
24
25 fn description(&self) -> &str {
26 "Make an HTTP GET request and return the response body."
27 }
28
29 fn parameters_schema(&self) -> serde_json::Value {
30 serde_json::json!({
31 "type": "object",
32 "properties": {
33 "url": {
34 "type": "string",
35 "description": "URL to request"
36 },
37 "timeout": {
38 "type": "integer",
39 "description": "Request timeout in seconds (default: 30)"
40 },
41 "headers": {
42 "type": "object",
43 "description": "Optional headers to include"
44 }
45 },
46 "required": ["url"]
47 })
48 }
49
50 fn category(&self) -> ToolCategory {
51 ToolCategory::Network
52 }
53
54 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
55 let url = args["url"]
56 .as_str()
57 .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
58
59 let timeout_secs = args["timeout"].as_u64().unwrap_or(30);
60
61 let mut request = reqwest::Client::new()
62 .get(url)
63 .timeout(std::time::Duration::from_secs(timeout_secs));
64
65 if let Some(headers) = args["headers"].as_object() {
66 for (key, value) in headers {
67 if let Some(v) = value.as_str() {
68 request = request.header(key, v);
69 }
70 }
71 }
72
73 let response = request
74 .send()
75 .await
76 .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
77
78 let status = response.status();
79 let body = response
80 .text()
81 .await
82 .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
83
84 Ok(format!("Status: {}\n\n{}", status, body))
85 }
86}
87
88pub struct HttpPostTool;
94
95#[async_trait]
96impl BuiltinTool for HttpPostTool {
97 fn name(&self) -> &str {
98 "http_post"
99 }
100
101 fn description(&self) -> &str {
102 "Make an HTTP POST request with optional body."
103 }
104
105 fn parameters_schema(&self) -> serde_json::Value {
106 serde_json::json!({
107 "type": "object",
108 "properties": {
109 "url": {
110 "type": "string",
111 "description": "URL to request"
112 },
113 "body": {
114 "type": "string",
115 "description": "Request body (JSON string or plain text)"
116 },
117 "content_type": {
118 "type": "string",
119 "description": "Content-Type header (default: application/json)"
120 },
121 "timeout": {
122 "type": "integer",
123 "description": "Request timeout in seconds (default: 30)"
124 },
125 "headers": {
126 "type": "object",
127 "description": "Optional headers to include"
128 }
129 },
130 "required": ["url"]
131 })
132 }
133
134 fn category(&self) -> ToolCategory {
135 ToolCategory::Network
136 }
137
138 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
139 let url = args["url"]
140 .as_str()
141 .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
142
143 let timeout_secs = args["timeout"].as_u64().unwrap_or(30);
144 let content_type = args["content_type"].as_str().unwrap_or("application/json");
145 let body = args["body"].as_str().unwrap_or("");
146
147 let mut request = reqwest::Client::new()
148 .post(url)
149 .timeout(std::time::Duration::from_secs(timeout_secs))
150 .header("Content-Type", content_type)
151 .body(body.to_string());
152
153 if let Some(headers) = args["headers"].as_object() {
154 for (key, value) in headers {
155 if let Some(v) = value.as_str() {
156 request = request.header(key, v);
157 }
158 }
159 }
160
161 let response = request
162 .send()
163 .await
164 .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
165
166 let status = response.status();
167 let response_body = response
168 .text()
169 .await
170 .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
171
172 Ok(format!("Status: {}\n\n{}", status, response_body))
173 }
174}
175
176pub struct DownloadFileTool;
182
183#[async_trait]
184impl BuiltinTool for DownloadFileTool {
185 fn name(&self) -> &str {
186 "download_file"
187 }
188
189 fn description(&self) -> &str {
190 "Download a file from URL and save to local path."
191 }
192
193 fn parameters_schema(&self) -> serde_json::Value {
194 serde_json::json!({
195 "type": "object",
196 "properties": {
197 "url": {
198 "type": "string",
199 "description": "URL to download from"
200 },
201 "path": {
202 "type": "string",
203 "description": "Local path to save the file"
204 },
205 "timeout": {
206 "type": "integer",
207 "description": "Download timeout in seconds (default: 60)"
208 }
209 },
210 "required": ["url", "path"]
211 })
212 }
213
214 fn category(&self) -> ToolCategory {
215 ToolCategory::Network
216 }
217
218 fn requires_confirmation(&self) -> bool {
219 true
220 }
221
222 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
223 let url = args["url"]
224 .as_str()
225 .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
226
227 let path = args["path"]
228 .as_str()
229 .ok_or_else(|| anyhow::anyhow!("Missing path parameter"))?;
230
231 let timeout_secs = args["timeout"].as_u64().unwrap_or(60);
232
233 let response = reqwest::Client::new()
234 .get(url)
235 .timeout(std::time::Duration::from_secs(timeout_secs))
236 .send()
237 .await
238 .map_err(|e| anyhow::anyhow!("Download failed: {}", e))?;
239
240 if !response.status().is_success() {
241 return Err(anyhow::anyhow!(
242 "Download failed with status: {}",
243 response.status()
244 ));
245 }
246
247 let bytes = response
248 .bytes()
249 .await
250 .map_err(|e| anyhow::anyhow!("Failed to read response: {}", e))?;
251
252 let file_path = std::path::Path::new(path);
254 if let Some(parent) = file_path.parent() {
255 std::fs::create_dir_all(parent)
256 .map_err(|e| anyhow::anyhow!("Failed to create directory: {}", e))?;
257 }
258
259 std::fs::write(file_path, &bytes)
260 .map_err(|e| anyhow::anyhow!("Failed to write file: {}", e))?;
261
262 Ok(format!("Downloaded {} bytes to {}", bytes.len(), path))
263 }
264}
265
266pub struct WebSocketConnectTool;
272
273#[async_trait]
274impl BuiltinTool for WebSocketConnectTool {
275 fn name(&self) -> &str {
276 "websocket_connect"
277 }
278
279 fn description(&self) -> &str {
280 "Connect to a WebSocket server and send/receive messages. Returns initial messages."
281 }
282
283 fn parameters_schema(&self) -> serde_json::Value {
284 serde_json::json!({
285 "type": "object",
286 "properties": {
287 "url": {
288 "type": "string",
289 "description": "WebSocket URL (ws:// or wss://)"
290 },
291 "message": {
292 "type": "string",
293 "description": "Initial message to send"
294 },
295 "receive_count": {
296 "type": "integer",
297 "description": "Number of messages to receive (default: 1)"
298 },
299 "timeout": {
300 "type": "integer",
301 "description": "Timeout in seconds (default: 10)"
302 }
303 },
304 "required": ["url"]
305 })
306 }
307
308 fn category(&self) -> ToolCategory {
309 ToolCategory::Network
310 }
311
312 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
313 let url = args["url"]
314 .as_str()
315 .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
316
317 let message = args["message"].as_str().unwrap_or("");
318 let receive_count = args["receive_count"].as_u64().unwrap_or(1).min(100) as usize;
319 let timeout_secs = args["timeout"].as_u64().unwrap_or(10);
320
321 let connect_future = connect_async(url);
323 let connection = tokio::time::timeout(Duration::from_secs(timeout_secs), connect_future)
324 .await
325 .map_err(|_| anyhow::anyhow!("WebSocket connection timeout after {}s", timeout_secs))?
326 .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {}", e))?;
327
328 let (mut ws_stream, _response) = connection;
329
330 if !message.is_empty() {
332 ws_stream
333 .send(WsMessage::Text(message.into()))
334 .await
335 .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
336 }
337
338 let mut received_messages: Vec<String> = Vec::new();
340 let receive_timeout = Duration::from_secs(timeout_secs);
341
342 for _ in 0..receive_count {
343 match tokio::time::timeout(receive_timeout, ws_stream.next()).await {
344 Ok(Some(Ok(msg))) => {
345 match msg {
346 WsMessage::Text(text) => received_messages.push(text.to_string()),
347 WsMessage::Binary(data) => {
348 received_messages.push(format!("<binary: {} bytes>", data.len()));
349 }
350 WsMessage::Ping(ping) => {
351 let _ = ws_stream.send(WsMessage::Pong(ping)).await;
353 }
354 WsMessage::Close(_) => {
355 received_messages.push("<connection closed>".to_string());
356 break;
357 }
358 _ => {}
359 }
360 }
361 Ok(Some(Err(e))) => {
362 return Err(anyhow::anyhow!("WebSocket error: {}", e));
363 }
364 Ok(None) => {
365 received_messages.push("<stream ended>".to_string());
366 break;
367 }
368 Err(_) => {
369 if received_messages.is_empty() {
370 return Err(anyhow::anyhow!(
371 "No message received within {}s",
372 timeout_secs
373 ));
374 }
375 break;
376 }
377 }
378 }
379
380 let _ = ws_stream.close(None).await;
382
383 if received_messages.is_empty() {
384 Ok(format!(
385 "WebSocket connected to {} (no messages received)",
386 url
387 ))
388 } else {
389 Ok(format!(
390 "WebSocket connected to {}:\n{}",
391 url,
392 received_messages.join("\n")
393 ))
394 }
395 }
396}
397
398pub struct PingTool;
404
405#[async_trait]
406impl BuiltinTool for PingTool {
407 fn name(&self) -> &str {
408 "ping"
409 }
410
411 fn description(&self) -> &str {
412 "Check if a host is reachable via HTTP HEAD request."
413 }
414
415 fn parameters_schema(&self) -> serde_json::Value {
416 serde_json::json!({
417 "type": "object",
418 "properties": {
419 "url": {
420 "type": "string",
421 "description": "URL to ping"
422 },
423 "timeout": {
424 "type": "integer",
425 "description": "Timeout in seconds (default: 5)"
426 }
427 },
428 "required": ["url"]
429 })
430 }
431
432 fn category(&self) -> ToolCategory {
433 ToolCategory::Network
434 }
435
436 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
437 let url = args["url"]
438 .as_str()
439 .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
440
441 let timeout_secs = args["timeout"].as_u64().unwrap_or(5);
442
443 let start = std::time::Instant::now();
444
445 let response = reqwest::Client::new()
446 .head(url)
447 .timeout(std::time::Duration::from_secs(timeout_secs))
448 .send()
449 .await;
450
451 let elapsed_ms = start.elapsed().as_millis();
452
453 match response {
454 Ok(resp) => {
455 let status = resp.status();
456 Ok(format!(
457 "Ping successful: {} (status: {}, {}ms)",
458 url, status, elapsed_ms
459 ))
460 }
461 Err(e) => Ok(format!(
462 "Ping failed: {} (error: {}, {}ms)",
463 url, e, elapsed_ms
464 )),
465 }
466 }
467}
468
469pub struct DnsLookupTool;
475
476#[async_trait]
477impl BuiltinTool for DnsLookupTool {
478 fn name(&self) -> &str {
479 "dns_lookup"
480 }
481
482 fn description(&self) -> &str {
483 "Resolve DNS for a hostname. Returns IP addresses."
484 }
485
486 fn parameters_schema(&self) -> serde_json::Value {
487 serde_json::json!({
488 "type": "object",
489 "properties": {
490 "hostname": {
491 "type": "string",
492 "description": "Hostname to resolve"
493 }
494 },
495 "required": ["hostname"]
496 })
497 }
498
499 fn category(&self) -> ToolCategory {
500 ToolCategory::Network
501 }
502
503 async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
504 let hostname = args["hostname"]
505 .as_str()
506 .ok_or_else(|| anyhow::anyhow!("Missing hostname parameter"))?;
507
508 use tokio::net::lookup_host;
509
510 let addresses = lookup_host(hostname)
511 .await
512 .map_err(|e| anyhow::anyhow!("DNS lookup failed: {}", e))?;
513
514 let results: Vec<String> = addresses.map(|addr| addr.ip().to_string()).collect();
515
516 Ok(format!("Resolved {} to:\n{}", hostname, results.join("\n")))
517 }
518}
519
520#[cfg(test)]
525mod tests {
526 use super::*;
527 use serde_json::json;
528
529 #[test]
530 fn test_http_get_category() {
531 let tool = HttpGetTool;
532 assert_eq!(tool.category(), ToolCategory::Network);
533 }
534
535 #[test]
536 fn test_http_post_category() {
537 let tool = HttpPostTool;
538 assert_eq!(tool.category(), ToolCategory::Network);
539 }
540
541 #[test]
542 fn test_download_file_requires_confirmation() {
543 let tool = DownloadFileTool;
544 assert!(tool.requires_confirmation());
545 }
546
547 #[tokio::test]
548 async fn test_ping_format() {
549 let tool = PingTool;
550 let result = tool.execute(json!({"url": "https://example.com"})).await;
551 assert!(result.is_ok());
553 let output = result.unwrap();
554 assert!(output.contains("Ping") || output.contains("example.com"));
555 }
556}