Skip to main content

studio_worker/engine/
download.rs

1//! Shared model-file provisioning used by every real engine.
2//!
3//! The studio attaches a [`ModelSource`](crate::types::ModelSource) to
4//! each real offer listing the files the worker needs (diffusion model,
5//! GGUF, VAE, ...) with a public URL + filename each.  Engines fetch
6//! them on first use and cache them under their per-engine directory, so
7//! a fresh worker provisions itself with no manual model placement.
8//!
9//! The streamed body is checked against the server's `Content-Length`,
10//! so a truncated download is rejected and cleaned up instead of being
11//! renamed into place as a corrupt model that every later job fails to
12//! load.
13
14use anyhow::{bail, Context, Result};
15use std::path::{Component, Path, PathBuf};
16use std::time::Instant;
17use tracing::{info, warn};
18
19/// Tracing target for model downloads.  Stable so operators can filter
20/// with `RUST_LOG=studio_worker::engine::download=debug`.
21const TRACE_TARGET: &str = "studio_worker::engine::download";
22
23/// HTTP client timeout per request — a GGUF / safetensors file is up to
24/// a few GiB so a 30-minute ceiling is generous.
25const DOWNLOAD_TIMEOUT_SECS: u64 = 30 * 60;
26
27/// Resolve `filename` to a path inside `dir`, refusing anything that
28/// is not a plain file name (no `/`, `\`, `..`, or absolute paths) so a
29/// malicious or buggy `ModelSource` can't write outside the cache.
30pub fn model_cache_path(dir: &Path, filename: &str) -> Result<PathBuf> {
31    let path = Path::new(filename);
32    let mut components = path.components();
33    match (components.next(), components.next()) {
34        (Some(Component::Normal(name)), None)
35            if !filename.contains('/') && !filename.contains('\\') =>
36        {
37            Ok(dir.join(name))
38        }
39        _ => bail!("model filename must be a plain file name: {filename:?}"),
40    }
41}
42
43/// Verify a streamed download wrote exactly the body the server
44/// promised.  `expected` is the response's `Content-Length`; it is
45/// `None` for chunked transfers, where there's nothing to check and we
46/// accept whatever arrived.  A mismatch in either direction means the
47/// download is truncated or corrupt, so we surface a clear error rather
48/// than cache a bad model.
49pub fn verify_download_len(copied: u64, expected: Option<u64>) -> Result<()> {
50    match expected {
51        Some(expected) if copied != expected => bail!(
52            "size mismatch: wrote {copied} bytes but the server declared \
53             Content-Length {expected} (download truncated or corrupt)"
54        ),
55        _ => Ok(()),
56    }
57}
58
59/// Best-effort removal of a partial `.part` download.  A `NotFound` is
60/// the desired end state (something already cleaned it up); any other
61/// failure is surfaced so a stuck temp file can't silently fill the
62/// worker's disk over a long session.
63pub fn remove_partial(path: &Path) {
64    if let Err(e) = std::fs::remove_file(path) {
65        if e.kind() != std::io::ErrorKind::NotFound {
66            warn!(
67                target: TRACE_TARGET,
68                op = "cleanup",
69                path = %path.display(),
70                error = %e,
71                "failed to remove partial download"
72            );
73        }
74    }
75}
76
77/// Ensure `filename` is present under `dir`, downloading it from `url`
78/// when missing.  Returns the resolved local path.
79#[cfg_attr(coverage_nightly, coverage(off))]
80pub fn ensure_file(dir: &Path, filename: &str, url: &str) -> Result<PathBuf> {
81    let local = model_cache_path(dir, filename)?;
82    if local.is_file() {
83        tracing::debug!(
84            target: TRACE_TARGET,
85            op = "ensure_file",
86            filename,
87            path = %local.display(),
88            "cached"
89        );
90        return Ok(local);
91    }
92    download_file(url, &local)
93        .with_context(|| format!("downloading {filename} ({url}) -> {}", local.display()))?;
94    Ok(local)
95}
96
97/// Stream `url` into `dest` (atomic via a `.part` rename so a killed
98/// download doesn't leave a half-written file on disk).
99///
100/// Excluded from coverage: requires real network + filesystem (and a
101/// multi-GiB download per model on the happy path).  Exercised
102/// end-to-end via the live dev loop; the pure guards
103/// ([`verify_download_len`], [`model_cache_path`]) are unit-tested.
104#[cfg_attr(coverage_nightly, coverage(off))]
105pub fn download_file(url: &str, dest: &Path) -> Result<()> {
106    if let Some(parent) = dest.parent() {
107        std::fs::create_dir_all(parent)
108            .with_context(|| format!("creating {}", parent.display()))?;
109    }
110    let part = dest.with_extension("part");
111    let client = reqwest::blocking::Client::builder()
112        .timeout(std::time::Duration::from_secs(DOWNLOAD_TIMEOUT_SECS))
113        .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
114        .build()?;
115    info!(
116        target: TRACE_TARGET,
117        op = "download",
118        url,
119        dest = %dest.display(),
120        "starting"
121    );
122    let started = Instant::now();
123    let mut response = client.get(url).send().context("GET")?;
124    if !response.status().is_success() {
125        bail!("GET {url} -> {}", response.status());
126    }
127    let expected_len = response.content_length();
128    let mut file =
129        std::fs::File::create(&part).with_context(|| format!("creating {}", part.display()))?;
130    let copied = std::io::copy(&mut response, &mut file);
131    // Close the handle before any remove / rename so cleanup works on
132    // Windows, where an open file can't be unlinked.
133    drop(file);
134    let bytes = match copied {
135        Ok(bytes) => bytes,
136        Err(e) => {
137            remove_partial(&part);
138            return Err(e).context("streaming body");
139        }
140    };
141    if let Err(e) = verify_download_len(bytes, expected_len) {
142        remove_partial(&part);
143        return Err(e).with_context(|| format!("downloading {url}"));
144    }
145    std::fs::rename(&part, dest)
146        .with_context(|| format!("renaming {} -> {}", part.display(), dest.display()))?;
147    let elapsed_ms = started.elapsed().as_millis() as u64;
148    info!(
149        target: TRACE_TARGET,
150        op = "download",
151        url,
152        dest = %dest.display(),
153        bytes,
154        elapsed_ms,
155        "done"
156    );
157    Ok(())
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use tempfile::tempdir;
164
165    #[test]
166    fn model_cache_path_accepts_plain_filenames_only() {
167        let root = Path::new("/models");
168        assert_eq!(
169            model_cache_path(root, "model.gguf").unwrap(),
170            PathBuf::from("/models/model.gguf")
171        );
172        assert!(model_cache_path(root, "../outside.gguf").is_err());
173        assert!(model_cache_path(root, "nested/model.gguf").is_err());
174        assert!(model_cache_path(root, "/tmp/model.gguf").is_err());
175        assert!(model_cache_path(root, r"nested\model.gguf").is_err());
176        assert!(model_cache_path(root, "").is_err());
177    }
178
179    #[test]
180    fn verify_download_len_accepts_exact_match() {
181        assert!(verify_download_len(2_700_000_000, Some(2_700_000_000)).is_ok());
182    }
183
184    #[test]
185    fn verify_download_len_accepts_when_length_unknown() {
186        assert!(verify_download_len(123, None).is_ok());
187    }
188
189    #[test]
190    fn verify_download_len_rejects_truncated_download() {
191        let err = verify_download_len(40, Some(100)).unwrap_err().to_string();
192        assert!(err.contains("size mismatch"), "got: {err}");
193        assert!(err.contains("40"), "got: {err}");
194        assert!(err.contains("100"), "got: {err}");
195    }
196
197    #[test]
198    fn verify_download_len_rejects_overlong_download() {
199        assert!(verify_download_len(120, Some(100)).is_err());
200    }
201
202    #[test]
203    fn ensure_file_returns_cached_path_without_network() {
204        // A file already present must be returned as-is — `ensure_file`
205        // never touches the network, so an unreachable URL is fine.
206        let dir = tempdir().unwrap();
207        std::fs::write(dir.path().join("cached.gguf"), b"already here").unwrap();
208        let path = ensure_file(dir.path(), "cached.gguf", "https://example.invalid/x").unwrap();
209        assert_eq!(path, dir.path().join("cached.gguf"));
210        assert_eq!(std::fs::read(&path).unwrap(), b"already here");
211    }
212
213    #[test]
214    fn ensure_file_rejects_path_traversal_before_any_network() {
215        let dir = tempdir().unwrap();
216        let err = ensure_file(dir.path(), "../escape.gguf", "https://example.invalid/x")
217            .unwrap_err()
218            .to_string();
219        assert!(err.contains("plain file name"), "got: {err}");
220    }
221
222    #[test]
223    fn remove_partial_ignores_a_missing_file() {
224        let dir = tempdir().unwrap();
225        let out = crate::test_support::capture({
226            let missing = dir.path().join("never.part");
227            move || remove_partial(&missing)
228        });
229        assert!(
230            !out.contains("failed to remove partial download"),
231            "a not-found partial is the desired end state: {out:?}"
232        );
233    }
234
235    #[test]
236    fn remove_partial_surfaces_a_failed_removal() {
237        // Pointing the helper at a directory makes `remove_file` fail on
238        // every platform (it refuses to unlink a dir).
239        let dir = tempdir().unwrap();
240        let stubborn = dir.path().join("subdir");
241        std::fs::create_dir(&stubborn).unwrap();
242        let out = crate::test_support::capture(move || remove_partial(&stubborn));
243        assert!(
244            out.contains("failed to remove partial download"),
245            "a failed removal must surface in the logs: {out:?}"
246        );
247    }
248}