vtcode_core/tools/
curl_tool.rs1use super::traits::Tool;
4use crate::config::constants::tools;
5use anyhow::{Context, Result, anyhow};
6use async_trait::async_trait;
7use futures::StreamExt;
8use rand::{Rng, distributions::Alphanumeric};
9use reqwest::{Client, Method, Url};
10use serde::Deserialize;
11use serde_json::{Value, json};
12use std::fs;
13use std::net::IpAddr;
14use std::path::PathBuf;
15use std::time::Duration;
16use tracing::warn;
17
18const DEFAULT_TIMEOUT_SECS: u64 = 10;
19const MAX_TIMEOUT_SECS: u64 = 30;
20const DEFAULT_MAX_BYTES: usize = 64 * 1024;
21const TEMP_SUBDIR: &str = "vtcode-curl";
22const SECURITY_NOTICE: &str = "Sandboxed HTTPS-only curl wrapper executed. Verify the target URL and delete any temporary files under /tmp when you finish reviewing the response.";
23
24#[derive(Debug, Deserialize)]
25struct CurlToolArgs {
26 url: String,
27 #[serde(default)]
28 method: Option<String>,
29 #[serde(default)]
30 max_bytes: Option<usize>,
31 #[serde(default)]
32 timeout_secs: Option<u64>,
33 #[serde(default)]
34 save_response: Option<bool>,
35}
36
37#[derive(Clone)]
39pub struct CurlTool {
40 client: Client,
41 temp_root: PathBuf,
42}
43
44impl CurlTool {
45 pub fn new() -> Self {
46 let client = Client::builder()
47 .redirect(reqwest::redirect::Policy::none())
48 .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
49 .user_agent("vtcode-sandboxed-curl/0.1")
50 .build()
51 .unwrap_or_else(|error| {
52 warn!(
53 ?error,
54 "Failed to build dedicated curl client; falling back to default"
55 );
56 Client::new()
57 });
58 let temp_root = std::env::temp_dir().join(TEMP_SUBDIR);
59 Self { client, temp_root }
60 }
61}
62
63impl Default for CurlTool {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl CurlTool {
70 fn write_temp_file(&self, data: &[u8]) -> Result<PathBuf> {
71 if !self.temp_root.exists() {
72 fs::create_dir_all(&self.temp_root)
73 .context("Failed to create temporary directory for curl tool")?;
74 }
75
76 let mut rng = rand::thread_rng();
77 let suffix: String = (&mut rng)
78 .sample_iter(&Alphanumeric)
79 .take(10)
80 .map(char::from)
81 .collect();
82
83 let path = self
84 .temp_root
85 .join(format!("response-{}.txt", suffix.to_lowercase()));
86 fs::write(&path, data)
87 .with_context(|| format!("Failed to write temporary file at {}", path.display()))?;
88 Ok(path)
89 }
90 async fn run(&self, raw_args: Value) -> Result<Value> {
91 let args: CurlToolArgs = serde_json::from_value(raw_args)
92 .context("Invalid arguments for curl tool. Provide an object with at least a 'url'.")?;
93
94 let method = self.normalize_method(args.method)?;
95 if method == Method::HEAD && args.save_response.unwrap_or(false) {
96 return Err(anyhow!(
97 "Cannot save a response body when performing a HEAD request. Set save_response=false or use GET."
98 ));
99 }
100
101 let url = Url::parse(&args.url).context("Invalid URL provided to curl tool")?;
102 self.validate_url(&url)?;
103
104 let timeout = args
105 .timeout_secs
106 .unwrap_or(DEFAULT_TIMEOUT_SECS)
107 .min(MAX_TIMEOUT_SECS);
108 let max_bytes = args
109 .max_bytes
110 .unwrap_or(DEFAULT_MAX_BYTES)
111 .min(DEFAULT_MAX_BYTES);
112
113 if max_bytes == 0 {
114 return Err(anyhow!("max_bytes must be greater than zero"));
115 }
116
117 let request = self
118 .client
119 .request(method.clone(), url.clone())
120 .timeout(Duration::from_secs(timeout))
121 .header(
122 reqwest::header::ACCEPT,
123 "text/plain, text/*, application/json, application/xml, application/yaml",
124 );
125
126 let response = request
127 .send()
128 .await
129 .with_context(|| format!("Failed to execute HTTPS request to {}", url))?;
130
131 let status = response.status();
132 if !status.is_success() {
133 return Err(anyhow!("Request returned non-success status: {}", status));
134 }
135
136 if let Some(length) = response.content_length()
137 && length > max_bytes as u64
138 {
139 return Err(anyhow!(
140 "Remote response is {} bytes which exceeds the policy limit of {} bytes",
141 length,
142 max_bytes
143 ));
144 }
145
146 let content_type = response
147 .headers()
148 .get(reqwest::header::CONTENT_TYPE)
149 .and_then(|value| value.to_str().ok())
150 .unwrap_or("")
151 .to_string();
152 self.validate_content_type(&content_type)?;
153
154 if method == Method::HEAD {
155 return Ok(json!({
156 "success": true,
157 "url": url.to_string(),
158 "status": status.as_u16(),
159 "content_type": content_type,
160 "content_length": response.content_length(),
161 "security_notice": SECURITY_NOTICE,
162 }));
163 }
164
165 let mut total_bytes: usize = 0;
166 let mut buffer: Vec<u8> = Vec::new();
167 let mut truncated = false;
168
169 let mut stream = response.bytes_stream();
170 while let Some(chunk) = stream.next().await {
171 let bytes =
172 chunk.with_context(|| format!("Failed to read response chunk from {}", url))?;
173 total_bytes = total_bytes.saturating_add(bytes.len());
174 if buffer.len() < max_bytes {
175 let remaining = max_bytes - buffer.len();
176 if bytes.len() > remaining {
177 buffer.extend_from_slice(&bytes[..remaining]);
178 truncated = true;
179 } else {
180 buffer.extend_from_slice(&bytes);
181 }
182 } else {
183 truncated = true;
184 }
185 if buffer.len() >= max_bytes {
186 truncated = true;
187 break;
188 }
189 }
190
191 let body_text = String::from_utf8_lossy(&buffer).to_string();
192 let saved_path = if args.save_response.unwrap_or(false) && !buffer.is_empty() {
193 Some(self.write_temp_file(&buffer)?)
194 } else {
195 None
196 };
197
198 let saved_path_str = saved_path.as_ref().map(|path| path.display().to_string());
199 let cleanup_hint = saved_path
200 .as_ref()
201 .map(|path| format!("rm {}", path.display()));
202
203 Ok(json!({
204 "success": true,
205 "url": url.to_string(),
206 "status": status.as_u16(),
207 "content_type": content_type,
208 "bytes_read": total_bytes,
209 "body": body_text,
210 "truncated": truncated,
211 "saved_path": saved_path_str,
212 "cleanup_hint": cleanup_hint,
213 "security_notice": SECURITY_NOTICE,
214 }))
215 }
216
217 fn normalize_method(&self, method: Option<String>) -> Result<Method> {
218 let requested = method.unwrap_or_else(|| "GET".to_string());
219 let normalized = requested.trim().to_uppercase();
220 match normalized.as_str() {
221 "GET" => Ok(Method::GET),
222 "HEAD" => Ok(Method::HEAD),
223 other => Err(anyhow!(
224 "HTTP method '{}' is not permitted. Only GET or HEAD are allowed.",
225 other
226 )),
227 }
228 }
229
230 fn validate_url(&self, url: &Url) -> Result<()> {
231 if url.scheme() != "https" {
232 return Err(anyhow!("Only HTTPS URLs are allowed"));
233 }
234
235 if !url.username().is_empty() || url.password().is_some() {
236 return Err(anyhow!("Credentials in URLs are not supported"));
237 }
238
239 let host = url
240 .host_str()
241 .ok_or_else(|| anyhow!("URL must include a host"))?
242 .to_lowercase();
243
244 if host.parse::<IpAddr>().is_ok() {
245 return Err(anyhow!("IP address targets are blocked for security"));
246 }
247
248 let forbidden_hosts = ["localhost", "127.0.0.1", "0.0.0.0", "::1"];
249
250 if forbidden_hosts
251 .iter()
252 .any(|blocked| host == *blocked || host.ends_with(&format!(".{}", blocked)))
253 {
254 return Err(anyhow!("Access to local or loopback hosts is blocked"));
255 }
256
257 let forbidden_suffixes = [".localhost", ".local", ".internal", ".lan"];
258 if forbidden_suffixes
259 .iter()
260 .any(|suffix| host.ends_with(suffix))
261 {
262 return Err(anyhow!("Private network hosts are not permitted"));
263 }
264
265 if let Some(port) = url.port()
266 && port != 443
267 {
268 return Err(anyhow!("Custom HTTPS ports are blocked by policy"));
269 }
270
271 Ok(())
272 }
273
274 fn validate_content_type(&self, content_type: &str) -> Result<()> {
275 if content_type.is_empty() {
276 return Ok(());
277 }
278 let lowered = content_type.to_lowercase();
279 let allowed = lowered.starts_with("text/")
280 || lowered.contains("json")
281 || lowered.contains("xml")
282 || lowered.contains("yaml")
283 || lowered.contains("toml")
284 || lowered.contains("javascript");
285 if allowed {
286 Ok(())
287 } else {
288 Err(anyhow!(
289 "Content type '{}' is not allowed. Only text or structured text responses are supported.",
290 content_type
291 ))
292 }
293 }
294}
295
296#[async_trait]
297impl Tool for CurlTool {
298 async fn execute(&self, args: Value) -> Result<Value> {
299 self.run(args).await
300 }
301
302 fn name(&self) -> &'static str {
303 tools::CURL
304 }
305
306 fn description(&self) -> &'static str {
307 "Fetches HTTPS text content with strict validation and security notices."
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use serde_json::json;
315
316 #[tokio::test]
317 async fn rejects_non_https_urls() {
318 let tool = CurlTool::new();
319 let result = tool
320 .execute(json!({
321 "url": "http://example.com"
322 }))
323 .await;
324 assert!(result.is_err());
325 }
326
327 #[tokio::test]
328 async fn rejects_local_targets() {
329 let tool = CurlTool::new();
330 let result = tool
331 .execute(json!({
332 "url": "https://localhost/resource"
333 }))
334 .await;
335 assert!(result.is_err());
336 }
337
338 #[tokio::test]
339 async fn rejects_disallowed_methods() {
340 let tool = CurlTool::new();
341 let result = tool
342 .execute(json!({
343 "url": "https://example.com/resource",
344 "method": "POST"
345 }))
346 .await;
347 assert!(result.is_err());
348 }
349}