Skip to main content

vanta_net/
lib.rs

1//! `vanta-net` — HTTP downloads over rustls: resumable, retrying, mirror-aware.
2//!
3//! The installer fetches artifacts through a [`Downloader`]; bytes stream to a
4//! `<dest>.part` file (resumed via HTTP range on retry) and are atomically
5//! renamed into place on success. Verification (checksum/signature) is the
6//! caller's responsibility (`vanta-security` / `vanta-store`); this crate only
7//! moves bytes. Parallelism is provided by the installer running downloads on
8//! worker threads (`docs/08-installation.md`). See `docs/16-performance.md`.
9#![forbid(unsafe_code)]
10
11use reqwest::blocking::Client;
12use reqwest::header::RANGE;
13use reqwest::StatusCode;
14use std::fs;
15use std::io::Read;
16use std::path::{Path, PathBuf};
17use std::time::Duration;
18use vanta_core::{Area, VtaError, VtaResult};
19
20/// Maximum number of HTTP redirects to follow (audit M6).
21const MAX_REDIRECTS: usize = 10;
22
23/// A callback invoked with the number of bytes newly written to the destination
24/// as a download streams in. Used to drive a progress bar without this crate
25/// depending on a UI crate. Reported incrementally, chunk by chunk.
26pub type ProgressFn<'a> = dyn Fn(u64) + 'a;
27
28/// A reusable HTTP downloader.
29pub struct Downloader {
30    client: Client,
31    retries: u32,
32    /// When `true`, plaintext `http://` URLs to non-loopback hosts are permitted
33    /// (audit M6 insecure opt-in). Default `false`: only `https://` (and
34    /// loopback `http://`, for local dev/test servers) is allowed.
35    allow_http: bool,
36}
37
38impl Downloader {
39    /// Build a secure downloader (TLS via rustls, connect timeout). Plaintext
40    /// `http://` to non-loopback hosts is rejected and `https→http` downgrade
41    /// redirects are refused.
42    pub fn new() -> VtaResult<Downloader> {
43        Self::build(false)
44    }
45
46    /// Build a downloader that additionally permits plaintext `http://` to any
47    /// host. This is the **dangerous** insecure opt-in (audit M6/C1): callers
48    /// must surface it to the operator. `https→http` downgrade redirects are
49    /// still refused.
50    pub fn insecure() -> VtaResult<Downloader> {
51        Self::build(true)
52    }
53
54    fn build(allow_http: bool) -> VtaResult<Downloader> {
55        // Redirect policy (M6): cap the chain and never follow an https→http
56        // downgrade, regardless of `allow_http`.
57        let redirect = reqwest::redirect::Policy::custom(|attempt| {
58            if attempt.previous().len() >= MAX_REDIRECTS {
59                attempt.error("too many redirects")
60            } else if attempt.url().scheme() == "http"
61                && attempt
62                    .previous()
63                    .last()
64                    .map(|u| u.scheme() == "https")
65                    .unwrap_or(false)
66            {
67                attempt.stop()
68            } else {
69                attempt.follow()
70            }
71        });
72        // NOTE: we intentionally do NOT call reqwest's `.https_only(true)` here.
73        // That would reject *all* `http://`, including the loopback dev/test
74        // servers Vanta's own integration tests (and local mirrors) rely on.
75        // Instead TLS is enforced per-request in `scheme_ok` (reject plaintext to
76        // non-loopback hosts) and the custom redirect policy above refuses any
77        // https→http downgrade — together giving the same guarantee for real
78        // hosts without breaking loopback.
79        let client = Client::builder()
80            .user_agent(concat!("vanta/", env!("CARGO_PKG_VERSION")))
81            .connect_timeout(Duration::from_secs(30))
82            .redirect(redirect)
83            .build()
84            .map_err(|e| VtaError::new(Area::Net, 4, format!("building HTTP client: {e}")))?;
85        Ok(Downloader {
86            client,
87            retries: 3,
88            allow_http,
89        })
90    }
91
92    /// Override the per-URL retry count (default 3).
93    pub fn with_retries(mut self, retries: u32) -> Self {
94        self.retries = retries;
95        self
96    }
97
98    /// Download `url` into `dest`, resuming a partial `<dest>.part` if present and
99    /// retrying transient failures with backoff. No size ceiling.
100    pub fn download(&self, url: &str, dest: &Path) -> VtaResult<()> {
101        self.download_capped(url, dest, None)
102    }
103
104    /// Like [`Downloader::download`] but aborts with an error if more than
105    /// `max` bytes would be written (audit M8). `None` means no ceiling.
106    pub fn download_capped(&self, url: &str, dest: &Path, max: Option<u64>) -> VtaResult<()> {
107        self.download_capped_with_progress(url, dest, max, None)
108    }
109
110    /// Like [`Downloader::download_capped`], but reports streamed bytes through
111    /// `progress` (e.g. to advance a progress bar). Used for the registry index
112    /// download where the total length is not known in advance.
113    pub fn download_capped_with_progress(
114        &self,
115        url: &str,
116        dest: &Path,
117        max: Option<u64>,
118        progress: Option<&ProgressFn>,
119    ) -> VtaResult<()> {
120        self.scheme_ok(url)?;
121        let mut last: Option<VtaError> = None;
122        for attempt in 0..=self.retries {
123            match self.fetch_one(url, dest, max, progress) {
124                Ok(()) => return Ok(()),
125                Err(e) => {
126                    last = Some(e);
127                    if attempt < self.retries {
128                        std::thread::sleep(backoff(attempt));
129                    }
130                }
131            }
132        }
133        Err(last.unwrap_or_else(|| VtaError::new(Area::Net, 1, format!("download failed: {url}"))))
134    }
135
136    /// Try a primary URL then mirrors/alternates in order, returning on the first
137    /// success. A mirror that serves wrong bytes is caught by the caller's hash
138    /// verification, so falling through mirrors is safe (`docs/13-offline.md`).
139    /// `max`, when set, caps the bytes accepted from any single URL (M8).
140    pub fn download_any(&self, urls: &[String], dest: &Path, max: Option<u64>) -> VtaResult<()> {
141        self.download_any_with_progress(urls, dest, max, None)
142    }
143
144    /// Like [`Downloader::download_any`], but reports streamed bytes through
145    /// `progress` so the caller can render a download bar (bytes/total/ETA).
146    pub fn download_any_with_progress(
147        &self,
148        urls: &[String],
149        dest: &Path,
150        max: Option<u64>,
151        progress: Option<&ProgressFn>,
152    ) -> VtaResult<()> {
153        let mut last: Option<VtaError> = None;
154        for url in urls {
155            // L11: never resume across a mirror switch — a stale `.part` from a
156            // previous host would otherwise be concatenated with this host's
157            // bytes. Start each URL from a clean slate.
158            let _ = fs::remove_file(part_path(dest));
159            match self.download_capped_with_progress(url, dest, max, progress) {
160                Ok(()) => return Ok(()),
161                Err(e) => last = Some(e),
162            }
163        }
164        Err(last.unwrap_or_else(|| {
165            VtaError::new(Area::Net, 1, "no URLs supplied to download_any".to_string())
166        }))
167    }
168
169    /// Reject plaintext `http://` to non-loopback hosts unless the insecure
170    /// opt-in is set (audit M6).
171    fn scheme_ok(&self, url: &str) -> VtaResult<()> {
172        if let Some(rest) = url.strip_prefix("http://") {
173            if !self.allow_http && !is_loopback_authority(rest) {
174                return Err(VtaError::new(
175                    Area::Net,
176                    5,
177                    format!(
178                        "refusing plaintext http:// download of {url} \
179                         (https required; set the insecure opt-in to override)"
180                    ),
181                ));
182            }
183        }
184        Ok(())
185    }
186
187    fn fetch_one(
188        &self,
189        url: &str,
190        dest: &Path,
191        max: Option<u64>,
192        progress: Option<&ProgressFn>,
193    ) -> VtaResult<()> {
194        let part = part_path(dest);
195        let have = fs::metadata(&part).map(|m| m.len()).unwrap_or(0);
196
197        let mut req = self.client.get(url);
198        if have > 0 {
199            req = req.header(RANGE, format!("bytes={have}-"));
200        }
201        let mut resp = req
202            .send()
203            .map_err(|e| VtaError::new(Area::Net, 1, format!("requesting {url}: {e}")))?;
204
205        let status = resp.status();
206        let resuming = have > 0 && status == StatusCode::PARTIAL_CONTENT;
207        if !(status.is_success() || resuming) {
208            return Err(VtaError::new(
209                Area::Net,
210                1,
211                format!("HTTP {status} for {url}"),
212            ));
213        }
214
215        // M8: enforce the declared size as a hard ceiling on total bytes.
216        let remaining =
217            match max {
218                Some(m) => Some(m.checked_sub(if resuming { have } else { 0 }).ok_or_else(
219                    || VtaError::new(Area::Net, 6, format!("download of {url} exceeds size cap")),
220                )?),
221                None => None,
222            };
223
224        if let Some(parent) = part.parent() {
225            fs::create_dir_all(parent).map_err(|e| io(parent, e))?;
226        }
227        let mut file = if resuming {
228            fs::OpenOptions::new()
229                .append(true)
230                .open(&part)
231                .map_err(|e| io(&part, e))?
232        } else {
233            let _ = fs::remove_file(&part);
234            fs::File::create(&part).map_err(|e| io(&part, e))?
235        };
236
237        // Wrap the response so each read chunk drives the progress callback.
238        let mut src = ProgressReader::new(&mut resp, progress);
239        let written = match remaining {
240            // Read one byte past the limit so we can detect an oversize body.
241            Some(limit) => {
242                let mut limited = (&mut src).take(limit.saturating_add(1));
243                let n = std::io::copy(&mut limited, &mut file).map_err(|e| {
244                    VtaError::new(Area::Net, 1, format!("writing {}: {e}", part.display()))
245                })?;
246                if n > limit {
247                    let _ = fs::remove_file(&part);
248                    return Err(VtaError::new(
249                        Area::Net,
250                        6,
251                        format!("download of {url} exceeds declared size {limit} bytes"),
252                    ));
253                }
254                n
255            }
256            None => std::io::copy(&mut src, &mut file).map_err(|e| {
257                VtaError::new(Area::Net, 1, format!("writing {}: {e}", part.display()))
258            })?,
259        };
260        let _ = written;
261        file.sync_all().ok();
262        fs::rename(&part, dest).map_err(|e| io(dest, e))?;
263        Ok(())
264    }
265}
266
267/// A `Read` adapter that reports each chunk's byte count through a progress
268/// callback as bytes stream through it. The size-cap and verification logic is
269/// unaffected — this only observes the bytes already being read.
270struct ProgressReader<'a, R> {
271    inner: R,
272    progress: Option<&'a ProgressFn<'a>>,
273}
274
275impl<'a, R> ProgressReader<'a, R> {
276    fn new(inner: R, progress: Option<&'a ProgressFn<'a>>) -> Self {
277        ProgressReader { inner, progress }
278    }
279}
280
281impl<R: Read> Read for ProgressReader<'_, R> {
282    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
283        let n = self.inner.read(buf)?;
284        if n > 0 {
285            if let Some(cb) = self.progress {
286                cb(n as u64);
287            }
288        }
289        Ok(n)
290    }
291}
292
293/// Whether the authority part of an `http://` URL (everything after the scheme)
294/// names a loopback host. Used to keep local dev/test servers usable while still
295/// rejecting plaintext to public hosts.
296fn is_loopback_authority(rest: &str) -> bool {
297    // authority is up to the first '/', '?' or '#'.
298    let authority = rest
299        .split(['/', '?', '#'])
300        .next()
301        .unwrap_or(rest)
302        .trim_end_matches('.');
303    // strip userinfo
304    let host_port = authority.rsplit('@').next().unwrap_or(authority);
305    let host = if let Some(stripped) = host_port.strip_prefix('[') {
306        // IPv6 literal: [::1]:port
307        stripped.split(']').next().unwrap_or(stripped)
308    } else {
309        host_port.split(':').next().unwrap_or(host_port)
310    };
311    host == "localhost" || host == "::1" || host.starts_with("127.")
312}
313
314fn part_path(dest: &Path) -> PathBuf {
315    let mut s = dest.as_os_str().to_os_string();
316    s.push(".part");
317    PathBuf::from(s)
318}
319
320fn backoff(attempt: u32) -> Duration {
321    // 0.5s, 1s, 2s, 4s … capped.
322    let secs = (1u64 << attempt.min(4)) as f64 * 0.5;
323    Duration::from_secs_f64(secs)
324}
325
326fn io(path: &Path, e: std::io::Error) -> VtaError {
327    VtaError::new(Area::Net, 1, format!("{}: {e}", path.display()))
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn client_builds() {
336        assert!(Downloader::new().is_ok());
337    }
338
339    #[test]
340    fn part_path_appends_suffix() {
341        assert_eq!(
342            part_path(Path::new("/tmp/a.bin")),
343            PathBuf::from("/tmp/a.bin.part")
344        );
345    }
346
347    #[test]
348    fn download_any_empty_errors() {
349        let d = Downloader::new().unwrap();
350        assert!(d.download_any(&[], Path::new("/tmp/none"), None).is_err());
351    }
352
353    #[test]
354    fn rejects_plaintext_http_scheme() {
355        // M6: a secure downloader refuses http:// to a non-loopback host before
356        // any network I/O.
357        let d = Downloader::new().unwrap();
358        let err = d
359            .download("http://example.org/x", Path::new("/tmp/should-not-write"))
360            .unwrap_err();
361        assert_eq!(err.area, Area::Net);
362        assert_eq!(err.number, 5);
363        // https is accepted at the scheme gate (it will fail later on network,
364        // but not on scheme).
365        assert!(matches!(d.scheme_ok("https://example.org/x"), Ok(())));
366    }
367
368    #[test]
369    fn loopback_http_is_allowed_scheme() {
370        // Local dev/test servers serve plaintext on 127.0.0.1 — permitted.
371        assert!(is_loopback_authority("127.0.0.1:8080/x"));
372        assert!(is_loopback_authority("localhost/x"));
373        assert!(is_loopback_authority("[::1]:9/x"));
374        assert!(!is_loopback_authority("example.org/x"));
375        assert!(!is_loopback_authority("127x.evil.com/x"));
376    }
377
378    #[test]
379    fn insecure_allows_http() {
380        let d = Downloader::insecure().unwrap();
381        assert!(matches!(d.scheme_ok("http://example.org/x"), Ok(())));
382    }
383
384    #[test]
385    fn size_cap_aborts_oversize_download() {
386        // M8: a body larger than the declared ceiling is rejected.
387        use std::collections::HashMap;
388        let mut files = HashMap::new();
389        files.insert("/big".to_string(), vec![0u8; 10_000]);
390        let port = vanta_test::serve(files);
391        let d = Downloader::new().unwrap();
392        let dest = std::env::temp_dir().join(format!("vanta-net-cap-{}.bin", std::process::id()));
393        let _ = fs::remove_file(&dest);
394        let url = format!("http://127.0.0.1:{port}/big");
395        // Cap below the body size → error, no file published.
396        let err = d.download_capped(&url, &dest, Some(1000)).unwrap_err();
397        assert_eq!(err.number, 6);
398        assert!(!dest.exists());
399        // Cap at/above the body size → succeeds.
400        assert!(d.download_capped(&url, &dest, Some(10_000)).is_ok());
401        let _ = fs::remove_file(&dest);
402    }
403}