1#![forbid(unsafe_code)]
10
11use reqwest::blocking::Client;
12use reqwest::header::RANGE;
13use reqwest::StatusCode;
14use std::fs;
15use std::io::Read;
16use std::path::{Path, PathBuf};
17use std::time::Duration;
18use vanta_core::{Area, VtaError, VtaResult};
19
20const MAX_REDIRECTS: usize = 10;
22
23pub type ProgressFn<'a> = dyn Fn(u64) + 'a;
27
28pub struct Downloader {
30 client: Client,
31 retries: u32,
32 allow_http: bool,
36}
37
38impl Downloader {
39 pub fn new() -> VtaResult<Downloader> {
43 Self::build(false)
44 }
45
46 pub fn insecure() -> VtaResult<Downloader> {
51 Self::build(true)
52 }
53
54 fn build(allow_http: bool) -> VtaResult<Downloader> {
55 let redirect = reqwest::redirect::Policy::custom(|attempt| {
58 if attempt.previous().len() >= MAX_REDIRECTS {
59 attempt.error("too many redirects")
60 } else if attempt.url().scheme() == "http"
61 && attempt
62 .previous()
63 .last()
64 .map(|u| u.scheme() == "https")
65 .unwrap_or(false)
66 {
67 attempt.stop()
68 } else {
69 attempt.follow()
70 }
71 });
72 let client = Client::builder()
80 .user_agent(concat!("vanta/", env!("CARGO_PKG_VERSION")))
81 .connect_timeout(Duration::from_secs(30))
82 .redirect(redirect)
83 .build()
84 .map_err(|e| VtaError::new(Area::Net, 4, format!("building HTTP client: {e}")))?;
85 Ok(Downloader {
86 client,
87 retries: 3,
88 allow_http,
89 })
90 }
91
92 pub fn with_retries(mut self, retries: u32) -> Self {
94 self.retries = retries;
95 self
96 }
97
98 pub fn download(&self, url: &str, dest: &Path) -> VtaResult<()> {
101 self.download_capped(url, dest, None)
102 }
103
104 pub fn download_capped(&self, url: &str, dest: &Path, max: Option<u64>) -> VtaResult<()> {
107 self.download_capped_with_progress(url, dest, max, None)
108 }
109
110 pub fn download_capped_with_progress(
114 &self,
115 url: &str,
116 dest: &Path,
117 max: Option<u64>,
118 progress: Option<&ProgressFn>,
119 ) -> VtaResult<()> {
120 self.scheme_ok(url)?;
121 let mut last: Option<VtaError> = None;
122 for attempt in 0..=self.retries {
123 match self.fetch_one(url, dest, max, progress) {
124 Ok(()) => return Ok(()),
125 Err(e) => {
126 last = Some(e);
127 if attempt < self.retries {
128 std::thread::sleep(backoff(attempt));
129 }
130 }
131 }
132 }
133 Err(last.unwrap_or_else(|| VtaError::new(Area::Net, 1, format!("download failed: {url}"))))
134 }
135
136 pub fn download_any(&self, urls: &[String], dest: &Path, max: Option<u64>) -> VtaResult<()> {
141 self.download_any_with_progress(urls, dest, max, None)
142 }
143
144 pub fn download_any_with_progress(
147 &self,
148 urls: &[String],
149 dest: &Path,
150 max: Option<u64>,
151 progress: Option<&ProgressFn>,
152 ) -> VtaResult<()> {
153 let mut last: Option<VtaError> = None;
154 for url in urls {
155 let _ = fs::remove_file(part_path(dest));
159 match self.download_capped_with_progress(url, dest, max, progress) {
160 Ok(()) => return Ok(()),
161 Err(e) => last = Some(e),
162 }
163 }
164 Err(last.unwrap_or_else(|| {
165 VtaError::new(Area::Net, 1, "no URLs supplied to download_any".to_string())
166 }))
167 }
168
169 fn scheme_ok(&self, url: &str) -> VtaResult<()> {
172 if let Some(rest) = url.strip_prefix("http://") {
173 if !self.allow_http && !is_loopback_authority(rest) {
174 return Err(VtaError::new(
175 Area::Net,
176 5,
177 format!(
178 "refusing plaintext http:// download of {url} \
179 (https required; set the insecure opt-in to override)"
180 ),
181 ));
182 }
183 }
184 Ok(())
185 }
186
187 fn fetch_one(
188 &self,
189 url: &str,
190 dest: &Path,
191 max: Option<u64>,
192 progress: Option<&ProgressFn>,
193 ) -> VtaResult<()> {
194 let part = part_path(dest);
195 let have = fs::metadata(&part).map(|m| m.len()).unwrap_or(0);
196
197 let mut req = self.client.get(url);
198 if have > 0 {
199 req = req.header(RANGE, format!("bytes={have}-"));
200 }
201 let mut resp = req
202 .send()
203 .map_err(|e| VtaError::new(Area::Net, 1, format!("requesting {url}: {e}")))?;
204
205 let status = resp.status();
206 let resuming = have > 0 && status == StatusCode::PARTIAL_CONTENT;
207 if !(status.is_success() || resuming) {
208 return Err(VtaError::new(
209 Area::Net,
210 1,
211 format!("HTTP {status} for {url}"),
212 ));
213 }
214
215 let remaining =
217 match max {
218 Some(m) => Some(m.checked_sub(if resuming { have } else { 0 }).ok_or_else(
219 || VtaError::new(Area::Net, 6, format!("download of {url} exceeds size cap")),
220 )?),
221 None => None,
222 };
223
224 if let Some(parent) = part.parent() {
225 fs::create_dir_all(parent).map_err(|e| io(parent, e))?;
226 }
227 let mut file = if resuming {
228 fs::OpenOptions::new()
229 .append(true)
230 .open(&part)
231 .map_err(|e| io(&part, e))?
232 } else {
233 let _ = fs::remove_file(&part);
234 fs::File::create(&part).map_err(|e| io(&part, e))?
235 };
236
237 let mut src = ProgressReader::new(&mut resp, progress);
239 let written = match remaining {
240 Some(limit) => {
242 let mut limited = (&mut src).take(limit.saturating_add(1));
243 let n = std::io::copy(&mut limited, &mut file).map_err(|e| {
244 VtaError::new(Area::Net, 1, format!("writing {}: {e}", part.display()))
245 })?;
246 if n > limit {
247 let _ = fs::remove_file(&part);
248 return Err(VtaError::new(
249 Area::Net,
250 6,
251 format!("download of {url} exceeds declared size {limit} bytes"),
252 ));
253 }
254 n
255 }
256 None => std::io::copy(&mut src, &mut file).map_err(|e| {
257 VtaError::new(Area::Net, 1, format!("writing {}: {e}", part.display()))
258 })?,
259 };
260 let _ = written;
261 file.sync_all().ok();
262 fs::rename(&part, dest).map_err(|e| io(dest, e))?;
263 Ok(())
264 }
265}
266
267struct ProgressReader<'a, R> {
271 inner: R,
272 progress: Option<&'a ProgressFn<'a>>,
273}
274
275impl<'a, R> ProgressReader<'a, R> {
276 fn new(inner: R, progress: Option<&'a ProgressFn<'a>>) -> Self {
277 ProgressReader { inner, progress }
278 }
279}
280
281impl<R: Read> Read for ProgressReader<'_, R> {
282 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
283 let n = self.inner.read(buf)?;
284 if n > 0 {
285 if let Some(cb) = self.progress {
286 cb(n as u64);
287 }
288 }
289 Ok(n)
290 }
291}
292
293fn is_loopback_authority(rest: &str) -> bool {
297 let authority = rest
299 .split(['/', '?', '#'])
300 .next()
301 .unwrap_or(rest)
302 .trim_end_matches('.');
303 let host_port = authority.rsplit('@').next().unwrap_or(authority);
305 let host = if let Some(stripped) = host_port.strip_prefix('[') {
306 stripped.split(']').next().unwrap_or(stripped)
308 } else {
309 host_port.split(':').next().unwrap_or(host_port)
310 };
311 host == "localhost" || host == "::1" || host.starts_with("127.")
312}
313
314fn part_path(dest: &Path) -> PathBuf {
315 let mut s = dest.as_os_str().to_os_string();
316 s.push(".part");
317 PathBuf::from(s)
318}
319
320fn backoff(attempt: u32) -> Duration {
321 let secs = (1u64 << attempt.min(4)) as f64 * 0.5;
323 Duration::from_secs_f64(secs)
324}
325
326fn io(path: &Path, e: std::io::Error) -> VtaError {
327 VtaError::new(Area::Net, 1, format!("{}: {e}", path.display()))
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn client_builds() {
336 assert!(Downloader::new().is_ok());
337 }
338
339 #[test]
340 fn part_path_appends_suffix() {
341 assert_eq!(
342 part_path(Path::new("/tmp/a.bin")),
343 PathBuf::from("/tmp/a.bin.part")
344 );
345 }
346
347 #[test]
348 fn download_any_empty_errors() {
349 let d = Downloader::new().unwrap();
350 assert!(d.download_any(&[], Path::new("/tmp/none"), None).is_err());
351 }
352
353 #[test]
354 fn rejects_plaintext_http_scheme() {
355 let d = Downloader::new().unwrap();
358 let err = d
359 .download("http://example.org/x", Path::new("/tmp/should-not-write"))
360 .unwrap_err();
361 assert_eq!(err.area, Area::Net);
362 assert_eq!(err.number, 5);
363 assert!(matches!(d.scheme_ok("https://example.org/x"), Ok(())));
366 }
367
368 #[test]
369 fn loopback_http_is_allowed_scheme() {
370 assert!(is_loopback_authority("127.0.0.1:8080/x"));
372 assert!(is_loopback_authority("localhost/x"));
373 assert!(is_loopback_authority("[::1]:9/x"));
374 assert!(!is_loopback_authority("example.org/x"));
375 assert!(!is_loopback_authority("127x.evil.com/x"));
376 }
377
378 #[test]
379 fn insecure_allows_http() {
380 let d = Downloader::insecure().unwrap();
381 assert!(matches!(d.scheme_ok("http://example.org/x"), Ok(())));
382 }
383
384 #[test]
385 fn size_cap_aborts_oversize_download() {
386 use std::collections::HashMap;
388 let mut files = HashMap::new();
389 files.insert("/big".to_string(), vec![0u8; 10_000]);
390 let port = vanta_test::serve(files);
391 let d = Downloader::new().unwrap();
392 let dest = std::env::temp_dir().join(format!("vanta-net-cap-{}.bin", std::process::id()));
393 let _ = fs::remove_file(&dest);
394 let url = format!("http://127.0.0.1:{port}/big");
395 let err = d.download_capped(&url, &dest, Some(1000)).unwrap_err();
397 assert_eq!(err.number, 6);
398 assert!(!dest.exists());
399 assert!(d.download_capped(&url, &dest, Some(10_000)).is_ok());
401 let _ = fs::remove_file(&dest);
402 }
403}