1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4use std::time::Duration;
5
6use futures_util::StreamExt;
7use reqwest::{Client, Response};
8use serde::{Deserialize, Serialize};
9use url::Url;
10
11use crate::error::{Error, Result};
12
13const DEFAULT_MAX_REDIRECTS: usize = 10;
14const DEFAULT_TIMEOUT_SECS: u64 = 30;
15const DEFAULT_MAX_BODY_SIZE: usize = 10 * 1024 * 1024; #[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct HttpInfo {
20 pub url: String,
22
23 pub status_code: u16,
25
26 pub headers: Vec<(String, String)>,
28
29 pub content_type: Option<String>,
31
32 pub redirect_count: u32,
37
38 pub body: String,
40}
41
42#[derive(Debug, Clone)]
44pub struct HttpOptions {
45 pub allow_insecure: bool,
50
51 pub follow_redirects: bool,
53
54 pub max_redirects: usize,
56
57 pub timeout: Duration,
59
60 pub max_body_size: usize,
65
66 pub block_private_ips: bool,
72
73 pub user_agent: String,
75
76 pub headers: Vec<(String, String)>,
78}
79
80impl Default for HttpOptions {
81 fn default() -> Self {
82 Self {
83 allow_insecure: false,
84 follow_redirects: true,
85 max_redirects: DEFAULT_MAX_REDIRECTS,
86 timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
87 max_body_size: DEFAULT_MAX_BODY_SIZE,
88 block_private_ips: true,
89 user_agent: format!(
90 "webpage-info/{} (https://crates.io/crates/webpage-info)",
91 env!("CARGO_PKG_VERSION")
92 ),
93 headers: Vec::new(),
94 }
95 }
96}
97
98impl HttpOptions {
99 pub fn new() -> Self {
101 Self::default()
102 }
103
104 pub fn allow_insecure(mut self, allow: bool) -> Self {
106 self.allow_insecure = allow;
107 self
108 }
109
110 pub fn follow_redirects(mut self, follow: bool) -> Self {
112 self.follow_redirects = follow;
113 self
114 }
115
116 pub fn max_redirects(mut self, max: usize) -> Self {
118 self.max_redirects = max;
119 self
120 }
121
122 pub fn timeout(mut self, timeout: Duration) -> Self {
124 self.timeout = timeout;
125 self
126 }
127
128 pub fn max_body_size(mut self, size: usize) -> Self {
132 self.max_body_size = size;
133 self
134 }
135
136 pub fn block_private_ips(mut self, block: bool) -> Self {
141 self.block_private_ips = block;
142 self
143 }
144
145 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
147 self.user_agent = user_agent.into();
148 self
149 }
150
151 pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
153 self.headers.push((name.into(), value.into()));
154 self
155 }
156
157 fn build_client(&self) -> Result<Client> {
159 let redirect_policy = if self.follow_redirects {
160 reqwest::redirect::Policy::limited(self.max_redirects)
161 } else {
162 reqwest::redirect::Policy::none()
163 };
164
165 let mut builder = Client::builder()
166 .danger_accept_invalid_certs(self.allow_insecure)
167 .redirect(redirect_policy)
168 .timeout(self.timeout)
169 .user_agent(&self.user_agent);
170
171 let mut headers = reqwest::header::HeaderMap::new();
173 for (name, value) in &self.headers {
174 if let (Ok(name), Ok(value)) = (
175 name.parse::<reqwest::header::HeaderName>(),
176 value.parse::<reqwest::header::HeaderValue>(),
177 ) {
178 headers.insert(name, value);
179 }
180 }
181 builder = builder.default_headers(headers);
182
183 Ok(builder.build()?)
184 }
185}
186
187fn is_private_ipv4(ip: Ipv4Addr) -> bool {
189 ip.is_loopback() || ip.is_private() || ip.is_link_local() || ip.is_broadcast() || ip.is_unspecified() || ip.is_documentation() || ip.octets()[0] == 0 || ip.octets()[0] >= 224 }
198
199fn is_private_ipv6(ip: Ipv6Addr) -> bool {
201 ip.is_loopback() || ip.is_unspecified() || ip.is_multicast() || ip.to_ipv4_mapped().is_some_and(is_private_ipv4)
206 || (ip.segments()[0] & 0xfe00) == 0xfc00
208 || (ip.segments()[0] & 0xffc0) == 0xfe80
210}
211
212fn is_private_ip(ip: IpAddr) -> bool {
214 match ip {
215 IpAddr::V4(v4) => is_private_ipv4(v4),
216 IpAddr::V6(v6) => is_private_ipv6(v6),
217 }
218}
219
220async fn validate_url_for_ssrf(url: &str) -> Result<()> {
222 let parsed = Url::parse(url).map_err(|e| Error::InvalidUrl(e.to_string()))?;
223
224 match parsed.scheme() {
226 "http" | "https" => {}
227 scheme => {
228 return Err(Error::InvalidUrl(format!(
229 "unsupported scheme '{}', only http/https allowed",
230 scheme
231 )));
232 }
233 }
234
235 let host = parsed
236 .host_str()
237 .ok_or_else(|| Error::InvalidUrl("missing host".to_string()))?;
238
239 let host_lower = host.to_lowercase();
241 if host_lower == "localhost"
242 || host_lower.ends_with(".local")
243 || host_lower.ends_with(".internal")
244 || host_lower == "metadata.google.internal"
245 {
246 return Err(Error::SsrfBlocked(format!(
247 "blocked request to internal host: {}",
248 host
249 )));
250 }
251
252 let port = parsed.port().unwrap_or(match parsed.scheme() {
254 "https" => 443,
255 _ => 80,
256 });
257
258 let addr_str = format!("{}:{}", host, port);
259 if let Ok(addrs) = tokio::net::lookup_host(&addr_str).await {
260 for addr in addrs {
261 if is_private_ip(addr.ip()) {
262 return Err(Error::SsrfBlocked(format!(
263 "blocked request to private IP: {} (resolved from {})",
264 addr.ip(),
265 host
266 )));
267 }
268 }
269 }
270 Ok(())
273}
274
275pub async fn fetch(url: &str, options: &HttpOptions) -> Result<HttpInfo> {
277 if options.block_private_ips {
279 validate_url_for_ssrf(url).await?;
280 }
281
282 let client = options.build_client()?;
283 let response = client.get(url).send().await?;
284
285 response_to_info(response, options.max_body_size).await
286}
287
288async fn response_to_info(response: Response, max_body_size: usize) -> Result<HttpInfo> {
290 let url = response.url().to_string();
291 let status_code = response.status().as_u16();
292
293 let content_type = response
294 .headers()
295 .get(reqwest::header::CONTENT_TYPE)
296 .and_then(|v| v.to_str().ok())
297 .map(|s| {
298 s.split(';').next().unwrap_or(s).trim().to_string()
300 });
301
302 let headers: Vec<(String, String)> = response
303 .headers()
304 .iter()
305 .filter_map(|(name, value)| {
306 value
307 .to_str()
308 .ok()
309 .map(|v| (name.to_string(), v.to_string()))
310 })
311 .collect();
312
313 let content_length = response.content_length().unwrap_or(0) as usize;
315 let capacity = content_length.min(max_body_size).min(1024 * 1024); let mut bytes = Vec::with_capacity(capacity);
317 let mut stream = response.bytes_stream();
318
319 while let Some(chunk) = stream.next().await {
320 let chunk = chunk?;
321 let remaining = max_body_size.saturating_sub(bytes.len());
322 if remaining == 0 {
323 break;
324 }
325 let to_take = chunk.len().min(remaining);
326 bytes.extend_from_slice(&chunk[..to_take]);
327 if to_take < chunk.len() {
328 break; }
330 }
331
332 let body = String::from_utf8_lossy(&bytes).into_owned();
333
334 Ok(HttpInfo {
335 url,
336 status_code,
337 headers,
338 content_type,
339 redirect_count: 0,
340 body,
341 })
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_default_options() {
350 let options = HttpOptions::default();
351 assert!(!options.allow_insecure);
352 assert!(options.follow_redirects);
353 assert_eq!(options.max_redirects, DEFAULT_MAX_REDIRECTS);
354 assert_eq!(options.timeout, Duration::from_secs(DEFAULT_TIMEOUT_SECS));
355 assert_eq!(options.max_body_size, DEFAULT_MAX_BODY_SIZE);
356 assert!(options.block_private_ips);
357 assert!(options.user_agent.contains("webpage-info"));
358 }
359
360 #[test]
361 fn test_builder_pattern() {
362 let options = HttpOptions::new()
363 .allow_insecure(true)
364 .follow_redirects(false)
365 .max_redirects(5)
366 .timeout(Duration::from_secs(60))
367 .max_body_size(1024)
368 .block_private_ips(false)
369 .user_agent("Custom Agent")
370 .header("X-Custom", "Value");
371
372 assert!(options.allow_insecure);
373 assert!(!options.follow_redirects);
374 assert_eq!(options.max_redirects, 5);
375 assert_eq!(options.timeout, Duration::from_secs(60));
376 assert_eq!(options.max_body_size, 1024);
377 assert!(!options.block_private_ips);
378 assert_eq!(options.user_agent, "Custom Agent");
379 assert_eq!(options.headers.len(), 1);
380 }
381
382 #[tokio::test]
383 async fn test_ssrf_blocks_localhost() {
384 let result = validate_url_for_ssrf("http://localhost/").await;
385 assert!(result.is_err());
386 assert!(result.unwrap_err().to_string().contains("internal host"));
387 }
388
389 #[tokio::test]
390 async fn test_ssrf_blocks_private_ip() {
391 let result = validate_url_for_ssrf("http://192.168.1.1/").await;
392 assert!(result.is_err());
393 assert!(result.unwrap_err().to_string().contains("private IP"));
394 }
395
396 #[tokio::test]
397 async fn test_ssrf_blocks_loopback() {
398 let result = validate_url_for_ssrf("http://127.0.0.1/").await;
399 assert!(result.is_err());
400 }
401
402 #[tokio::test]
403 async fn test_ssrf_blocks_metadata_endpoint() {
404 let result = validate_url_for_ssrf("http://169.254.169.254/").await;
406 assert!(result.is_err());
407 }
408
409 #[tokio::test]
410 async fn test_ssrf_blocks_internal_domain() {
411 let result = validate_url_for_ssrf("http://server.local/").await;
412 assert!(result.is_err());
413 }
414
415 #[tokio::test]
416 async fn test_ssrf_blocks_file_scheme() {
417 let result = validate_url_for_ssrf("file:///etc/passwd").await;
418 assert!(result.is_err());
419 assert!(
420 result
421 .unwrap_err()
422 .to_string()
423 .contains("unsupported scheme")
424 );
425 }
426
427 #[tokio::test]
428 async fn test_ssrf_allows_public_urls() {
429 let result = validate_url_for_ssrf("https://example.com/").await;
431 assert!(result.is_ok());
432 }
433
434 #[test]
435 fn test_private_ipv4_detection() {
436 assert!(is_private_ipv4(Ipv4Addr::new(127, 0, 0, 1)));
437 assert!(is_private_ipv4(Ipv4Addr::new(10, 0, 0, 1)));
438 assert!(is_private_ipv4(Ipv4Addr::new(172, 16, 0, 1)));
439 assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 1, 1)));
440 assert!(is_private_ipv4(Ipv4Addr::new(169, 254, 169, 254)));
441 assert!(is_private_ipv4(Ipv4Addr::new(0, 0, 0, 0)));
442 assert!(!is_private_ipv4(Ipv4Addr::new(8, 8, 8, 8)));
443 assert!(!is_private_ipv4(Ipv4Addr::new(93, 184, 216, 34)));
444 }
445
446 #[test]
447 fn test_private_ipv6_detection() {
448 assert!(is_private_ipv6(Ipv6Addr::LOCALHOST));
449 assert!(is_private_ipv6(Ipv6Addr::UNSPECIFIED));
450 assert!(is_private_ipv6("fe80::1".parse().unwrap()));
452 assert!(is_private_ipv6("fc00::1".parse().unwrap()));
454 assert!(!is_private_ipv6(
456 "2607:f8b0:4004:800::200e".parse().unwrap()
457 ));
458 }
459}