vtcode_core/tools/
curl_tool.rs

1//! Sandboxed curl-like tool with strict safety guarantees
2
3use super::traits::Tool;
4use crate::config::constants::tools;
5use anyhow::{Context, Result, anyhow};
6use async_trait::async_trait;
7use futures::StreamExt;
8use rand::{Rng, distributions::Alphanumeric};
9use reqwest::{Client, Method, Url};
10use serde::Deserialize;
11use serde_json::{Value, json};
12use std::fs;
13use std::net::IpAddr;
14use std::path::PathBuf;
15use std::time::Duration;
16use tracing::warn;
17
18const DEFAULT_TIMEOUT_SECS: u64 = 10;
19const MAX_TIMEOUT_SECS: u64 = 30;
20const DEFAULT_MAX_BYTES: usize = 64 * 1024;
21const TEMP_SUBDIR: &str = "vtcode-curl";
22const SECURITY_NOTICE: &str = "Sandboxed HTTPS-only curl wrapper executed. Verify the target URL and delete any temporary files under /tmp when you finish reviewing the response.";
23
24#[derive(Debug, Deserialize)]
25struct CurlToolArgs {
26    url: String,
27    #[serde(default)]
28    method: Option<String>,
29    #[serde(default)]
30    max_bytes: Option<usize>,
31    #[serde(default)]
32    timeout_secs: Option<u64>,
33    #[serde(default)]
34    save_response: Option<bool>,
35}
36
37/// Secure HTTP fetch tool with aggressive validation
38#[derive(Clone)]
39pub struct CurlTool {
40    client: Client,
41    temp_root: PathBuf,
42}
43
44impl CurlTool {
45    pub fn new() -> Self {
46        let client = Client::builder()
47            .redirect(reqwest::redirect::Policy::none())
48            .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
49            .user_agent("vtcode-sandboxed-curl/0.1")
50            .build()
51            .unwrap_or_else(|error| {
52                warn!(
53                    ?error,
54                    "Failed to build dedicated curl client; falling back to default"
55                );
56                Client::new()
57            });
58        let temp_root = std::env::temp_dir().join(TEMP_SUBDIR);
59        Self { client, temp_root }
60    }
61}
62
63impl Default for CurlTool {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl CurlTool {
70    fn write_temp_file(&self, data: &[u8]) -> Result<PathBuf> {
71        if !self.temp_root.exists() {
72            fs::create_dir_all(&self.temp_root)
73                .context("Failed to create temporary directory for curl tool")?;
74        }
75
76        let mut rng = rand::thread_rng();
77        let suffix: String = (&mut rng)
78            .sample_iter(&Alphanumeric)
79            .take(10)
80            .map(char::from)
81            .collect();
82
83        let path = self
84            .temp_root
85            .join(format!("response-{}.txt", suffix.to_lowercase()));
86        fs::write(&path, data)
87            .with_context(|| format!("Failed to write temporary file at {}", path.display()))?;
88        Ok(path)
89    }
90    async fn run(&self, raw_args: Value) -> Result<Value> {
91        let args: CurlToolArgs = serde_json::from_value(raw_args)
92            .context("Invalid arguments for curl tool. Provide an object with at least a 'url'.")?;
93
94        let method = self.normalize_method(args.method)?;
95        if method == Method::HEAD && args.save_response.unwrap_or(false) {
96            return Err(anyhow!(
97                "Cannot save a response body when performing a HEAD request. Set save_response=false or use GET."
98            ));
99        }
100
101        let url = Url::parse(&args.url).context("Invalid URL provided to curl tool")?;
102        self.validate_url(&url)?;
103
104        let timeout = args
105            .timeout_secs
106            .unwrap_or(DEFAULT_TIMEOUT_SECS)
107            .min(MAX_TIMEOUT_SECS);
108        let max_bytes = args
109            .max_bytes
110            .unwrap_or(DEFAULT_MAX_BYTES)
111            .min(DEFAULT_MAX_BYTES);
112
113        if max_bytes == 0 {
114            return Err(anyhow!("max_bytes must be greater than zero"));
115        }
116
117        let request = self
118            .client
119            .request(method.clone(), url.clone())
120            .timeout(Duration::from_secs(timeout))
121            .header(
122                reqwest::header::ACCEPT,
123                "text/plain, text/*, application/json, application/xml, application/yaml",
124            );
125
126        let response = request
127            .send()
128            .await
129            .with_context(|| format!("Failed to execute HTTPS request to {}", url))?;
130
131        let status = response.status();
132        if !status.is_success() {
133            return Err(anyhow!("Request returned non-success status: {}", status));
134        }
135
136        if let Some(length) = response.content_length()
137            && length > max_bytes as u64
138        {
139            return Err(anyhow!(
140                "Remote response is {} bytes which exceeds the policy limit of {} bytes",
141                length,
142                max_bytes
143            ));
144        }
145
146        let content_type = response
147            .headers()
148            .get(reqwest::header::CONTENT_TYPE)
149            .and_then(|value| value.to_str().ok())
150            .unwrap_or("")
151            .to_string();
152        self.validate_content_type(&content_type)?;
153
154        if method == Method::HEAD {
155            return Ok(json!({
156                "success": true,
157                "url": url.to_string(),
158                "status": status.as_u16(),
159                "content_type": content_type,
160                "content_length": response.content_length(),
161                "security_notice": SECURITY_NOTICE,
162            }));
163        }
164
165        let mut total_bytes: usize = 0;
166        let mut buffer: Vec<u8> = Vec::new();
167        let mut truncated = false;
168
169        let mut stream = response.bytes_stream();
170        while let Some(chunk) = stream.next().await {
171            let bytes =
172                chunk.with_context(|| format!("Failed to read response chunk from {}", url))?;
173            total_bytes = total_bytes.saturating_add(bytes.len());
174            if buffer.len() < max_bytes {
175                let remaining = max_bytes - buffer.len();
176                if bytes.len() > remaining {
177                    buffer.extend_from_slice(&bytes[..remaining]);
178                    truncated = true;
179                } else {
180                    buffer.extend_from_slice(&bytes);
181                }
182            } else {
183                truncated = true;
184            }
185            if buffer.len() >= max_bytes {
186                truncated = true;
187                break;
188            }
189        }
190
191        let body_text = String::from_utf8_lossy(&buffer).to_string();
192        let saved_path = if args.save_response.unwrap_or(false) && !buffer.is_empty() {
193            Some(self.write_temp_file(&buffer)?)
194        } else {
195            None
196        };
197
198        let saved_path_str = saved_path.as_ref().map(|path| path.display().to_string());
199        let cleanup_hint = saved_path
200            .as_ref()
201            .map(|path| format!("rm {}", path.display()));
202
203        Ok(json!({
204            "success": true,
205            "url": url.to_string(),
206            "status": status.as_u16(),
207            "content_type": content_type,
208            "bytes_read": total_bytes,
209            "body": body_text,
210            "truncated": truncated,
211            "saved_path": saved_path_str,
212            "cleanup_hint": cleanup_hint,
213            "security_notice": SECURITY_NOTICE,
214        }))
215    }
216
217    fn normalize_method(&self, method: Option<String>) -> Result<Method> {
218        let requested = method.unwrap_or_else(|| "GET".to_string());
219        let normalized = requested.trim().to_uppercase();
220        match normalized.as_str() {
221            "GET" => Ok(Method::GET),
222            "HEAD" => Ok(Method::HEAD),
223            other => Err(anyhow!(
224                "HTTP method '{}' is not permitted. Only GET or HEAD are allowed.",
225                other
226            )),
227        }
228    }
229
230    fn validate_url(&self, url: &Url) -> Result<()> {
231        if url.scheme() != "https" {
232            return Err(anyhow!("Only HTTPS URLs are allowed"));
233        }
234
235        if !url.username().is_empty() || url.password().is_some() {
236            return Err(anyhow!("Credentials in URLs are not supported"));
237        }
238
239        let host = url
240            .host_str()
241            .ok_or_else(|| anyhow!("URL must include a host"))?
242            .to_lowercase();
243
244        if host.parse::<IpAddr>().is_ok() {
245            return Err(anyhow!("IP address targets are blocked for security"));
246        }
247
248        let forbidden_hosts = ["localhost", "127.0.0.1", "0.0.0.0", "::1"];
249
250        if forbidden_hosts
251            .iter()
252            .any(|blocked| host == *blocked || host.ends_with(&format!(".{}", blocked)))
253        {
254            return Err(anyhow!("Access to local or loopback hosts is blocked"));
255        }
256
257        let forbidden_suffixes = [".localhost", ".local", ".internal", ".lan"];
258        if forbidden_suffixes
259            .iter()
260            .any(|suffix| host.ends_with(suffix))
261        {
262            return Err(anyhow!("Private network hosts are not permitted"));
263        }
264
265        if let Some(port) = url.port()
266            && port != 443
267        {
268            return Err(anyhow!("Custom HTTPS ports are blocked by policy"));
269        }
270
271        Ok(())
272    }
273
274    fn validate_content_type(&self, content_type: &str) -> Result<()> {
275        if content_type.is_empty() {
276            return Ok(());
277        }
278        let lowered = content_type.to_lowercase();
279        let allowed = lowered.starts_with("text/")
280            || lowered.contains("json")
281            || lowered.contains("xml")
282            || lowered.contains("yaml")
283            || lowered.contains("toml")
284            || lowered.contains("javascript");
285        if allowed {
286            Ok(())
287        } else {
288            Err(anyhow!(
289                "Content type '{}' is not allowed. Only text or structured text responses are supported.",
290                content_type
291            ))
292        }
293    }
294}
295
296#[async_trait]
297impl Tool for CurlTool {
298    async fn execute(&self, args: Value) -> Result<Value> {
299        self.run(args).await
300    }
301
302    fn name(&self) -> &'static str {
303        tools::CURL
304    }
305
306    fn description(&self) -> &'static str {
307        "Fetches HTTPS text content with strict validation and security notices."
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use serde_json::json;
315
316    #[tokio::test]
317    async fn rejects_non_https_urls() {
318        let tool = CurlTool::new();
319        let result = tool
320            .execute(json!({
321                "url": "http://example.com"
322            }))
323            .await;
324        assert!(result.is_err());
325    }
326
327    #[tokio::test]
328    async fn rejects_local_targets() {
329        let tool = CurlTool::new();
330        let result = tool
331            .execute(json!({
332                "url": "https://localhost/resource"
333            }))
334            .await;
335        assert!(result.is_err());
336    }
337
338    #[tokio::test]
339    async fn rejects_disallowed_methods() {
340        let tool = CurlTool::new();
341        let result = tool
342            .execute(json!({
343                "url": "https://example.com/resource",
344                "method": "POST"
345            }))
346            .await;
347        assert!(result.is_err());
348    }
349}