vtcode_core/tools/
curl_tool.rs

1//! Sandboxed curl-like tool with strict safety guarantees
2
3use super::traits::Tool;
4use crate::config::constants::tools;
5use anyhow::{Context, Result, anyhow};
6use async_trait::async_trait;
7use futures::StreamExt;
8use rand::{Rng, distributions::Alphanumeric};
9use reqwest::{Client, Method, Url};
10use serde::Deserialize;
11use serde_json::{Value, json};
12use std::fs;
13use std::net::IpAddr;
14use std::path::PathBuf;
15use std::time::Duration;
16use tracing::warn;
17
18const DEFAULT_TIMEOUT_SECS: u64 = 10;
19const MAX_TIMEOUT_SECS: u64 = 30;
20const DEFAULT_MAX_BYTES: usize = 64 * 1024;
21const TEMP_SUBDIR: &str = "vtcode-curl";
22const SECURITY_NOTICE: &str = "Sandboxed HTTPS-only curl wrapper executed. Verify the target URL and delete any temporary files under /tmp when you finish reviewing the response.";
23
24#[derive(Debug, Deserialize)]
25struct CurlToolArgs {
26    url: String,
27    #[serde(default)]
28    method: Option<String>,
29    #[serde(default)]
30    max_bytes: Option<usize>,
31    #[serde(default)]
32    timeout_secs: Option<u64>,
33    #[serde(default)]
34    save_response: Option<bool>,
35}
36
37/// Secure HTTP fetch tool with aggressive validation
38#[derive(Clone)]
39pub struct CurlTool {
40    client: Client,
41    temp_root: PathBuf,
42}
43
44impl CurlTool {
45    pub fn new() -> Self {
46        let client = Client::builder()
47            .redirect(reqwest::redirect::Policy::none())
48            .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
49            .user_agent("vtcode-sandboxed-curl/0.1")
50            .build()
51            .unwrap_or_else(|error| {
52                warn!(
53                    ?error,
54                    "Failed to build dedicated curl client; falling back to default"
55                );
56                Client::new()
57            });
58        let temp_root = std::env::temp_dir().join(TEMP_SUBDIR);
59        Self { client, temp_root }
60    }
61
62    async fn run(&self, raw_args: Value) -> Result<Value> {
63        let args: CurlToolArgs = serde_json::from_value(raw_args)
64            .context("Invalid arguments for curl tool. Provide an object with at least a 'url'.")?;
65
66        let method = self.normalize_method(args.method)?;
67        if method == Method::HEAD && args.save_response.unwrap_or(false) {
68            return Err(anyhow!(
69                "Cannot save a response body when performing a HEAD request. Set save_response=false or use GET."
70            ));
71        }
72
73        let url = Url::parse(&args.url).context("Invalid URL provided to curl tool")?;
74        self.validate_url(&url)?;
75
76        let timeout = args
77            .timeout_secs
78            .unwrap_or(DEFAULT_TIMEOUT_SECS)
79            .min(MAX_TIMEOUT_SECS);
80        let max_bytes = args
81            .max_bytes
82            .unwrap_or(DEFAULT_MAX_BYTES)
83            .min(DEFAULT_MAX_BYTES);
84
85        if max_bytes == 0 {
86            return Err(anyhow!("max_bytes must be greater than zero"));
87        }
88
89        let request = self
90            .client
91            .request(method.clone(), url.clone())
92            .timeout(Duration::from_secs(timeout))
93            .header(
94                reqwest::header::ACCEPT,
95                "text/plain, text/*, application/json, application/xml, application/yaml",
96            );
97
98        let response = request
99            .send()
100            .await
101            .with_context(|| format!("Failed to execute HTTPS request to {}", url))?;
102
103        let status = response.status();
104        if !status.is_success() {
105            return Err(anyhow!("Request returned non-success status: {}", status));
106        }
107
108        if let Some(length) = response.content_length() {
109            if length > max_bytes as u64 {
110                return Err(anyhow!(
111                    "Remote response is {} bytes which exceeds the policy limit of {} bytes",
112                    length,
113                    max_bytes
114                ));
115            }
116        }
117
118        let content_type = response
119            .headers()
120            .get(reqwest::header::CONTENT_TYPE)
121            .and_then(|value| value.to_str().ok())
122            .unwrap_or("")
123            .to_string();
124        self.validate_content_type(&content_type)?;
125
126        if method == Method::HEAD {
127            return Ok(json!({
128                "success": true,
129                "url": url.to_string(),
130                "status": status.as_u16(),
131                "content_type": content_type,
132                "content_length": response.content_length(),
133                "security_notice": SECURITY_NOTICE,
134            }));
135        }
136
137        let mut total_bytes: usize = 0;
138        let mut buffer: Vec<u8> = Vec::new();
139        let mut truncated = false;
140
141        let mut stream = response.bytes_stream();
142        while let Some(chunk) = stream.next().await {
143            let bytes =
144                chunk.with_context(|| format!("Failed to read response chunk from {}", url))?;
145            total_bytes = total_bytes.saturating_add(bytes.len());
146            if buffer.len() < max_bytes {
147                let remaining = max_bytes - buffer.len();
148                if bytes.len() > remaining {
149                    buffer.extend_from_slice(&bytes[..remaining]);
150                    truncated = true;
151                } else {
152                    buffer.extend_from_slice(&bytes);
153                }
154            } else {
155                truncated = true;
156            }
157            if buffer.len() >= max_bytes {
158                truncated = true;
159                break;
160            }
161        }
162
163        let body_text = String::from_utf8_lossy(&buffer).to_string();
164        let saved_path = if args.save_response.unwrap_or(false) && !buffer.is_empty() {
165            Some(self.write_temp_file(&buffer)?)
166        } else {
167            None
168        };
169
170        let saved_path_str = saved_path.as_ref().map(|path| path.display().to_string());
171        let cleanup_hint = saved_path
172            .as_ref()
173            .map(|path| format!("rm {}", path.display()));
174
175        Ok(json!({
176            "success": true,
177            "url": url.to_string(),
178            "status": status.as_u16(),
179            "content_type": content_type,
180            "bytes_read": total_bytes,
181            "body": body_text,
182            "truncated": truncated,
183            "saved_path": saved_path_str,
184            "cleanup_hint": cleanup_hint,
185            "security_notice": SECURITY_NOTICE,
186        }))
187    }
188
189    fn normalize_method(&self, method: Option<String>) -> Result<Method> {
190        let requested = method.unwrap_or_else(|| "GET".to_string());
191        let normalized = requested.trim().to_uppercase();
192        match normalized.as_str() {
193            "GET" => Ok(Method::GET),
194            "HEAD" => Ok(Method::HEAD),
195            other => Err(anyhow!(
196                "HTTP method '{}' is not permitted. Only GET or HEAD are allowed.",
197                other
198            )),
199        }
200    }
201
202    fn validate_url(&self, url: &Url) -> Result<()> {
203        if url.scheme() != "https" {
204            return Err(anyhow!("Only HTTPS URLs are allowed"));
205        }
206
207        if !url.username().is_empty() || url.password().is_some() {
208            return Err(anyhow!("Credentials in URLs are not supported"));
209        }
210
211        let host = url
212            .host_str()
213            .ok_or_else(|| anyhow!("URL must include a host"))?
214            .to_lowercase();
215
216        if host.parse::<IpAddr>().is_ok() {
217            return Err(anyhow!("IP address targets are blocked for security"));
218        }
219
220        let forbidden_hosts = ["localhost", "127.0.0.1", "0.0.0.0", "::1"];
221
222        if forbidden_hosts
223            .iter()
224            .any(|blocked| host == *blocked || host.ends_with(&format!(".{}", blocked)))
225        {
226            return Err(anyhow!("Access to local or loopback hosts is blocked"));
227        }
228
229        let forbidden_suffixes = [".localhost", ".local", ".internal", ".lan"];
230        if forbidden_suffixes
231            .iter()
232            .any(|suffix| host.ends_with(suffix))
233        {
234            return Err(anyhow!("Private network hosts are not permitted"));
235        }
236
237        if let Some(port) = url.port() {
238            if port != 443 {
239                return Err(anyhow!("Custom HTTPS ports are blocked by policy"));
240            }
241        }
242
243        Ok(())
244    }
245
246    fn validate_content_type(&self, content_type: &str) -> Result<()> {
247        if content_type.is_empty() {
248            return Ok(());
249        }
250        let lowered = content_type.to_lowercase();
251        let allowed = lowered.starts_with("text/")
252            || lowered.contains("json")
253            || lowered.contains("xml")
254            || lowered.contains("yaml")
255            || lowered.contains("toml")
256            || lowered.contains("javascript");
257        if allowed {
258            Ok(())
259        } else {
260            Err(anyhow!(
261                "Content type '{}' is not allowed. Only text or structured text responses are supported.",
262                content_type
263            ))
264        }
265    }
266
267    fn write_temp_file(&self, data: &[u8]) -> Result<PathBuf> {
268        if !self.temp_root.exists() {
269            fs::create_dir_all(&self.temp_root)
270                .context("Failed to create temporary directory for curl tool")?;
271        }
272
273        let mut rng = rand::thread_rng();
274        let suffix: String = (&mut rng)
275            .sample_iter(&Alphanumeric)
276            .take(10)
277            .map(char::from)
278            .collect();
279
280        let path = self
281            .temp_root
282            .join(format!("response-{}.txt", suffix.to_lowercase()));
283        fs::write(&path, data)
284            .with_context(|| format!("Failed to write temporary file at {}", path.display()))?;
285        Ok(path)
286    }
287}
288
289#[async_trait]
290impl Tool for CurlTool {
291    async fn execute(&self, args: Value) -> Result<Value> {
292        self.run(args).await
293    }
294
295    fn name(&self) -> &'static str {
296        tools::CURL
297    }
298
299    fn description(&self) -> &'static str {
300        "Fetches HTTPS text content with strict validation and security notices."
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use serde_json::json;
308
309    #[tokio::test]
310    async fn rejects_non_https_urls() {
311        let tool = CurlTool::new();
312        let result = tool
313            .execute(json!({
314                "url": "http://example.com"
315            }))
316            .await;
317        assert!(result.is_err());
318    }
319
320    #[tokio::test]
321    async fn rejects_local_targets() {
322        let tool = CurlTool::new();
323        let result = tool
324            .execute(json!({
325                "url": "https://localhost/resource"
326            }))
327            .await;
328        assert!(result.is_err());
329    }
330
331    #[tokio::test]
332    async fn rejects_disallowed_methods() {
333        let tool = CurlTool::new();
334        let result = tool
335            .execute(json!({
336                "url": "https://example.com/resource",
337                "method": "POST"
338            }))
339            .await;
340        assert!(result.is_err());
341    }
342}