Skip to main content

studio_worker/engine/
download.rs

1//! Shared model-file provisioning used by every real engine.
2//!
3//! The studio attaches a [`ModelSource`](crate::types::ModelSource) to
4//! each real offer listing the files the worker needs (diffusion model,
5//! GGUF, VAE, ...) with a public URL + filename each.  Engines fetch
6//! them on first use and cache them under their per-engine directory, so
7//! a fresh worker provisions itself with no manual model placement.
8//!
9//! The streamed body is checked against the server's `Content-Length`,
10//! so a truncated download is rejected and cleaned up instead of being
11//! renamed into place as a corrupt model that every later job fails to
12//! load.
13//!
14//! Every download emits a structured `tracing` breadcrumb at the
15//! `studio_worker::engine::download` target: `info` on `starting` and
16//! `done`, and a symmetric `warn` on each failure (non-success status,
17//! a streaming error, or a length / sha256 mismatch) so an operator
18//! never sees a dangling `starting` with no terminal event explaining
19//! what went wrong — mirroring the `ApiClient` HTTP surface.
20
21use anyhow::{bail, Context, Result};
22use sha2::{Digest, Sha256};
23use std::io::Write;
24use std::path::{Component, Path, PathBuf};
25use std::time::Instant;
26use tracing::{info, warn};
27
28use crate::types::ModelFile;
29
30/// Tracing target for model downloads.  Stable so operators can filter
31/// with `RUST_LOG=studio_worker::engine::download=debug`.
32const TRACE_TARGET: &str = "studio_worker::engine::download";
33
34/// HTTP client timeout per request — a GGUF / safetensors file is up to
35/// a few GiB so a 30-minute ceiling is generous.
36const DOWNLOAD_TIMEOUT_SECS: u64 = 30 * 60;
37
38/// Resolve `filename` to a path inside `dir`, refusing anything that
39/// is not a plain file name (no `/`, `\`, `..`, or absolute paths) so a
40/// malicious or buggy `ModelSource` can't write outside the cache.
41pub fn model_cache_path(dir: &Path, filename: &str) -> Result<PathBuf> {
42    let path = Path::new(filename);
43    let mut components = path.components();
44    match (components.next(), components.next()) {
45        (Some(Component::Normal(name)), None)
46            if !filename.contains('/') && !filename.contains('\\') =>
47        {
48            Ok(dir.join(name))
49        }
50        _ => bail!("model filename must be a plain file name: {filename:?}"),
51    }
52}
53
54/// Verify a streamed download wrote exactly the body the server
55/// promised.  `expected` is the response's `Content-Length`; it is
56/// `None` for chunked transfers, where there's nothing to check and we
57/// accept whatever arrived.  A mismatch in either direction means the
58/// download is truncated or corrupt, so we surface a clear error rather
59/// than cache a bad model.
60pub fn verify_download_len(copied: u64, expected: Option<u64>) -> Result<()> {
61    match expected {
62        Some(expected) if copied != expected => bail!(
63            "size mismatch: wrote {copied} bytes but the server declared \
64             Content-Length {expected} (download truncated or corrupt)"
65        ),
66        _ => Ok(()),
67    }
68}
69
70/// Verify a downloaded body's sha256 against the registry's expected
71/// hex digest (case-insensitive).  `None` means the registry row
72/// predates integrity hashes — nothing to check.  A mismatch means a
73/// corrupted or tampered body that must never be committed to the
74/// cache.
75pub fn verify_sha256(actual_hex: &str, expected: Option<&str>) -> Result<()> {
76    match expected {
77        Some(expected) if !actual_hex.eq_ignore_ascii_case(expected.trim()) => bail!(
78            "sha256 mismatch: downloaded body hashes to {actual_hex} but the registry \
79             expects {expected} (corrupted or tampered download)"
80        ),
81        _ => Ok(()),
82    }
83}
84
85/// Writer adapter that feeds every chunk through a [`Sha256`] hasher
86/// on its way to the underlying file, so verification needs no second
87/// read pass over a multi-GiB model.
88struct HashingWriter<W: Write> {
89    inner: W,
90    hasher: Sha256,
91}
92
93impl<W: Write> Write for HashingWriter<W> {
94    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
95        let written = self.inner.write(buf)?;
96        self.hasher.update(&buf[..written]);
97        Ok(written)
98    }
99
100    fn flush(&mut self) -> std::io::Result<()> {
101        self.inner.flush()
102    }
103}
104
105/// Best-effort removal of a temporary file — a partial `.part`
106/// download, an engine's per-job scratch image, or a downloaded init /
107/// mask.  A `NotFound` is the desired end state (something already
108/// cleaned it up); any other failure is surfaced so a stuck temp file
109/// can't silently fill the worker's disk over a long session.
110pub fn remove_temp_file(path: &Path) {
111    if let Err(e) = std::fs::remove_file(path) {
112        if e.kind() != std::io::ErrorKind::NotFound {
113            warn!(
114                target: TRACE_TARGET,
115                op = "cleanup",
116                path = %path.display(),
117                error = %e,
118                "failed to remove temp file"
119            );
120        }
121    }
122}
123
124/// RAII owner of a job's scratch files.  Registering a job's temp
125/// paths up front means every exit path — the success return, an
126/// engine error, even a panic mid-dispatch — removes them on drop
127/// instead of leaking them into the temp dir and slowly filling the
128/// worker's disk over a long-running session.  Removal is best-effort
129/// via [`remove_temp_file`], so a path that never materialised (the
130/// job failed before the file was written) is silently tolerated.
131#[derive(Default)]
132pub struct TempFileGuard {
133    paths: Vec<PathBuf>,
134}
135
136impl TempFileGuard {
137    pub fn new() -> Self {
138        Self { paths: Vec::new() }
139    }
140
141    /// Register a path to be removed when the guard drops.
142    pub fn push(&mut self, path: PathBuf) {
143        self.paths.push(path);
144    }
145}
146
147impl Drop for TempFileGuard {
148    fn drop(&mut self) {
149        for path in &self.paths {
150            remove_temp_file(path);
151        }
152    }
153}
154
155/// Ensure `file.filename` is present under `dir`, downloading it from
156/// `file.url` when missing (verified against `file.sha256` when the
157/// registry provides one).  Returns the resolved local path.
158#[cfg_attr(coverage_nightly, coverage(off))]
159pub fn ensure_file(dir: &Path, file: &ModelFile) -> Result<PathBuf> {
160    let filename = file.filename.as_str();
161    let url = file.url.as_str();
162    let local = model_cache_path(dir, filename)?;
163    if local.is_file() {
164        tracing::debug!(
165            target: TRACE_TARGET,
166            op = "ensure_file",
167            filename,
168            path = %local.display(),
169            "cached"
170        );
171        return Ok(local);
172    }
173    download_file_verified(url, &local, file.sha256.as_deref())
174        .with_context(|| format!("downloading {filename} ({url}) -> {}", local.display()))?;
175    Ok(local)
176}
177
178/// Stream `url` into `dest` (atomic via a `.part` rename so a killed
179/// download doesn't leave a half-written file on disk).
180///
181/// Excluded from coverage: requires real network + filesystem (and a
182/// multi-GiB download per model on the happy path).  Exercised
183/// end-to-end via the live dev loop; the pure guards
184/// ([`verify_download_len`], [`model_cache_path`]) are unit-tested.
185#[cfg_attr(coverage_nightly, coverage(off))]
186pub fn download_file(url: &str, dest: &Path) -> Result<()> {
187    download_file_verified(url, dest, None)
188}
189
190/// [`download_file`] with an optional expected sha256 — the body is
191/// hashed while it streams and a mismatch is rejected before the
192/// rename, so a bad body never lands in the cache.
193#[cfg_attr(coverage_nightly, coverage(off))]
194pub fn download_file_verified(url: &str, dest: &Path, expected_sha256: Option<&str>) -> Result<()> {
195    if let Some(parent) = dest.parent() {
196        std::fs::create_dir_all(parent)
197            .with_context(|| format!("creating {}", parent.display()))?;
198    }
199    let part = dest.with_extension("part");
200    let client = reqwest::blocking::Client::builder()
201        .timeout(std::time::Duration::from_secs(DOWNLOAD_TIMEOUT_SECS))
202        .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
203        .build()?;
204    info!(
205        target: TRACE_TARGET,
206        op = "download",
207        url,
208        dest = %dest.display(),
209        "starting"
210    );
211    let started = Instant::now();
212    let mut response = match client.get(url).send() {
213        Ok(response) => response,
214        Err(e) => {
215            // A connection-level failure (DNS, TLS, timeout, or a
216            // connection closed before the declared body completed)
217            // must leave the same terminal breadcrumb as the other
218            // failure modes below — otherwise an operator filtering
219            // this target sees the "starting" line then silence.
220            warn!(
221                target: TRACE_TARGET,
222                op = "download",
223                url,
224                dest = %dest.display(),
225                elapsed_ms = started.elapsed().as_millis() as u64,
226                error = %e,
227                "download failed: request error"
228            );
229            return Err(e).context("GET");
230        }
231    };
232    let status = response.status();
233    if !status.is_success() {
234        warn!(
235            target: TRACE_TARGET,
236            op = "download",
237            url,
238            dest = %dest.display(),
239            status = status.as_u16(),
240            elapsed_ms = started.elapsed().as_millis() as u64,
241            "download failed: non-success status"
242        );
243        bail!("GET {url} -> {status}");
244    }
245    let expected_len = response.content_length();
246    let file =
247        std::fs::File::create(&part).with_context(|| format!("creating {}", part.display()))?;
248    let mut writer = HashingWriter {
249        inner: file,
250        hasher: Sha256::new(),
251    };
252    let copied = std::io::copy(&mut response, &mut writer);
253    let digest = writer.hasher.finalize();
254    // Close the handle before any remove / rename so cleanup works on
255    // Windows, where an open file can't be unlinked.
256    drop(writer.inner);
257    let bytes = match copied {
258        Ok(bytes) => bytes,
259        Err(e) => {
260            remove_temp_file(&part);
261            warn!(
262                target: TRACE_TARGET,
263                op = "download",
264                url,
265                dest = %dest.display(),
266                elapsed_ms = started.elapsed().as_millis() as u64,
267                error = %e,
268                "download failed: streaming body"
269            );
270            return Err(e).context("streaming body");
271        }
272    };
273    if let Err(e) = verify_download_len(bytes, expected_len) {
274        remove_temp_file(&part);
275        warn!(
276            target: TRACE_TARGET,
277            op = "download",
278            url,
279            dest = %dest.display(),
280            bytes,
281            elapsed_ms = started.elapsed().as_millis() as u64,
282            error = %e,
283            "download failed: size mismatch"
284        );
285        return Err(e).with_context(|| format!("downloading {url}"));
286    }
287    let actual_hex: String = digest.iter().map(|b| format!("{b:02x}")).collect();
288    if let Err(e) = verify_sha256(&actual_hex, expected_sha256) {
289        remove_temp_file(&part);
290        warn!(
291            target: TRACE_TARGET,
292            op = "download",
293            url,
294            dest = %dest.display(),
295            bytes,
296            elapsed_ms = started.elapsed().as_millis() as u64,
297            error = %e,
298            "download failed: sha256 mismatch"
299        );
300        return Err(e).with_context(|| format!("downloading {url}"));
301    }
302    std::fs::rename(&part, dest)
303        .with_context(|| format!("renaming {} -> {}", part.display(), dest.display()))?;
304    let elapsed_ms = started.elapsed().as_millis() as u64;
305    info!(
306        target: TRACE_TARGET,
307        op = "download",
308        url,
309        dest = %dest.display(),
310        bytes,
311        elapsed_ms,
312        "done"
313    );
314    Ok(())
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use tempfile::tempdir;
321
322    #[test]
323    fn model_cache_path_accepts_plain_filenames_only() {
324        let root = Path::new("/models");
325        assert_eq!(
326            model_cache_path(root, "model.gguf").unwrap(),
327            PathBuf::from("/models/model.gguf")
328        );
329        assert!(model_cache_path(root, "../outside.gguf").is_err());
330        assert!(model_cache_path(root, "nested/model.gguf").is_err());
331        assert!(model_cache_path(root, "/tmp/model.gguf").is_err());
332        assert!(model_cache_path(root, r"nested\model.gguf").is_err());
333        assert!(model_cache_path(root, "").is_err());
334    }
335
336    #[test]
337    fn verify_download_len_accepts_exact_match() {
338        assert!(verify_download_len(2_700_000_000, Some(2_700_000_000)).is_ok());
339    }
340
341    #[test]
342    fn verify_download_len_accepts_when_length_unknown() {
343        assert!(verify_download_len(123, None).is_ok());
344    }
345
346    #[test]
347    fn verify_download_len_rejects_truncated_download() {
348        let err = verify_download_len(40, Some(100)).unwrap_err().to_string();
349        assert!(err.contains("size mismatch"), "got: {err}");
350        assert!(err.contains("40"), "got: {err}");
351        assert!(err.contains("100"), "got: {err}");
352    }
353
354    #[test]
355    fn verify_download_len_rejects_overlong_download() {
356        assert!(verify_download_len(120, Some(100)).is_err());
357    }
358
359    fn test_file(filename: &str, url: &str) -> ModelFile {
360        ModelFile {
361            role: crate::types::ModelFileRole::Model,
362            url: url.to_string(),
363            filename: filename.to_string(),
364            approx_bytes: None,
365            sha256: None,
366        }
367    }
368
369    #[test]
370    fn ensure_file_returns_cached_path_without_network() {
371        // A file already present must be returned as-is — `ensure_file`
372        // never touches the network, so an unreachable URL is fine.
373        let dir = tempdir().unwrap();
374        std::fs::write(dir.path().join("cached.gguf"), b"already here").unwrap();
375        let path = ensure_file(
376            dir.path(),
377            &test_file("cached.gguf", "https://example.invalid/x"),
378        )
379        .unwrap();
380        assert_eq!(path, dir.path().join("cached.gguf"));
381        assert_eq!(std::fs::read(&path).unwrap(), b"already here");
382    }
383
384    #[test]
385    fn ensure_file_rejects_path_traversal_before_any_network() {
386        let dir = tempdir().unwrap();
387        let err = ensure_file(
388            dir.path(),
389            &test_file("../escape.gguf", "https://example.invalid/x"),
390        )
391        .unwrap_err()
392        .to_string();
393        assert!(err.contains("plain file name"), "got: {err}");
394    }
395
396    // -----------------------------------------------------------------
397    // verify_sha256 — the integrity gate for registry-pinned hashes.
398    // -----------------------------------------------------------------
399
400    #[test]
401    fn verify_sha256_accepts_match_and_absence() {
402        assert!(verify_sha256("abc123", Some("abc123")).is_ok());
403        assert!(
404            verify_sha256("abc123", Some("ABC123")).is_ok(),
405            "case-insensitive"
406        );
407        assert!(
408            verify_sha256("abc123", Some(" abc123 ")).is_ok(),
409            "whitespace-tolerant"
410        );
411        assert!(
412            verify_sha256("abc123", None).is_ok(),
413            "legacy rows have no hash"
414        );
415    }
416
417    #[test]
418    fn verify_sha256_rejects_mismatch() {
419        let err = verify_sha256("abc123", Some("def456"))
420            .unwrap_err()
421            .to_string();
422        assert!(err.contains("sha256 mismatch"), "got: {err}");
423        assert!(
424            err.contains("abc123") && err.contains("def456"),
425            "must name both digests: {err}"
426        );
427    }
428
429    // -----------------------------------------------------------------
430    // HashingWriter — streams the body into the cache file while
431    // computing the sha256 that `verify_sha256` later checks.  The
432    // integrity guarantee hinges on hashing *exactly* the bytes the
433    // inner writer accepted: `write` slices `&buf[..written]`, so a
434    // short write (inner takes only a prefix) must hash only that
435    // prefix — the unwritten tail is re-offered by `io::copy` on the
436    // next call.  Hashing the whole `buf` on a short write would
437    // silently corrupt every digest and turn the integrity gate into a
438    // false-reject.  The download integration test wraps a real `File`,
439    // which never short-writes, so this prefix branch is only reachable
440    // here.
441    // -----------------------------------------------------------------
442
443    /// A writer that accepts at most `max_per_write` bytes per call (to
444    /// model a short write) and counts `flush` calls.
445    struct ProbeWriter {
446        sink: Vec<u8>,
447        max_per_write: usize,
448        flushes: usize,
449    }
450
451    impl Write for ProbeWriter {
452        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
453            let take = buf.len().min(self.max_per_write);
454            self.sink.extend_from_slice(&buf[..take]);
455            Ok(take)
456        }
457        fn flush(&mut self) -> std::io::Result<()> {
458            self.flushes += 1;
459            Ok(())
460        }
461    }
462
463    fn hex(bytes: &[u8]) -> String {
464        bytes.iter().map(|b| format!("{b:02x}")).collect()
465    }
466
467    #[test]
468    fn hashing_writer_hashes_only_the_bytes_the_inner_accepted() {
469        // The inner writer takes only 3 of the 8 offered bytes, so the
470        // hasher must absorb just "abc" — proving the `&buf[..written]`
471        // slice.  If `write` hashed the whole `buf`, the digest would be
472        // sha256("abcdefgh") and this assertion would fail.
473        let mut writer = HashingWriter {
474            inner: ProbeWriter {
475                sink: Vec::new(),
476                max_per_write: 3,
477                flushes: 0,
478            },
479            hasher: Sha256::new(),
480        };
481        let written = writer.write(b"abcdefgh").unwrap();
482        assert_eq!(written, 3, "inner accepts at most 3 bytes per write");
483        assert_eq!(writer.inner.sink, b"abc", "only the prefix reaches inner");
484        assert_eq!(
485            hex(&writer.hasher.finalize()),
486            hex(&Sha256::digest(b"abc")),
487            "hash covers only the accepted prefix"
488        );
489    }
490
491    #[test]
492    fn hashing_writer_digest_matches_a_short_writing_stream_end_to_end() {
493        // Drive the writer the way `download_file_verified` does — via
494        // `io::copy`, which re-offers the unwritten tail — through an
495        // inner that only takes 4 bytes at a time.  The streamed bytes
496        // and the final digest must both equal the full source, with no
497        // double-hashing across the re-offered chunks.
498        let source = b"the quick brown model weights".to_vec();
499        let mut reader = source.as_slice();
500        let mut writer = HashingWriter {
501            inner: ProbeWriter {
502                sink: Vec::new(),
503                max_per_write: 4,
504                flushes: 0,
505            },
506            hasher: Sha256::new(),
507        };
508        let copied = std::io::copy(&mut reader, &mut writer).unwrap();
509        assert_eq!(copied as usize, source.len());
510        assert_eq!(
511            writer.inner.sink, source,
512            "every byte reaches the cache file"
513        );
514        assert_eq!(
515            hex(&writer.hasher.finalize()),
516            hex(&Sha256::digest(&source)),
517            "digest matches the full body"
518        );
519    }
520
521    #[test]
522    fn hashing_writer_flush_delegates_to_the_inner_writer() {
523        let mut writer = HashingWriter {
524            inner: ProbeWriter {
525                sink: Vec::new(),
526                max_per_write: usize::MAX,
527                flushes: 0,
528            },
529            hasher: Sha256::new(),
530        };
531        writer.flush().unwrap();
532        writer.flush().unwrap();
533        assert_eq!(writer.inner.flushes, 2, "flush is forwarded to inner");
534    }
535
536    // -----------------------------------------------------------------
537    // remove_temp_file + TempFileGuard — the shared best-effort cleanup
538    // primitives every engine routes its per-job scratch files through.
539    // Owned here (the shared engine-provisioning module) so the sdcpp
540    // output guard and the onnx init/mask cleanup share one tested
541    // implementation instead of each rolling its own silent removal.
542    // -----------------------------------------------------------------
543
544    #[test]
545    fn remove_temp_file_deletes_an_existing_file_quietly() {
546        let dir = tempdir().unwrap();
547        let f = dir.path().join("artefact.webp");
548        std::fs::write(&f, b"bytes").unwrap();
549        let out = crate::test_support::capture({
550            let f = f.clone();
551            move || remove_temp_file(&f)
552        });
553        assert!(!f.exists(), "file should be gone after cleanup");
554        assert!(
555            !out.contains("failed to remove temp file"),
556            "the success path must not warn: {out:?}"
557        );
558    }
559
560    #[test]
561    fn remove_temp_file_ignores_a_missing_file() {
562        let dir = tempdir().unwrap();
563        let out = crate::test_support::capture({
564            let missing = dir.path().join("never.part");
565            move || remove_temp_file(&missing)
566        });
567        assert!(
568            !out.contains("failed to remove temp file"),
569            "a not-found temp file is the desired end state: {out:?}"
570        );
571    }
572
573    #[test]
574    fn remove_temp_file_surfaces_a_failed_removal() {
575        // Pointing the helper at a directory makes `remove_file` fail on
576        // every platform (it refuses to unlink a dir): the closest
577        // portable stand-in for a locked / permission-denied temp file.
578        let dir = tempdir().unwrap();
579        let stubborn = dir.path().join("subdir");
580        std::fs::create_dir(&stubborn).unwrap();
581        let out = crate::test_support::capture(move || remove_temp_file(&stubborn));
582        assert!(
583            out.contains("failed to remove temp file"),
584            "a failed removal must surface in the logs: {out:?}"
585        );
586        assert!(
587            out.contains("subdir"),
588            "the warning must name the offending path: {out:?}"
589        );
590        assert!(
591            out.contains("cleanup"),
592            "the warning should tag the cleanup op: {out:?}"
593        );
594    }
595
596    #[test]
597    fn temp_file_guard_removes_every_registered_file_on_drop() {
598        let dir = tempdir().unwrap();
599        let out = dir.path().join("out.webp");
600        let init = dir.path().join("out-init.png");
601        std::fs::write(&out, b"image").unwrap();
602        std::fs::write(&init, b"init").unwrap();
603        {
604            let mut guard = TempFileGuard::new();
605            guard.push(out.clone());
606            guard.push(init.clone());
607            assert!(out.exists() && init.exists(), "files present before drop");
608        }
609        assert!(!out.exists(), "output temp must be removed on drop");
610        assert!(!init.exists(), "init-image temp must be removed on drop");
611    }
612
613    #[test]
614    fn temp_file_guard_tolerates_a_file_that_never_materialised() {
615        // A path registered before its download runs (so an early
616        // failure drops a guard pointing at a file that never existed)
617        // is the desired end state, not a cleanup warning.
618        let dir = tempdir().unwrap();
619        let missing = dir.path().join("never-written.webp");
620        let out = crate::test_support::capture(move || {
621            let mut guard = TempFileGuard::new();
622            guard.push(missing);
623            drop(guard);
624        });
625        assert!(
626            !out.contains("failed to remove temp file"),
627            "a never-created temp file must not warn on cleanup: {out:?}"
628        );
629    }
630}