Skip to main content

webfetch/
fetch.rs

1use reqwest::header::{CONTENT_TYPE, LOCATION};
2use reqwest::{redirect::Policy, Client, Response};
3use std::net::SocketAddr;
4use std::time::Duration;
5
6use crate::guard;
7use crate::tls::TlsConfig;
8
9const USER_AGENT: &str = concat!("webfetch/", env!("CARGO_PKG_VERSION"));
10const MAX_ATTEMPTS: u32 = 3;
11const MAX_REDIRECTS: usize = 5;
12
13/// Hard cap on the response body we will read (5 MiB). The HTML extractor turns
14/// a page into a few KB of text, so a multi-megabyte body is almost never worth
15/// the bandwidth, memory, and parse time — and an unbounded read is a DoS lever.
16/// Bodies over the cap are *truncated* (not errored): partial content is still
17/// useful and the extractor copes with truncated HTML.
18const MAX_BODY_BYTES: usize = 5 * 1024 * 1024;
19
20/// Outcome of an HTTP fetch: the body, the URL we actually landed on after
21/// following redirects, and the response's `Content-Type` (if any).
22pub struct FetchedPage {
23    pub body: String,
24    pub final_url: String,
25    pub content_type: Option<String>,
26}
27
28/// One hop's result: either the final page, or a redirect to a raw `Location`.
29enum Hop {
30    Page(FetchedPage),
31    Redirect(String),
32}
33
34/// Build a client for a single validated URL. `pinned` are the public IPs the
35/// host already resolved to; binding them closes the DNS-rebinding window
36/// between validation and connection.
37///
38/// Redirects are **not** followed by reqwest here ([`Policy::none`]): we follow
39/// them manually in [`fetch_page`] so every hop is re-validated *and* pinned to
40/// its own resolved addresses. (Reqwest's `resolve_to_addrs` pins only the
41/// hosts known at build time, so auto-follow would leave redirect hops
42/// unpinned.) A consequence is that connection pooling cannot be shared across
43/// hosts via one long-lived client without weakening per-URL IP pinning, so we
44/// deliberately do not cache clients — SSRF safety wins over pool reuse.
45fn build_client(
46    url: &reqwest::Url,
47    timeout_secs: u64,
48    pinned: &[SocketAddr],
49    tls: &TlsConfig,
50) -> anyhow::Result<Client> {
51    let mut builder = Client::builder()
52        .timeout(Duration::from_secs(timeout_secs))
53        .redirect(Policy::none())
54        .user_agent(USER_AGENT)
55        .gzip(true)
56        .brotli(true);
57
58    // Trust the OS store (+ SSL_CERT_FILE / --ca-cert) so org/proxy root CAs
59    // are accepted, instead of only the bundled webpki roots.
60    builder = tls.apply(builder)?;
61
62    if let Some(host) = url.host_str() {
63        if !pinned.is_empty() {
64            builder = builder.resolve_to_addrs(host, pinned);
65        }
66    }
67    Ok(builder.build()?)
68}
69
70/// Append as much of `chunk` to `buf` as fits under `max`. Returns `true` once
71/// the cap is reached (the body is truncated and the caller should stop).
72fn push_capped(buf: &mut Vec<u8>, chunk: &[u8], max: usize) -> bool {
73    let remaining = max.saturating_sub(buf.len());
74    if chunk.len() >= remaining {
75        buf.extend_from_slice(&chunk[..remaining]);
76        true
77    } else {
78        buf.extend_from_slice(chunk);
79        false
80    }
81}
82
83/// Read a response body, streaming chunks with a running byte cap so an
84/// oversized body is bounded before it is ever DOM-parsed. The `bool` reports
85/// whether a read error is transient (worth retrying).
86async fn read_body_capped(mut resp: Response) -> Result<String, (anyhow::Error, bool)> {
87    let mut buf: Vec<u8> = Vec::new();
88    // Honour Content-Length to pre-size, but never trust it past the cap.
89    if let Some(len) = resp.content_length() {
90        buf.reserve(len.min(MAX_BODY_BYTES as u64) as usize);
91    }
92    loop {
93        match resp.chunk().await {
94            Ok(Some(chunk)) => {
95                if push_capped(&mut buf, &chunk, MAX_BODY_BYTES) {
96                    break;
97                }
98            }
99            Ok(None) => break,
100            Err(e) => {
101                let transient = e.is_timeout();
102                return Err((e.into(), transient));
103            }
104        }
105    }
106    Ok(String::from_utf8_lossy(&buf).into_owned())
107}
108
109/// One request attempt. The bool in the error reports whether the failure is
110/// transient (worth retrying): connection/timeout errors, 5xx, and 429.
111async fn attempt(client: &Client, url: &str) -> Result<Hop, (anyhow::Error, bool)> {
112    let resp = match client.get(url).send().await {
113        Ok(r) => r,
114        Err(e) => {
115            let transient = e.is_timeout() || e.is_connect() || e.is_request();
116            return Err((e.into(), transient));
117        }
118    };
119
120    let status = resp.status();
121
122    // Redirects are surfaced to the caller (which re-validates and pins the
123    // target) rather than followed by reqwest.
124    if status.is_redirection() {
125        return match resp.headers().get(LOCATION).and_then(|v| v.to_str().ok()) {
126            Some(loc) => Ok(Hop::Redirect(loc.to_string())),
127            None => Err((
128                anyhow::anyhow!("redirect ({status}) without a Location header"),
129                false,
130            )),
131        };
132    }
133
134    let resp = match resp.error_for_status() {
135        Ok(r) => r,
136        Err(e) => {
137            let transient = status.is_server_error() || status.as_u16() == 429;
138            return Err((e.into(), transient));
139        }
140    };
141
142    let final_url = resp.url().to_string();
143    let content_type = resp
144        .headers()
145        .get(CONTENT_TYPE)
146        .and_then(|v| v.to_str().ok())
147        .map(|s| s.to_string());
148
149    let body = read_body_capped(resp).await?;
150    Ok(Hop::Page(FetchedPage {
151        body,
152        final_url,
153        content_type,
154    }))
155}
156
157/// Issue one hop's request, retrying transient failures with exponential
158/// backoff (200ms, 400ms).
159async fn fetch_with_retries(client: &Client, url: &str) -> anyhow::Result<Hop> {
160    let mut delay = Duration::from_millis(200);
161    for attempt_no in 1..=MAX_ATTEMPTS {
162        match attempt(client, url).await {
163            Ok(hop) => return Ok(hop),
164            Err((err, transient)) => {
165                if attempt_no == MAX_ATTEMPTS || !transient {
166                    return Err(err);
167                }
168                tokio::time::sleep(delay).await;
169                delay *= 2;
170            }
171        }
172    }
173    unreachable!("loop returns on the final attempt")
174}
175
176/// Fetch a URL, following redirects manually so the SSRF guard re-validates and
177/// re-pins each hop (closing the DNS-rebinding window for redirected hosts too),
178/// retrying transient failures with exponential backoff. Caps the redirect
179/// chain at [`MAX_REDIRECTS`] and the body at [`MAX_BODY_BYTES`].
180pub async fn fetch_page(
181    url: &str,
182    timeout_secs: u64,
183    tls: &TlsConfig,
184) -> anyhow::Result<FetchedPage> {
185    let mut current = reqwest::Url::parse(url)?;
186    let mut hops = 0usize;
187
188    loop {
189        // Validate + resolve the host for THIS hop, then pin the connection to
190        // exactly those addresses.
191        let pinned = guard::validate_url(&current).await?;
192        let client = build_client(&current, timeout_secs, &pinned, tls)?;
193
194        match fetch_with_retries(&client, current.as_str()).await? {
195            Hop::Page(page) => return Ok(page),
196            Hop::Redirect(location) => {
197                hops += 1;
198                if hops > MAX_REDIRECTS {
199                    anyhow::bail!("too many redirects (>{MAX_REDIRECTS})");
200                }
201                current = current
202                    .join(&location)
203                    .map_err(|e| anyhow::anyhow!("invalid redirect target `{location}`: {e}"))?;
204            }
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn push_capped_truncates_oversized_chunk() {
215        let mut buf = Vec::new();
216        // A single chunk larger than the cap is clipped to the cap.
217        let stopped = push_capped(&mut buf, &[b'x'; 10], 4);
218        assert!(stopped);
219        assert_eq!(buf.len(), 4);
220    }
221
222    #[test]
223    fn push_capped_accumulates_until_cap() {
224        let mut buf = Vec::new();
225        assert!(!push_capped(&mut buf, b"abc", 8));
226        assert!(!push_capped(&mut buf, b"de", 8));
227        assert_eq!(buf, b"abcde");
228        // Next chunk crosses the cap: only the remaining 3 bytes are kept.
229        let stopped = push_capped(&mut buf, b"fghij", 8);
230        assert!(stopped);
231        assert_eq!(buf.len(), 8);
232        assert_eq!(buf, b"abcdefgh");
233    }
234
235    #[test]
236    fn push_capped_small_body_unaffected() {
237        let mut buf = Vec::new();
238        let stopped = push_capped(&mut buf, b"hello", 1024);
239        assert!(!stopped);
240        assert_eq!(buf, b"hello");
241    }
242}