Skip to main content

webfetch/
fetch.rs

1use reqwest::header::{CONTENT_TYPE, LOCATION};
2use reqwest::{redirect::Policy, Client, Response};
3use std::net::SocketAddr;
4use std::time::Duration;
5
6use crate::guard;
7
8const USER_AGENT: &str = concat!("webfetch/", env!("CARGO_PKG_VERSION"));
9const MAX_ATTEMPTS: u32 = 3;
10const MAX_REDIRECTS: usize = 5;
11
12/// Hard cap on the response body we will read (5 MiB). The HTML extractor turns
13/// a page into a few KB of text, so a multi-megabyte body is almost never worth
14/// the bandwidth, memory, and parse time — and an unbounded read is a DoS lever.
15/// Bodies over the cap are *truncated* (not errored): partial content is still
16/// useful and the extractor copes with truncated HTML.
17const MAX_BODY_BYTES: usize = 5 * 1024 * 1024;
18
19/// Outcome of an HTTP fetch: the body, the URL we actually landed on after
20/// following redirects, and the response's `Content-Type` (if any).
21pub struct FetchedPage {
22    pub body: String,
23    pub final_url: String,
24    pub content_type: Option<String>,
25}
26
27/// One hop's result: either the final page, or a redirect to a raw `Location`.
28enum Hop {
29    Page(FetchedPage),
30    Redirect(String),
31}
32
33/// Build a client for a single validated URL. `pinned` are the public IPs the
34/// host already resolved to; binding them closes the DNS-rebinding window
35/// between validation and connection.
36///
37/// Redirects are **not** followed by reqwest here ([`Policy::none`]): we follow
38/// them manually in [`fetch_page`] so every hop is re-validated *and* pinned to
39/// its own resolved addresses. (Reqwest's `resolve_to_addrs` pins only the
40/// hosts known at build time, so auto-follow would leave redirect hops
41/// unpinned.) A consequence is that connection pooling cannot be shared across
42/// hosts via one long-lived client without weakening per-URL IP pinning, so we
43/// deliberately do not cache clients — SSRF safety wins over pool reuse.
44fn build_client(
45    url: &reqwest::Url,
46    timeout_secs: u64,
47    pinned: &[SocketAddr],
48) -> anyhow::Result<Client> {
49    let mut builder = Client::builder()
50        .timeout(Duration::from_secs(timeout_secs))
51        .redirect(Policy::none())
52        .user_agent(USER_AGENT)
53        .gzip(true)
54        .brotli(true);
55
56    if let Some(host) = url.host_str() {
57        if !pinned.is_empty() {
58            builder = builder.resolve_to_addrs(host, pinned);
59        }
60    }
61    Ok(builder.build()?)
62}
63
64/// Append as much of `chunk` to `buf` as fits under `max`. Returns `true` once
65/// the cap is reached (the body is truncated and the caller should stop).
66fn push_capped(buf: &mut Vec<u8>, chunk: &[u8], max: usize) -> bool {
67    let remaining = max.saturating_sub(buf.len());
68    if chunk.len() >= remaining {
69        buf.extend_from_slice(&chunk[..remaining]);
70        true
71    } else {
72        buf.extend_from_slice(chunk);
73        false
74    }
75}
76
77/// Read a response body, streaming chunks with a running byte cap so an
78/// oversized body is bounded before it is ever DOM-parsed. The `bool` reports
79/// whether a read error is transient (worth retrying).
80async fn read_body_capped(mut resp: Response) -> Result<String, (anyhow::Error, bool)> {
81    let mut buf: Vec<u8> = Vec::new();
82    // Honour Content-Length to pre-size, but never trust it past the cap.
83    if let Some(len) = resp.content_length() {
84        buf.reserve(len.min(MAX_BODY_BYTES as u64) as usize);
85    }
86    loop {
87        match resp.chunk().await {
88            Ok(Some(chunk)) => {
89                if push_capped(&mut buf, &chunk, MAX_BODY_BYTES) {
90                    break;
91                }
92            }
93            Ok(None) => break,
94            Err(e) => {
95                let transient = e.is_timeout();
96                return Err((e.into(), transient));
97            }
98        }
99    }
100    Ok(String::from_utf8_lossy(&buf).into_owned())
101}
102
103/// One request attempt. The bool in the error reports whether the failure is
104/// transient (worth retrying): connection/timeout errors, 5xx, and 429.
105async fn attempt(client: &Client, url: &str) -> Result<Hop, (anyhow::Error, bool)> {
106    let resp = match client.get(url).send().await {
107        Ok(r) => r,
108        Err(e) => {
109            let transient = e.is_timeout() || e.is_connect() || e.is_request();
110            return Err((e.into(), transient));
111        }
112    };
113
114    let status = resp.status();
115
116    // Redirects are surfaced to the caller (which re-validates and pins the
117    // target) rather than followed by reqwest.
118    if status.is_redirection() {
119        return match resp.headers().get(LOCATION).and_then(|v| v.to_str().ok()) {
120            Some(loc) => Ok(Hop::Redirect(loc.to_string())),
121            None => Err((
122                anyhow::anyhow!("redirect ({status}) without a Location header"),
123                false,
124            )),
125        };
126    }
127
128    let resp = match resp.error_for_status() {
129        Ok(r) => r,
130        Err(e) => {
131            let transient = status.is_server_error() || status.as_u16() == 429;
132            return Err((e.into(), transient));
133        }
134    };
135
136    let final_url = resp.url().to_string();
137    let content_type = resp
138        .headers()
139        .get(CONTENT_TYPE)
140        .and_then(|v| v.to_str().ok())
141        .map(|s| s.to_string());
142
143    let body = read_body_capped(resp).await?;
144    Ok(Hop::Page(FetchedPage {
145        body,
146        final_url,
147        content_type,
148    }))
149}
150
151/// Issue one hop's request, retrying transient failures with exponential
152/// backoff (200ms, 400ms).
153async fn fetch_with_retries(client: &Client, url: &str) -> anyhow::Result<Hop> {
154    let mut delay = Duration::from_millis(200);
155    for attempt_no in 1..=MAX_ATTEMPTS {
156        match attempt(client, url).await {
157            Ok(hop) => return Ok(hop),
158            Err((err, transient)) => {
159                if attempt_no == MAX_ATTEMPTS || !transient {
160                    return Err(err);
161                }
162                tokio::time::sleep(delay).await;
163                delay *= 2;
164            }
165        }
166    }
167    unreachable!("loop returns on the final attempt")
168}
169
170/// Fetch a URL, following redirects manually so the SSRF guard re-validates and
171/// re-pins each hop (closing the DNS-rebinding window for redirected hosts too),
172/// retrying transient failures with exponential backoff. Caps the redirect
173/// chain at [`MAX_REDIRECTS`] and the body at [`MAX_BODY_BYTES`].
174pub async fn fetch_page(url: &str, timeout_secs: u64) -> anyhow::Result<FetchedPage> {
175    let mut current = reqwest::Url::parse(url)?;
176    let mut hops = 0usize;
177
178    loop {
179        // Validate + resolve the host for THIS hop, then pin the connection to
180        // exactly those addresses.
181        let pinned = guard::validate_url(&current).await?;
182        let client = build_client(&current, timeout_secs, &pinned)?;
183
184        match fetch_with_retries(&client, current.as_str()).await? {
185            Hop::Page(page) => return Ok(page),
186            Hop::Redirect(location) => {
187                hops += 1;
188                if hops > MAX_REDIRECTS {
189                    anyhow::bail!("too many redirects (>{MAX_REDIRECTS})");
190                }
191                current = current
192                    .join(&location)
193                    .map_err(|e| anyhow::anyhow!("invalid redirect target `{location}`: {e}"))?;
194            }
195        }
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn push_capped_truncates_oversized_chunk() {
205        let mut buf = Vec::new();
206        // A single chunk larger than the cap is clipped to the cap.
207        let stopped = push_capped(&mut buf, &[b'x'; 10], 4);
208        assert!(stopped);
209        assert_eq!(buf.len(), 4);
210    }
211
212    #[test]
213    fn push_capped_accumulates_until_cap() {
214        let mut buf = Vec::new();
215        assert!(!push_capped(&mut buf, b"abc", 8));
216        assert!(!push_capped(&mut buf, b"de", 8));
217        assert_eq!(buf, b"abcde");
218        // Next chunk crosses the cap: only the remaining 3 bytes are kept.
219        let stopped = push_capped(&mut buf, b"fghij", 8);
220        assert!(stopped);
221        assert_eq!(buf.len(), 8);
222        assert_eq!(buf, b"abcdefgh");
223    }
224
225    #[test]
226    fn push_capped_small_body_unaffected() {
227        let mut buf = Vec::new();
228        let stopped = push_capped(&mut buf, b"hello", 1024);
229        assert!(!stopped);
230        assert_eq!(buf, b"hello");
231    }
232}