webpage_info/
http.rs

1//! HTTP client for fetching web pages
2
3use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
4use std::time::Duration;
5
6use futures_util::StreamExt;
7use reqwest::{Client, Response};
8use serde::{Deserialize, Serialize};
9use url::Url;
10
11use crate::error::{Error, Result};
12
13const DEFAULT_MAX_REDIRECTS: usize = 10;
14const DEFAULT_TIMEOUT_SECS: u64 = 30;
15const DEFAULT_MAX_BODY_SIZE: usize = 10 * 1024 * 1024; // 10 MB
16
17/// HTTP response information.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct HttpInfo {
20    /// The final URL after following redirects
21    pub url: String,
22
23    /// HTTP status code
24    pub status_code: u16,
25
26    /// Response headers
27    pub headers: Vec<(String, String)>,
28
29    /// Content-Type header value
30    pub content_type: Option<String>,
31
32    /// Number of redirects followed.
33    ///
34    /// Note: This is currently always 0 as reqwest doesn't expose redirect count directly.
35    /// The field is retained for API compatibility and potential future implementation.
36    pub redirect_count: u32,
37
38    /// Response body as string
39    pub body: String,
40}
41
42/// Configuration for HTTP requests.
43#[derive(Debug, Clone)]
44pub struct HttpOptions {
45    /// Allow insecure HTTPS connections (self-signed certs).
46    ///
47    /// **Security Warning:** Enabling this allows man-in-the-middle attacks.
48    /// Only use for testing or when connecting to known self-signed services.
49    pub allow_insecure: bool,
50
51    /// Follow HTTP redirects
52    pub follow_redirects: bool,
53
54    /// Maximum number of redirects to follow
55    pub max_redirects: usize,
56
57    /// Request timeout
58    pub timeout: Duration,
59
60    /// Maximum response body size in bytes.
61    ///
62    /// Responses larger than this will be truncated to prevent memory exhaustion.
63    /// Default: 10 MB.
64    pub max_body_size: usize,
65
66    /// Block requests to private/internal IP addresses (SSRF protection).
67    ///
68    /// When enabled, requests to localhost, private networks (10.x, 172.16-31.x, 192.168.x),
69    /// link-local addresses, and cloud metadata endpoints (169.254.x) are blocked.
70    /// Default: true.
71    pub block_private_ips: bool,
72
73    /// User-Agent header
74    pub user_agent: String,
75
76    /// Additional headers to send
77    pub headers: Vec<(String, String)>,
78}
79
80impl Default for HttpOptions {
81    fn default() -> Self {
82        Self {
83            allow_insecure: false,
84            follow_redirects: true,
85            max_redirects: DEFAULT_MAX_REDIRECTS,
86            timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
87            max_body_size: DEFAULT_MAX_BODY_SIZE,
88            block_private_ips: true,
89            user_agent: format!(
90                "webpage-info/{} (https://crates.io/crates/webpage-info)",
91                env!("CARGO_PKG_VERSION")
92            ),
93            headers: Vec::new(),
94        }
95    }
96}
97
98impl HttpOptions {
99    /// Create a new HttpOptions with default settings.
100    pub fn new() -> Self {
101        Self::default()
102    }
103
104    /// Set whether to allow insecure HTTPS connections.
105    pub fn allow_insecure(mut self, allow: bool) -> Self {
106        self.allow_insecure = allow;
107        self
108    }
109
110    /// Set whether to follow redirects.
111    pub fn follow_redirects(mut self, follow: bool) -> Self {
112        self.follow_redirects = follow;
113        self
114    }
115
116    /// Set the maximum number of redirects to follow.
117    pub fn max_redirects(mut self, max: usize) -> Self {
118        self.max_redirects = max;
119        self
120    }
121
122    /// Set the request timeout.
123    pub fn timeout(mut self, timeout: Duration) -> Self {
124        self.timeout = timeout;
125        self
126    }
127
128    /// Set the maximum response body size in bytes.
129    ///
130    /// Responses larger than this will be truncated.
131    pub fn max_body_size(mut self, size: usize) -> Self {
132        self.max_body_size = size;
133        self
134    }
135
136    /// Set whether to block requests to private/internal IP addresses.
137    ///
138    /// **Security Note:** Disabling this exposes your application to SSRF attacks
139    /// if URLs come from untrusted sources.
140    pub fn block_private_ips(mut self, block: bool) -> Self {
141        self.block_private_ips = block;
142        self
143    }
144
145    /// Set the User-Agent header.
146    pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
147        self.user_agent = user_agent.into();
148        self
149    }
150
151    /// Add a custom header.
152    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
153        self.headers.push((name.into(), value.into()));
154        self
155    }
156
157    /// Build a reqwest Client from these options.
158    fn build_client(&self) -> Result<Client> {
159        let redirect_policy = if self.follow_redirects {
160            reqwest::redirect::Policy::limited(self.max_redirects)
161        } else {
162            reqwest::redirect::Policy::none()
163        };
164
165        let mut builder = Client::builder()
166            .danger_accept_invalid_certs(self.allow_insecure)
167            .redirect(redirect_policy)
168            .timeout(self.timeout)
169            .user_agent(&self.user_agent);
170
171        // Add default headers
172        let mut headers = reqwest::header::HeaderMap::new();
173        for (name, value) in &self.headers {
174            if let (Ok(name), Ok(value)) = (
175                name.parse::<reqwest::header::HeaderName>(),
176                value.parse::<reqwest::header::HeaderValue>(),
177            ) {
178                headers.insert(name, value);
179            }
180        }
181        builder = builder.default_headers(headers);
182
183        Ok(builder.build()?)
184    }
185}
186
187/// Check if an IPv4 address is private/internal.
188fn is_private_ipv4(ip: Ipv4Addr) -> bool {
189    ip.is_loopback()                           // 127.0.0.0/8
190        || ip.is_private()                     // 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
191        || ip.is_link_local()                  // 169.254.0.0/16 (includes cloud metadata)
192        || ip.is_broadcast()                   // 255.255.255.255
193        || ip.is_unspecified()                 // 0.0.0.0
194        || ip.is_documentation()              // 192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24
195        || ip.octets()[0] == 0                // 0.0.0.0/8
196        || ip.octets()[0] >= 224 // Multicast and reserved (224.0.0.0+)
197}
198
199/// Check if an IPv6 address is private/internal.
200fn is_private_ipv6(ip: Ipv6Addr) -> bool {
201    ip.is_loopback()                           // ::1
202        || ip.is_unspecified()                 // ::
203        || ip.is_multicast()                   // ff00::/8
204        // IPv4-mapped addresses (::ffff:0:0/96)
205        || ip.to_ipv4_mapped().is_some_and(is_private_ipv4)
206        // Unique local (fc00::/7)
207        || (ip.segments()[0] & 0xfe00) == 0xfc00
208        // Link-local (fe80::/10)
209        || (ip.segments()[0] & 0xffc0) == 0xfe80
210}
211
212/// Check if an IP address is private/internal.
213fn is_private_ip(ip: IpAddr) -> bool {
214    match ip {
215        IpAddr::V4(v4) => is_private_ipv4(v4),
216        IpAddr::V6(v6) => is_private_ipv6(v6),
217    }
218}
219
220/// Validate URL for SSRF protection (async DNS resolution).
221async fn validate_url_for_ssrf(url: &str) -> Result<()> {
222    let parsed = Url::parse(url).map_err(|e| Error::InvalidUrl(e.to_string()))?;
223
224    // Only allow http and https schemes
225    match parsed.scheme() {
226        "http" | "https" => {}
227        scheme => {
228            return Err(Error::InvalidUrl(format!(
229                "unsupported scheme '{}', only http/https allowed",
230                scheme
231            )));
232        }
233    }
234
235    let host = parsed
236        .host_str()
237        .ok_or_else(|| Error::InvalidUrl("missing host".to_string()))?;
238
239    // Block obviously dangerous hostnames
240    let host_lower = host.to_lowercase();
241    if host_lower == "localhost"
242        || host_lower.ends_with(".local")
243        || host_lower.ends_with(".internal")
244        || host_lower == "metadata.google.internal"
245    {
246        return Err(Error::SsrfBlocked(format!(
247            "blocked request to internal host: {}",
248            host
249        )));
250    }
251
252    // Resolve hostname and check all IP addresses (async to avoid blocking runtime)
253    let port = parsed.port().unwrap_or(match parsed.scheme() {
254        "https" => 443,
255        _ => 80,
256    });
257
258    let addr_str = format!("{}:{}", host, port);
259    if let Ok(addrs) = tokio::net::lookup_host(&addr_str).await {
260        for addr in addrs {
261            if is_private_ip(addr.ip()) {
262                return Err(Error::SsrfBlocked(format!(
263                    "blocked request to private IP: {} (resolved from {})",
264                    addr.ip(),
265                    host
266                )));
267            }
268        }
269    }
270    // If DNS resolution fails, let reqwest handle it (might be a valid external host)
271
272    Ok(())
273}
274
275/// Fetch a URL and return HTTP information.
276pub async fn fetch(url: &str, options: &HttpOptions) -> Result<HttpInfo> {
277    // SSRF protection: validate URL before making request
278    if options.block_private_ips {
279        validate_url_for_ssrf(url).await?;
280    }
281
282    let client = options.build_client()?;
283    let response = client.get(url).send().await?;
284
285    response_to_info(response, options.max_body_size).await
286}
287
288/// Convert a reqwest Response to HttpInfo with streaming body size limit.
289async fn response_to_info(response: Response, max_body_size: usize) -> Result<HttpInfo> {
290    let url = response.url().to_string();
291    let status_code = response.status().as_u16();
292
293    let content_type = response
294        .headers()
295        .get(reqwest::header::CONTENT_TYPE)
296        .and_then(|v| v.to_str().ok())
297        .map(|s| {
298            // Extract just the mime type, not charset
299            s.split(';').next().unwrap_or(s).trim().to_string()
300        });
301
302    let headers: Vec<(String, String)> = response
303        .headers()
304        .iter()
305        .filter_map(|(name, value)| {
306            value
307                .to_str()
308                .ok()
309                .map(|v| (name.to_string(), v.to_string()))
310        })
311        .collect();
312
313    // Stream body with size limit - stops downloading when limit reached
314    let content_length = response.content_length().unwrap_or(0) as usize;
315    let capacity = content_length.min(max_body_size).min(1024 * 1024); // Cap initial alloc at 1MB
316    let mut bytes = Vec::with_capacity(capacity);
317    let mut stream = response.bytes_stream();
318
319    while let Some(chunk) = stream.next().await {
320        let chunk = chunk?;
321        let remaining = max_body_size.saturating_sub(bytes.len());
322        if remaining == 0 {
323            break;
324        }
325        let to_take = chunk.len().min(remaining);
326        bytes.extend_from_slice(&chunk[..to_take]);
327        if to_take < chunk.len() {
328            break; // Hit the limit
329        }
330    }
331
332    let body = String::from_utf8_lossy(&bytes).into_owned();
333
334    Ok(HttpInfo {
335        url,
336        status_code,
337        headers,
338        content_type,
339        redirect_count: 0,
340        body,
341    })
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_default_options() {
350        let options = HttpOptions::default();
351        assert!(!options.allow_insecure);
352        assert!(options.follow_redirects);
353        assert_eq!(options.max_redirects, DEFAULT_MAX_REDIRECTS);
354        assert_eq!(options.timeout, Duration::from_secs(DEFAULT_TIMEOUT_SECS));
355        assert_eq!(options.max_body_size, DEFAULT_MAX_BODY_SIZE);
356        assert!(options.block_private_ips);
357        assert!(options.user_agent.contains("webpage-info"));
358    }
359
360    #[test]
361    fn test_builder_pattern() {
362        let options = HttpOptions::new()
363            .allow_insecure(true)
364            .follow_redirects(false)
365            .max_redirects(5)
366            .timeout(Duration::from_secs(60))
367            .max_body_size(1024)
368            .block_private_ips(false)
369            .user_agent("Custom Agent")
370            .header("X-Custom", "Value");
371
372        assert!(options.allow_insecure);
373        assert!(!options.follow_redirects);
374        assert_eq!(options.max_redirects, 5);
375        assert_eq!(options.timeout, Duration::from_secs(60));
376        assert_eq!(options.max_body_size, 1024);
377        assert!(!options.block_private_ips);
378        assert_eq!(options.user_agent, "Custom Agent");
379        assert_eq!(options.headers.len(), 1);
380    }
381
382    #[tokio::test]
383    async fn test_ssrf_blocks_localhost() {
384        let result = validate_url_for_ssrf("http://localhost/").await;
385        assert!(result.is_err());
386        assert!(result.unwrap_err().to_string().contains("internal host"));
387    }
388
389    #[tokio::test]
390    async fn test_ssrf_blocks_private_ip() {
391        let result = validate_url_for_ssrf("http://192.168.1.1/").await;
392        assert!(result.is_err());
393        assert!(result.unwrap_err().to_string().contains("private IP"));
394    }
395
396    #[tokio::test]
397    async fn test_ssrf_blocks_loopback() {
398        let result = validate_url_for_ssrf("http://127.0.0.1/").await;
399        assert!(result.is_err());
400    }
401
402    #[tokio::test]
403    async fn test_ssrf_blocks_metadata_endpoint() {
404        // AWS/GCP metadata endpoint
405        let result = validate_url_for_ssrf("http://169.254.169.254/").await;
406        assert!(result.is_err());
407    }
408
409    #[tokio::test]
410    async fn test_ssrf_blocks_internal_domain() {
411        let result = validate_url_for_ssrf("http://server.local/").await;
412        assert!(result.is_err());
413    }
414
415    #[tokio::test]
416    async fn test_ssrf_blocks_file_scheme() {
417        let result = validate_url_for_ssrf("file:///etc/passwd").await;
418        assert!(result.is_err());
419        assert!(
420            result
421                .unwrap_err()
422                .to_string()
423                .contains("unsupported scheme")
424        );
425    }
426
427    #[tokio::test]
428    async fn test_ssrf_allows_public_urls() {
429        // Note: This test does DNS resolution, so it needs network access
430        let result = validate_url_for_ssrf("https://example.com/").await;
431        assert!(result.is_ok());
432    }
433
434    #[test]
435    fn test_private_ipv4_detection() {
436        assert!(is_private_ipv4(Ipv4Addr::new(127, 0, 0, 1)));
437        assert!(is_private_ipv4(Ipv4Addr::new(10, 0, 0, 1)));
438        assert!(is_private_ipv4(Ipv4Addr::new(172, 16, 0, 1)));
439        assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 1, 1)));
440        assert!(is_private_ipv4(Ipv4Addr::new(169, 254, 169, 254)));
441        assert!(is_private_ipv4(Ipv4Addr::new(0, 0, 0, 0)));
442        assert!(!is_private_ipv4(Ipv4Addr::new(8, 8, 8, 8)));
443        assert!(!is_private_ipv4(Ipv4Addr::new(93, 184, 216, 34)));
444    }
445
446    #[test]
447    fn test_private_ipv6_detection() {
448        assert!(is_private_ipv6(Ipv6Addr::LOCALHOST));
449        assert!(is_private_ipv6(Ipv6Addr::UNSPECIFIED));
450        // Link-local
451        assert!(is_private_ipv6("fe80::1".parse().unwrap()));
452        // Unique local
453        assert!(is_private_ipv6("fc00::1".parse().unwrap()));
454        // Public
455        assert!(!is_private_ipv6(
456            "2607:f8b0:4004:800::200e".parse().unwrap()
457        ));
458    }
459}