Skip to main content

sh_layer3/builtin_tools/
network.rs

1//! # Network Tools
2//!
3//! 网络请求工具集。
4
5use crate::builtin_tools::BuiltinTool;
6use crate::types::{Layer3Result, ToolCategory};
7use async_trait::async_trait;
8use std::time::Duration;
9
10/// HTTP Request Tool
11pub struct HttpRequestTool;
12
13#[async_trait]
14impl BuiltinTool for HttpRequestTool {
15    fn name(&self) -> &str {
16        "http_request"
17    }
18
19    fn description(&self) -> &str {
20        "Make an HTTP request to a URL."
21    }
22
23    fn parameters_schema(&self) -> serde_json::Value {
24        serde_json::json!({
25            "type": "object",
26            "properties": {
27                "url": {
28                    "type": "string",
29                    "description": "The URL to request"
30                },
31                "method": {
32                    "type": "string",
33                    "enum": ["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"],
34                    "description": "HTTP method (default: GET)"
35                },
36                "headers": {
37                    "type": "object",
38                    "description": "Optional: request headers as key-value pairs"
39                },
40                "body": {
41                    "type": "string",
42                    "description": "Optional: request body (for POST/PUT/PATCH)"
43                },
44                "timeout": {
45                    "type": "integer",
46                    "description": "Optional: timeout in seconds (default: 30)"
47                }
48            },
49            "required": ["url"]
50        })
51    }
52
53    fn category(&self) -> ToolCategory {
54        ToolCategory::Network
55    }
56
57    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
58        let url = args["url"]
59            .as_str()
60            .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
61
62        let method = args["method"].as_str().unwrap_or("GET").to_uppercase();
63        let timeout_secs = args["timeout"].as_u64().unwrap_or(30);
64
65        // Build client
66        let client = reqwest::Client::builder()
67            .timeout(Duration::from_secs(timeout_secs))
68            .user_agent("Continuum/1.0")
69            .build()
70            .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
71
72        // Build request
73        let mut request = match method.as_str() {
74            "GET" => client.get(url),
75            "POST" => client.post(url),
76            "PUT" => client.put(url),
77            "DELETE" => client.delete(url),
78            "HEAD" => client.head(url),
79            "PATCH" => client.patch(url),
80            _ => client.get(url),
81        };
82
83        // Add headers
84        if let Some(headers) = args["headers"].as_object() {
85            for (key, value) in headers {
86                if let Some(val_str) = value.as_str() {
87                    request = request.header(key, val_str);
88                }
89            }
90        }
91
92        // Add body
93        if let Some(body) = args["body"].as_str() {
94            request = request.body(body.to_string());
95        }
96
97        // Execute
98        let response = request
99            .send()
100            .await
101            .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
102
103        let status = response.status();
104        let headers = response.headers().clone();
105
106        // Get body (or empty for HEAD)
107        let body = if method == "HEAD" {
108            String::new()
109        } else {
110            response
111                .text()
112                .await
113                .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?
114        };
115
116        // Format result
117        let mut result = format!(
118            "Status: {} {}\n",
119            status.as_u16(),
120            status.canonical_reason().unwrap_or("")
121        );
122        result.push_str("Headers:\n");
123        for (name, value) in headers.iter() {
124            result.push_str(&format!(
125                "  {}: {}\n",
126                name,
127                value.to_str().unwrap_or("<binary>")
128            ));
129        }
130        if !body.is_empty() {
131            result.push_str("\nBody:\n");
132            // Limit body display
133            if body.len() > 5000 {
134                result.push_str(&format!(
135                    "{}...\n(truncated, {} bytes total)",
136                    &body[..5000],
137                    body.len()
138                ));
139            } else {
140                result.push_str(&body);
141            }
142        }
143
144        Ok(result)
145    }
146}
147
148/// Web Fetch Tool - 简化的网页抓取
149pub struct WebFetchTool;
150
151#[async_trait]
152impl BuiltinTool for WebFetchTool {
153    fn name(&self) -> &str {
154        "web_fetch"
155    }
156
157    fn description(&self) -> &str {
158        "Fetch and extract text content from a webpage."
159    }
160
161    fn parameters_schema(&self) -> serde_json::Value {
162        serde_json::json!({
163            "type": "object",
164            "properties": {
165                "url": {
166                    "type": "string",
167                    "description": "The URL to fetch"
168                },
169                "selector": {
170                    "type": "string",
171                    "description": "Optional: CSS selector to extract specific content"
172                }
173            },
174            "required": ["url"]
175        })
176    }
177
178    fn category(&self) -> ToolCategory {
179        ToolCategory::Network
180    }
181
182    async fn execute(&self, args: serde_json::Value) -> Layer3Result<String> {
183        let url = args["url"]
184            .as_str()
185            .ok_or_else(|| anyhow::anyhow!("Missing url parameter"))?;
186
187        let client = reqwest::Client::builder()
188            .timeout(Duration::from_secs(30))
189            .user_agent("Continuum/1.0")
190            .build()
191            .map_err(|e| anyhow::anyhow!("Failed to create HTTP client: {}", e))?;
192
193        let response = client
194            .get(url)
195            .send()
196            .await
197            .map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
198
199        if !response.status().is_success() {
200            return Err(anyhow::anyhow!("HTTP error: {}", response.status()));
201        }
202
203        let body = response
204            .text()
205            .await
206            .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
207
208        // Simple HTML to text extraction (strip tags)
209        let text = extract_text_from_html(&body);
210
211        // Limit output
212        if text.len() > 10000 {
213            Ok(format!(
214                "{}...\n\n(truncated, {} chars total)",
215                &text[..10000],
216                text.len()
217            ))
218        } else {
219            Ok(text)
220        }
221    }
222}
223
224/// 简单的 HTML 文本提取
225fn extract_text_from_html(html: &str) -> String {
226    // 移除 script 和 style 标签内容
227    let mut result = html.to_string();
228
229    // 移除 script 标签
230    while let Some(start) = result.find("<script") {
231        if let Some(end) = result.find("</script>").map(|e| e + 9) {
232            if end > start {
233                result.replace_range(start..end, "");
234            } else {
235                break;
236            }
237        } else {
238            break;
239        }
240    }
241
242    // 移除 style 标签
243    while let Some(start) = result.find("<style") {
244        if let Some(end) = result.find("</style>").map(|e| e + 8) {
245            if end > start {
246                result.replace_range(start..end, "");
247            } else {
248                break;
249            }
250        } else {
251            break;
252        }
253    }
254
255    // 移除所有 HTML 标签
256    let mut text = String::new();
257    let mut in_tag = false;
258    for c in result.chars() {
259        if c == '<' {
260            in_tag = true;
261        } else if c == '>' {
262            in_tag = false;
263        } else if !in_tag {
264            text.push(c);
265        }
266    }
267
268    // 清理多余空白
269    text = text.split_whitespace().collect::<Vec<_>>().join(" ");
270    text
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use serde_json::json;
277
278    #[test]
279    fn test_http_tool_category() {
280        let tool = HttpRequestTool;
281        assert_eq!(tool.category(), ToolCategory::Network);
282    }
283
284    #[test]
285    fn test_web_fetch_tool_category() {
286        let tool = WebFetchTool;
287        assert_eq!(tool.category(), ToolCategory::Network);
288    }
289
290    #[test]
291    fn test_extract_text_from_html() {
292        let html = "<html><body><h1>Title</h1><p>Content here</p></body></html>";
293        let text = extract_text_from_html(html);
294        assert!(text.contains("Title"));
295        assert!(text.contains("Content"));
296    }
297
298    #[tokio::test]
299    async fn test_http_request_missing_url() {
300        let tool = HttpRequestTool;
301        let result = tool.execute(json!({})).await;
302        assert!(result.is_err());
303        assert!(result.unwrap_err().to_string().contains("Missing url"));
304    }
305}