Skip to main content

sh_layer3/builtin_tools/
network_tools.rs

1//! # Network Tools
2//!
3//! 网络工具集:HTTP 请求、文件下载、WebSocket 等。
4
5use crate::builtin_tools::BuiltinTool;
6use crate::types::{Layer3Result, ToolCategory};
7use async_trait::async_trait;
8use futures::{SinkExt, StreamExt};
9use std::time::Duration;
10use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
11
12// ============================================================================
13// HTTP GET Tool
14// ============================================================================
15
16/// HTTP GET 请求工具
17pub struct HttpGetTool;
18
19#[async_trait]
20impl BuiltinTool for HttpGetTool {
21    fn name(&self) -> &str {
22        "http_get"
23    }
24
25    fn description(&self) -> &str {
26        "Make an HTTP GET request and return the response body."
27    }
28
29    fn parameters_schema(&self) -> serde_json::Value {
30        serde_json::json!({
31            "type": "object",
32            "properties": {
33                "url": {
34                    "type": "string",
35                    "description": "URL to request"
36                },
37                "timeout": {
38                    "type": "integer",
39                    "description": "Request timeout in seconds (default: 30)"
40                },
41                "headers": {
42                    "type": "object",
43                    "description": "Optional headers to include"
44                }
45            },
46            "required": ["url"]
47        })
48    }
49
50    fn category(&self) -> ToolCategory {
51        ToolCategory::Network
52    }
53
54    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
55        let url = args["url"]
56            .as_str()
57            .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
58
59        let timeout_secs = args["timeout"].as_u64().unwrap_or(30);
60
61        let mut request = reqwest::Client::new()
62            .get(url)
63            .timeout(std::time::Duration::from_secs(timeout_secs));
64
65        if let Some(headers) = args["headers"].as_object() {
66            for (key, value) in headers {
67                if let Some(v) = value.as_str() {
68                    request = request.header(key, v);
69                }
70            }
71        }
72
73        let response = request
74            .send()
75            .await
76            .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
77
78        let status = response.status();
79        let body = response
80            .text()
81            .await
82            .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
83
84        Ok(format!("Status: {}\n\n{}", status, body))
85    }
86}
87
88// ============================================================================
89// HTTP POST Tool
90// ============================================================================
91
92/// HTTP POST 请求工具
93pub struct HttpPostTool;
94
95#[async_trait]
96impl BuiltinTool for HttpPostTool {
97    fn name(&self) -> &str {
98        "http_post"
99    }
100
101    fn description(&self) -> &str {
102        "Make an HTTP POST request with optional body."
103    }
104
105    fn parameters_schema(&self) -> serde_json::Value {
106        serde_json::json!({
107            "type": "object",
108            "properties": {
109                "url": {
110                    "type": "string",
111                    "description": "URL to request"
112                },
113                "body": {
114                    "type": "string",
115                    "description": "Request body (JSON string or plain text)"
116                },
117                "content_type": {
118                    "type": "string",
119                    "description": "Content-Type header (default: application/json)"
120                },
121                "timeout": {
122                    "type": "integer",
123                    "description": "Request timeout in seconds (default: 30)"
124                },
125                "headers": {
126                    "type": "object",
127                    "description": "Optional headers to include"
128                }
129            },
130            "required": ["url"]
131        })
132    }
133
134    fn category(&self) -> ToolCategory {
135        ToolCategory::Network
136    }
137
138    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
139        let url = args["url"]
140            .as_str()
141            .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
142
143        let timeout_secs = args["timeout"].as_u64().unwrap_or(30);
144        let content_type = args["content_type"].as_str().unwrap_or("application/json");
145        let body = args["body"].as_str().unwrap_or("");
146
147        let mut request = reqwest::Client::new()
148            .post(url)
149            .timeout(std::time::Duration::from_secs(timeout_secs))
150            .header("Content-Type", content_type)
151            .body(body.to_string());
152
153        if let Some(headers) = args["headers"].as_object() {
154            for (key, value) in headers {
155                if let Some(v) = value.as_str() {
156                    request = request.header(key, v);
157                }
158            }
159        }
160
161        let response = request
162            .send()
163            .await
164            .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
165
166        let status = response.status();
167        let response_body = response
168            .text()
169            .await
170            .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
171
172        Ok(format!("Status: {}\n\n{}", status, response_body))
173    }
174}
175
176// ============================================================================
177// Download File Tool
178// ============================================================================
179
180/// 文件下载工具
181pub struct DownloadFileTool;
182
183#[async_trait]
184impl BuiltinTool for DownloadFileTool {
185    fn name(&self) -> &str {
186        "download_file"
187    }
188
189    fn description(&self) -> &str {
190        "Download a file from URL and save to local path."
191    }
192
193    fn parameters_schema(&self) -> serde_json::Value {
194        serde_json::json!({
195            "type": "object",
196            "properties": {
197                "url": {
198                    "type": "string",
199                    "description": "URL to download from"
200                },
201                "path": {
202                    "type": "string",
203                    "description": "Local path to save the file"
204                },
205                "timeout": {
206                    "type": "integer",
207                    "description": "Download timeout in seconds (default: 60)"
208                }
209            },
210            "required": ["url", "path"]
211        })
212    }
213
214    fn category(&self) -> ToolCategory {
215        ToolCategory::Network
216    }
217
218    fn requires_confirmation(&self) -> bool {
219        true
220    }
221
222    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
223        let url = args["url"]
224            .as_str()
225            .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
226
227        let path = args["path"]
228            .as_str()
229            .ok_or_else(|| anyhow::anyhow!("Missing path parameter"))?;
230
231        let timeout_secs = args["timeout"].as_u64().unwrap_or(60);
232
233        let response = reqwest::Client::new()
234            .get(url)
235            .timeout(std::time::Duration::from_secs(timeout_secs))
236            .send()
237            .await
238            .map_err(|e| anyhow::anyhow!("Download failed: {}", e))?;
239
240        if !response.status().is_success() {
241            return Err(anyhow::anyhow!(
242                "Download failed with status: {}",
243                response.status()
244            ));
245        }
246
247        let bytes = response
248            .bytes()
249            .await
250            .map_err(|e| anyhow::anyhow!("Failed to read response: {}", e))?;
251
252        // Create parent directory if needed
253        let file_path = std::path::Path::new(path);
254        if let Some(parent) = file_path.parent() {
255            std::fs::create_dir_all(parent)
256                .map_err(|e| anyhow::anyhow!("Failed to create directory: {}", e))?;
257        }
258
259        std::fs::write(file_path, &bytes)
260            .map_err(|e| anyhow::anyhow!("Failed to write file: {}", e))?;
261
262        Ok(format!("Downloaded {} bytes to {}", bytes.len(), path))
263    }
264}
265
266// ============================================================================
267// WebSocket Connect Tool
268// ============================================================================
269
270/// WebSocket 连接工具
271pub struct WebSocketConnectTool;
272
273#[async_trait]
274impl BuiltinTool for WebSocketConnectTool {
275    fn name(&self) -> &str {
276        "websocket_connect"
277    }
278
279    fn description(&self) -> &str {
280        "Connect to a WebSocket server and send/receive messages. Returns initial messages."
281    }
282
283    fn parameters_schema(&self) -> serde_json::Value {
284        serde_json::json!({
285            "type": "object",
286            "properties": {
287                "url": {
288                    "type": "string",
289                    "description": "WebSocket URL (ws:// or wss://)"
290                },
291                "message": {
292                    "type": "string",
293                    "description": "Initial message to send"
294                },
295                "receive_count": {
296                    "type": "integer",
297                    "description": "Number of messages to receive (default: 1)"
298                },
299                "timeout": {
300                    "type": "integer",
301                    "description": "Timeout in seconds (default: 10)"
302                }
303            },
304            "required": ["url"]
305        })
306    }
307
308    fn category(&self) -> ToolCategory {
309        ToolCategory::Network
310    }
311
312    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
313        let url = args["url"]
314            .as_str()
315            .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
316
317        let message = args["message"].as_str().unwrap_or("");
318        let receive_count = args["receive_count"].as_u64().unwrap_or(1).min(100) as usize;
319        let timeout_secs = args["timeout"].as_u64().unwrap_or(10);
320
321        // Connect to WebSocket with timeout
322        let connect_future = connect_async(url);
323        let connection = tokio::time::timeout(Duration::from_secs(timeout_secs), connect_future)
324            .await
325            .map_err(|_| anyhow::anyhow!("WebSocket connection timeout after {}s", timeout_secs))?
326            .map_err(|e| anyhow::anyhow!("WebSocket connection failed: {}", e))?;
327
328        let (mut ws_stream, _response) = connection;
329
330        // Send initial message if provided
331        if !message.is_empty() {
332            ws_stream
333                .send(WsMessage::Text(message.into()))
334                .await
335                .map_err(|e| anyhow::anyhow!("Failed to send message: {}", e))?;
336        }
337
338        // Receive messages
339        let mut received_messages: Vec<String> = Vec::new();
340        let receive_timeout = Duration::from_secs(timeout_secs);
341
342        for _ in 0..receive_count {
343            match tokio::time::timeout(receive_timeout, ws_stream.next()).await {
344                Ok(Some(Ok(msg))) => {
345                    match msg {
346                        WsMessage::Text(text) => received_messages.push(text.to_string()),
347                        WsMessage::Binary(data) => {
348                            received_messages.push(format!("<binary: {} bytes>", data.len()));
349                        }
350                        WsMessage::Ping(ping) => {
351                            // Respond to ping with pong
352                            let _ = ws_stream.send(WsMessage::Pong(ping)).await;
353                        }
354                        WsMessage::Close(_) => {
355                            received_messages.push("<connection closed>".to_string());
356                            break;
357                        }
358                        _ => {}
359                    }
360                }
361                Ok(Some(Err(e))) => {
362                    return Err(anyhow::anyhow!("WebSocket error: {}", e));
363                }
364                Ok(None) => {
365                    received_messages.push("<stream ended>".to_string());
366                    break;
367                }
368                Err(_) => {
369                    if received_messages.is_empty() {
370                        return Err(anyhow::anyhow!(
371                            "No message received within {}s",
372                            timeout_secs
373                        ));
374                    }
375                    break;
376                }
377            }
378        }
379
380        // Close connection
381        let _ = ws_stream.close(None).await;
382
383        if received_messages.is_empty() {
384            Ok(format!(
385                "WebSocket connected to {} (no messages received)",
386                url
387            ))
388        } else {
389            Ok(format!(
390                "WebSocket connected to {}:\n{}",
391                url,
392                received_messages.join("\n")
393            ))
394        }
395    }
396}
397
398// ============================================================================
399// Ping Tool
400// ============================================================================
401
402/// Ping 工具
403pub struct PingTool;
404
405#[async_trait]
406impl BuiltinTool for PingTool {
407    fn name(&self) -> &str {
408        "ping"
409    }
410
411    fn description(&self) -> &str {
412        "Check if a host is reachable via HTTP HEAD request."
413    }
414
415    fn parameters_schema(&self) -> serde_json::Value {
416        serde_json::json!({
417            "type": "object",
418            "properties": {
419                "url": {
420                    "type": "string",
421                    "description": "URL to ping"
422                },
423                "timeout": {
424                    "type": "integer",
425                    "description": "Timeout in seconds (default: 5)"
426                }
427            },
428            "required": ["url"]
429        })
430    }
431
432    fn category(&self) -> ToolCategory {
433        ToolCategory::Network
434    }
435
436    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
437        let url = args["url"]
438            .as_str()
439            .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
440
441        let timeout_secs = args["timeout"].as_u64().unwrap_or(5);
442
443        let start = std::time::Instant::now();
444
445        let response = reqwest::Client::new()
446            .head(url)
447            .timeout(std::time::Duration::from_secs(timeout_secs))
448            .send()
449            .await;
450
451        let elapsed_ms = start.elapsed().as_millis();
452
453        match response {
454            Ok(resp) => {
455                let status = resp.status();
456                Ok(format!(
457                    "Ping successful: {} (status: {}, {}ms)",
458                    url, status, elapsed_ms
459                ))
460            }
461            Err(e) => Ok(format!(
462                "Ping failed: {} (error: {}, {}ms)",
463                url, e, elapsed_ms
464            )),
465        }
466    }
467}
468
469// ============================================================================
470// DNS Lookup Tool
471// ============================================================================
472
473/// DNS 解析工具
474pub struct DnsLookupTool;
475
476#[async_trait]
477impl BuiltinTool for DnsLookupTool {
478    fn name(&self) -> &str {
479        "dns_lookup"
480    }
481
482    fn description(&self) -> &str {
483        "Resolve DNS for a hostname. Returns IP addresses."
484    }
485
486    fn parameters_schema(&self) -> serde_json::Value {
487        serde_json::json!({
488            "type": "object",
489            "properties": {
490                "hostname": {
491                    "type": "string",
492                    "description": "Hostname to resolve"
493                }
494            },
495            "required": ["hostname"]
496        })
497    }
498
499    fn category(&self) -> ToolCategory {
500        ToolCategory::Network
501    }
502
503    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
504        let hostname = args["hostname"]
505            .as_str()
506            .ok_or_else(|| anyhow::anyhow!("Missing hostname parameter"))?;
507
508        use tokio::net::lookup_host;
509
510        let addresses = lookup_host(hostname)
511            .await
512            .map_err(|e| anyhow::anyhow!("DNS lookup failed: {}", e))?;
513
514        let results: Vec<String> = addresses.map(|addr| addr.ip().to_string()).collect();
515
516        Ok(format!("Resolved {} to:\n{}", hostname, results.join("\n")))
517    }
518}
519
520// ============================================================================
521// Tests
522// ============================================================================
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use serde_json::json;
528
529    #[test]
530    fn test_http_get_category() {
531        let tool = HttpGetTool;
532        assert_eq!(tool.category(), ToolCategory::Network);
533    }
534
535    #[test]
536    fn test_http_post_category() {
537        let tool = HttpPostTool;
538        assert_eq!(tool.category(), ToolCategory::Network);
539    }
540
541    #[test]
542    fn test_download_file_requires_confirmation() {
543        let tool = DownloadFileTool;
544        assert!(tool.requires_confirmation());
545    }
546
547    #[tokio::test]
548    async fn test_ping_format() {
549        let tool = PingTool;
550        let result = tool.execute(json!({"url": "https://example.com"})).await;
551        // Will succeed or fail based on network, but should be formatted correctly
552        assert!(result.is_ok());
553        let output = result.unwrap();
554        assert!(output.contains("Ping") || output.contains("example.com"));
555    }
556}