1use reqwest::header::{CONTENT_TYPE, LOCATION};
2use reqwest::{redirect::Policy, Client, Response};
3use std::net::SocketAddr;
4use std::time::Duration;
5
6use crate::guard;
7use crate::tls::TlsConfig;
8
9const USER_AGENT: &str = concat!("webfetch/", env!("CARGO_PKG_VERSION"));
10const MAX_ATTEMPTS: u32 = 3;
11const MAX_REDIRECTS: usize = 5;
12
13const MAX_BODY_BYTES: usize = 5 * 1024 * 1024;
19
20pub struct FetchedPage {
23 pub body: String,
24 pub final_url: String,
25 pub content_type: Option<String>,
26}
27
28enum Hop {
30 Page(FetchedPage),
31 Redirect(String),
32}
33
34fn build_client(
46 url: &reqwest::Url,
47 timeout_secs: u64,
48 pinned: &[SocketAddr],
49 tls: &TlsConfig,
50) -> anyhow::Result<Client> {
51 let mut builder = Client::builder()
52 .timeout(Duration::from_secs(timeout_secs))
53 .redirect(Policy::none())
54 .user_agent(USER_AGENT)
55 .gzip(true)
56 .brotli(true);
57
58 builder = tls.apply(builder)?;
61
62 if let Some(host) = url.host_str() {
63 if !pinned.is_empty() {
64 builder = builder.resolve_to_addrs(host, pinned);
65 }
66 }
67 Ok(builder.build()?)
68}
69
70fn push_capped(buf: &mut Vec<u8>, chunk: &[u8], max: usize) -> bool {
73 let remaining = max.saturating_sub(buf.len());
74 if chunk.len() >= remaining {
75 buf.extend_from_slice(&chunk[..remaining]);
76 true
77 } else {
78 buf.extend_from_slice(chunk);
79 false
80 }
81}
82
83async fn read_body_capped(mut resp: Response) -> Result<String, (anyhow::Error, bool)> {
87 let mut buf: Vec<u8> = Vec::new();
88 if let Some(len) = resp.content_length() {
90 buf.reserve(len.min(MAX_BODY_BYTES as u64) as usize);
91 }
92 loop {
93 match resp.chunk().await {
94 Ok(Some(chunk)) => {
95 if push_capped(&mut buf, &chunk, MAX_BODY_BYTES) {
96 break;
97 }
98 }
99 Ok(None) => break,
100 Err(e) => {
101 let transient = e.is_timeout();
102 return Err((e.into(), transient));
103 }
104 }
105 }
106 Ok(String::from_utf8_lossy(&buf).into_owned())
107}
108
109async fn attempt(client: &Client, url: &str) -> Result<Hop, (anyhow::Error, bool)> {
112 let resp = match client.get(url).send().await {
113 Ok(r) => r,
114 Err(e) => {
115 let transient = e.is_timeout() || e.is_connect() || e.is_request();
116 return Err((e.into(), transient));
117 }
118 };
119
120 let status = resp.status();
121
122 if status.is_redirection() {
125 return match resp.headers().get(LOCATION).and_then(|v| v.to_str().ok()) {
126 Some(loc) => Ok(Hop::Redirect(loc.to_string())),
127 None => Err((
128 anyhow::anyhow!("redirect ({status}) without a Location header"),
129 false,
130 )),
131 };
132 }
133
134 let resp = match resp.error_for_status() {
135 Ok(r) => r,
136 Err(e) => {
137 let transient = status.is_server_error() || status.as_u16() == 429;
138 return Err((e.into(), transient));
139 }
140 };
141
142 let final_url = resp.url().to_string();
143 let content_type = resp
144 .headers()
145 .get(CONTENT_TYPE)
146 .and_then(|v| v.to_str().ok())
147 .map(|s| s.to_string());
148
149 let body = read_body_capped(resp).await?;
150 Ok(Hop::Page(FetchedPage {
151 body,
152 final_url,
153 content_type,
154 }))
155}
156
157async fn fetch_with_retries(client: &Client, url: &str) -> anyhow::Result<Hop> {
160 let mut delay = Duration::from_millis(200);
161 for attempt_no in 1..=MAX_ATTEMPTS {
162 match attempt(client, url).await {
163 Ok(hop) => return Ok(hop),
164 Err((err, transient)) => {
165 if attempt_no == MAX_ATTEMPTS || !transient {
166 return Err(err);
167 }
168 tokio::time::sleep(delay).await;
169 delay *= 2;
170 }
171 }
172 }
173 unreachable!("loop returns on the final attempt")
174}
175
176pub async fn fetch_page(
181 url: &str,
182 timeout_secs: u64,
183 tls: &TlsConfig,
184) -> anyhow::Result<FetchedPage> {
185 let mut current = reqwest::Url::parse(url)?;
186 let mut hops = 0usize;
187
188 loop {
189 let pinned = guard::validate_url(¤t).await?;
192 let client = build_client(¤t, timeout_secs, &pinned, tls)?;
193
194 match fetch_with_retries(&client, current.as_str()).await? {
195 Hop::Page(page) => return Ok(page),
196 Hop::Redirect(location) => {
197 hops += 1;
198 if hops > MAX_REDIRECTS {
199 anyhow::bail!("too many redirects (>{MAX_REDIRECTS})");
200 }
201 current = current
202 .join(&location)
203 .map_err(|e| anyhow::anyhow!("invalid redirect target `{location}`: {e}"))?;
204 }
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn push_capped_truncates_oversized_chunk() {
215 let mut buf = Vec::new();
216 let stopped = push_capped(&mut buf, &[b'x'; 10], 4);
218 assert!(stopped);
219 assert_eq!(buf.len(), 4);
220 }
221
222 #[test]
223 fn push_capped_accumulates_until_cap() {
224 let mut buf = Vec::new();
225 assert!(!push_capped(&mut buf, b"abc", 8));
226 assert!(!push_capped(&mut buf, b"de", 8));
227 assert_eq!(buf, b"abcde");
228 let stopped = push_capped(&mut buf, b"fghij", 8);
230 assert!(stopped);
231 assert_eq!(buf.len(), 8);
232 assert_eq!(buf, b"abcdefgh");
233 }
234
235 #[test]
236 fn push_capped_small_body_unaffected() {
237 let mut buf = Vec::new();
238 let stopped = push_capped(&mut buf, b"hello", 1024);
239 assert!(!stopped);
240 assert_eq!(buf, b"hello");
241 }
242}