Skip to main content

webfetch/
fetch.rs

1use reqwest::header::CONTENT_TYPE;
2use reqwest::{redirect::Policy, Client};
3use std::time::Duration;
4
5use crate::guard;
6
7const USER_AGENT: &str = concat!("webfetch/", env!("CARGO_PKG_VERSION"));
8const MAX_ATTEMPTS: u32 = 3;
9const MAX_REDIRECTS: usize = 5;
10
11/// Outcome of an HTTP fetch: the body, the URL we actually landed on after
12/// following redirects, and the response's `Content-Type` (if any).
13pub struct FetchedPage {
14    pub body: String,
15    pub final_url: String,
16    pub content_type: Option<String>,
17}
18
19/// Redirect policy that re-runs the SSRF guard on every hop, so a public URL
20/// cannot bounce the client to `localhost`, the cloud metadata IP, or an
21/// internal host. Caps the chain at [`MAX_REDIRECTS`].
22fn guarded_redirect_policy() -> Policy {
23    Policy::custom(|attempt| {
24        if attempt.previous().len() >= MAX_REDIRECTS {
25            return attempt.error("too many redirects");
26        }
27        match guard::validate_url(attempt.url()) {
28            Ok(_) => attempt.follow(),
29            Err(e) => attempt.error(e),
30        }
31    })
32}
33
34/// Build a client for a single validated URL. `pinned` are the public IPs the
35/// host already resolved to; binding them closes the DNS-rebinding window
36/// between validation and connection.
37fn build_client(url: &reqwest::Url, timeout_secs: u64) -> anyhow::Result<Client> {
38    let pinned = guard::validate_url(url)?;
39    let mut builder = Client::builder()
40        .timeout(Duration::from_secs(timeout_secs))
41        .redirect(guarded_redirect_policy())
42        .user_agent(USER_AGENT)
43        .gzip(true)
44        .brotli(true);
45
46    if let Some(host) = url.host_str() {
47        if !pinned.is_empty() {
48            builder = builder.resolve_to_addrs(host, &pinned);
49        }
50    }
51    Ok(builder.build()?)
52}
53
54/// One request attempt. The bool in the error reports whether the failure is
55/// transient (worth retrying): connection/timeout errors, 5xx, and 429.
56async fn attempt(client: &Client, url: &str) -> Result<FetchedPage, (anyhow::Error, bool)> {
57    let resp = match client.get(url).send().await {
58        Ok(r) => r,
59        Err(e) => {
60            let transient = e.is_timeout() || e.is_connect() || e.is_request();
61            return Err((e.into(), transient));
62        }
63    };
64
65    let status = resp.status();
66    let resp = match resp.error_for_status() {
67        Ok(r) => r,
68        Err(e) => {
69            let transient = status.is_server_error() || status.as_u16() == 429;
70            return Err((e.into(), transient));
71        }
72    };
73
74    let final_url = resp.url().to_string();
75    let content_type = resp
76        .headers()
77        .get(CONTENT_TYPE)
78        .and_then(|v| v.to_str().ok())
79        .map(|s| s.to_string());
80
81    match resp.text().await {
82        Ok(body) => Ok(FetchedPage {
83            body,
84            final_url,
85            content_type,
86        }),
87        Err(e) => {
88            let transient = e.is_timeout();
89            Err((e.into(), transient))
90        }
91    }
92}
93
94/// Fetch a URL, following redirects, retrying transient failures with
95/// exponential backoff (200ms, 400ms).
96pub async fn fetch_page(url: &str, timeout_secs: u64) -> anyhow::Result<FetchedPage> {
97    let parsed = reqwest::Url::parse(url)?;
98    let client = build_client(&parsed, timeout_secs)?;
99
100    let mut delay = Duration::from_millis(200);
101    for attempt_no in 1..=MAX_ATTEMPTS {
102        match attempt(&client, url).await {
103            Ok(page) => return Ok(page),
104            Err((err, transient)) => {
105                if attempt_no == MAX_ATTEMPTS || !transient {
106                    return Err(err);
107                }
108                tokio::time::sleep(delay).await;
109                delay *= 2;
110            }
111        }
112    }
113    unreachable!("loop returns on the final attempt")
114}