Skip to main content

studio_worker/engine/
download.rs

1//! Shared model-file provisioning used by every real engine.
2//!
3//! The studio attaches a [`ModelSource`](crate::types::ModelSource) to
4//! each real offer listing the files the worker needs (diffusion model,
5//! GGUF, VAE, ...) with a public URL + filename each.  Engines fetch
6//! them on first use and cache them under their per-engine directory, so
7//! a fresh worker provisions itself with no manual model placement.
8//!
9//! The streamed body is checked against the server's `Content-Length`,
10//! so a truncated download is rejected and cleaned up instead of being
11//! renamed into place as a corrupt model that every later job fails to
12//! load.
13//!
14//! Every download emits a structured `tracing` breadcrumb at the
15//! `studio_worker::engine::download` target: `info` on `starting` and
16//! `done`, and a symmetric `warn` on each failure (non-success status,
17//! a streaming error, or a length / sha256 mismatch) so an operator
18//! never sees a dangling `starting` with no terminal event explaining
19//! what went wrong — mirroring the `ApiClient` HTTP surface.
20
21use anyhow::{bail, Context, Result};
22use sha2::{Digest, Sha256};
23use std::io::Write;
24use std::path::{Component, Path, PathBuf};
25use std::time::Instant;
26use tracing::{info, warn};
27
28use crate::types::ModelFile;
29
30/// Tracing target for model downloads.  Stable so operators can filter
31/// with `RUST_LOG=studio_worker::engine::download=debug`.
32const TRACE_TARGET: &str = "studio_worker::engine::download";
33
34/// HTTP client timeout per request — a GGUF / safetensors file is up to
35/// a few GiB so a 30-minute ceiling is generous.
36const DOWNLOAD_TIMEOUT_SECS: u64 = 30 * 60;
37
38/// Resolve `filename` to a path inside `dir`, refusing anything that
39/// is not a plain file name (no `/`, `\`, `..`, or absolute paths) so a
40/// malicious or buggy `ModelSource` can't write outside the cache.
41pub fn model_cache_path(dir: &Path, filename: &str) -> Result<PathBuf> {
42    let path = Path::new(filename);
43    let mut components = path.components();
44    match (components.next(), components.next()) {
45        (Some(Component::Normal(name)), None)
46            if !filename.contains('/') && !filename.contains('\\') =>
47        {
48            Ok(dir.join(name))
49        }
50        _ => bail!("model filename must be a plain file name: {filename:?}"),
51    }
52}
53
54/// Verify a streamed download wrote exactly the body the server
55/// promised.  `expected` is the response's `Content-Length`; it is
56/// `None` for chunked transfers, where there's nothing to check and we
57/// accept whatever arrived.  A mismatch in either direction means the
58/// download is truncated or corrupt, so we surface a clear error rather
59/// than cache a bad model.
60pub fn verify_download_len(copied: u64, expected: Option<u64>) -> Result<()> {
61    match expected {
62        Some(expected) if copied != expected => bail!(
63            "size mismatch: wrote {copied} bytes but the server declared \
64             Content-Length {expected} (download truncated or corrupt)"
65        ),
66        _ => Ok(()),
67    }
68}
69
70/// Verify a downloaded body's sha256 against the registry's expected
71/// hex digest (case-insensitive).  `None` means the registry row
72/// predates integrity hashes — nothing to check.  A mismatch means a
73/// corrupted or tampered body that must never be committed to the
74/// cache.
75pub fn verify_sha256(actual_hex: &str, expected: Option<&str>) -> Result<()> {
76    match expected {
77        Some(expected) if !actual_hex.eq_ignore_ascii_case(expected.trim()) => bail!(
78            "sha256 mismatch: downloaded body hashes to {actual_hex} but the registry \
79             expects {expected} (corrupted or tampered download)"
80        ),
81        _ => Ok(()),
82    }
83}
84
85/// Writer adapter that feeds every chunk through a [`Sha256`] hasher
86/// on its way to the underlying file, so verification needs no second
87/// read pass over a multi-GiB model.
88struct HashingWriter<W: Write> {
89    inner: W,
90    hasher: Sha256,
91}
92
93impl<W: Write> Write for HashingWriter<W> {
94    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
95        let written = self.inner.write(buf)?;
96        self.hasher.update(&buf[..written]);
97        Ok(written)
98    }
99
100    fn flush(&mut self) -> std::io::Result<()> {
101        self.inner.flush()
102    }
103}
104
105/// Sniff an image's container format from its leading magic bytes and
106/// return the file extension `sd-cli` expects for it, or `None` when
107/// the bytes match no format we hand to `sd-cli`.
108///
109/// `sd-cli`'s `media_io` loader picks its decoder purely from the file
110/// **extension**, not the content — so a JPEG saved as `foo.webp`, or a
111/// webp saved as `foo.png`, fails with `load image from '...' failed`.
112/// The studio serves asset URLs like `latest.webp` whose bytes are
113/// often actually JPEG, so the worker must name the on-disk tempfile
114/// after the real content for the decoder to pick correctly.
115pub fn sniff_image_extension(bytes: &[u8]) -> Option<&'static str> {
116    let starts = |sig: &[u8]| bytes.len() >= sig.len() && &bytes[..sig.len()] == sig;
117    if starts(&[0xff, 0xd8, 0xff]) {
118        Some("jpg")
119    } else if starts(&[0x89, b'P', b'N', b'G', 0x0d, 0x0a, 0x1a, 0x0a]) {
120        Some("png")
121    } else if bytes.len() >= 12 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WEBP" {
122        Some("webp")
123    } else if starts(b"GIF87a") || starts(b"GIF89a") {
124        Some("gif")
125    } else if starts(b"BM") {
126        Some("bmp")
127    } else if starts(&[0x49, 0x49, 0x2a, 0x00]) || starts(&[0x4d, 0x4d, 0x00, 0x2a]) {
128        Some("tif")
129    } else {
130        None
131    }
132}
133
134/// Make a downloaded input image (init / mask / reference) safe to hand
135/// to `sd-cli` by naming it after its **actual** content format.
136///
137/// The worker first names the tempfile from the URL's extension, but
138/// studio asset URLs lie (`latest.webp` is frequently JPEG bytes).
139/// `sd-cli` selects its image decoder from the file extension, so a
140/// mismatched name makes every img2img / edit / inpaint job fail with
141/// `load image from '...' failed`.  Here we sniff the real format from
142/// the file's magic bytes and, when it disagrees with the current
143/// extension, rename the file to a sibling with the correct one,
144/// returning the path the engine should consume.  Unknown or
145/// already-correct content passes straight through.
146///
147/// The caller owns cleanup: when the returned path differs from the
148/// input it is the same bytes under a new name, so it (not the
149/// original) must be registered with the job's [`TempFileGuard`].
150pub fn ensure_correct_image_extension(path: &Path) -> Result<PathBuf> {
151    let mut header = [0u8; 16];
152    let read = {
153        use std::io::Read;
154        let mut file = std::fs::File::open(path)
155            .with_context(|| format!("opening input image {}", path.display()))?;
156        file.read(&mut header)
157            .with_context(|| format!("reading input image header {}", path.display()))?
158    };
159    let Some(actual_ext) = sniff_image_extension(&header[..read]) else {
160        return Ok(path.to_path_buf());
161    };
162    let current_ext = path
163        .extension()
164        .and_then(|e| e.to_str())
165        .map(|e| e.to_ascii_lowercase());
166    // `jpeg` and `jpg` are the same decoder to sd-cli — don't churn the
167    // file when only the spelling differs.
168    let matches = current_ext.as_deref() == Some(actual_ext)
169        || (actual_ext == "jpg" && current_ext.as_deref() == Some("jpeg"));
170    if matches {
171        return Ok(path.to_path_buf());
172    }
173    let corrected = path.with_extension(actual_ext);
174    std::fs::rename(path, &corrected)
175        .with_context(|| format!("renaming {} -> {}", path.display(), corrected.display()))?;
176    info!(
177        target: TRACE_TARGET,
178        op = "sniff",
179        from = %path.display(),
180        to = %corrected.display(),
181        actual_ext,
182        "renamed input image to match its actual format for sd-cli"
183    );
184    Ok(corrected)
185}
186
187/// Best-effort removal of a temporary file — a partial `.part`
188/// download, an engine's per-job scratch image, or a downloaded init /
189/// mask.  A `NotFound` is the desired end state (something already
190/// cleaned it up); any other failure is surfaced so a stuck temp file
191/// can't silently fill the worker's disk over a long session.
192pub fn remove_temp_file(path: &Path) {
193    if let Err(e) = std::fs::remove_file(path) {
194        if e.kind() != std::io::ErrorKind::NotFound {
195            warn!(
196                target: TRACE_TARGET,
197                op = "cleanup",
198                path = %path.display(),
199                error = %e,
200                "failed to remove temp file"
201            );
202        }
203    }
204}
205
206/// RAII owner of a job's scratch files.  Registering a job's temp
207/// paths up front means every exit path — the success return, an
208/// engine error, even a panic mid-dispatch — removes them on drop
209/// instead of leaking them into the temp dir and slowly filling the
210/// worker's disk over a long-running session.  Removal is best-effort
211/// via [`remove_temp_file`], so a path that never materialised (the
212/// job failed before the file was written) is silently tolerated.
213#[derive(Default)]
214pub struct TempFileGuard {
215    paths: Vec<PathBuf>,
216}
217
218impl TempFileGuard {
219    pub fn new() -> Self {
220        Self { paths: Vec::new() }
221    }
222
223    /// Register a path to be removed when the guard drops.
224    pub fn push(&mut self, path: PathBuf) {
225        self.paths.push(path);
226    }
227}
228
229impl Drop for TempFileGuard {
230    fn drop(&mut self) {
231        for path in &self.paths {
232            remove_temp_file(path);
233        }
234    }
235}
236
237/// Ensure `file.filename` is present under `dir`, downloading it from
238/// `file.url` when missing (verified against `file.sha256` when the
239/// registry provides one).  Returns the resolved local path.
240#[cfg_attr(coverage_nightly, coverage(off))]
241pub fn ensure_file(dir: &Path, file: &ModelFile) -> Result<PathBuf> {
242    let filename = file.filename.as_str();
243    let url = file.url.as_str();
244    let local = model_cache_path(dir, filename)?;
245    if local.is_file() {
246        tracing::debug!(
247            target: TRACE_TARGET,
248            op = "ensure_file",
249            filename,
250            path = %local.display(),
251            "cached"
252        );
253        return Ok(local);
254    }
255    download_file_verified(url, &local, file.sha256.as_deref())
256        .with_context(|| format!("downloading {filename} ({url}) -> {}", local.display()))?;
257    Ok(local)
258}
259
260/// Stream `url` into `dest` (atomic via a `.part` rename so a killed
261/// download doesn't leave a half-written file on disk).
262///
263/// Excluded from coverage: requires real network + filesystem (and a
264/// multi-GiB download per model on the happy path).  Exercised
265/// end-to-end via the live dev loop; the pure guards
266/// ([`verify_download_len`], [`model_cache_path`]) are unit-tested.
267#[cfg_attr(coverage_nightly, coverage(off))]
268pub fn download_file(url: &str, dest: &Path) -> Result<()> {
269    download_file_verified(url, dest, None)
270}
271
272/// [`download_file`] with an optional expected sha256 — the body is
273/// hashed while it streams and a mismatch is rejected before the
274/// rename, so a bad body never lands in the cache.
275#[cfg_attr(coverage_nightly, coverage(off))]
276pub fn download_file_verified(url: &str, dest: &Path, expected_sha256: Option<&str>) -> Result<()> {
277    if let Some(parent) = dest.parent() {
278        std::fs::create_dir_all(parent)
279            .with_context(|| format!("creating {}", parent.display()))?;
280    }
281    let part = dest.with_extension("part");
282    let client = reqwest::blocking::Client::builder()
283        .timeout(std::time::Duration::from_secs(DOWNLOAD_TIMEOUT_SECS))
284        .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
285        .build()?;
286    info!(
287        target: TRACE_TARGET,
288        op = "download",
289        url,
290        dest = %dest.display(),
291        "starting"
292    );
293    let started = Instant::now();
294    let mut response = match client.get(url).send() {
295        Ok(response) => response,
296        Err(e) => {
297            // A connection-level failure (DNS, TLS, timeout, or a
298            // connection closed before the declared body completed)
299            // must leave the same terminal breadcrumb as the other
300            // failure modes below — otherwise an operator filtering
301            // this target sees the "starting" line then silence.
302            warn!(
303                target: TRACE_TARGET,
304                op = "download",
305                url,
306                dest = %dest.display(),
307                elapsed_ms = started.elapsed().as_millis() as u64,
308                error = %e,
309                "download failed: request error"
310            );
311            return Err(e).context("GET");
312        }
313    };
314    let status = response.status();
315    if !status.is_success() {
316        warn!(
317            target: TRACE_TARGET,
318            op = "download",
319            url,
320            dest = %dest.display(),
321            status = status.as_u16(),
322            elapsed_ms = started.elapsed().as_millis() as u64,
323            "download failed: non-success status"
324        );
325        bail!("GET {url} -> {status}");
326    }
327    let expected_len = response.content_length();
328    let file =
329        std::fs::File::create(&part).with_context(|| format!("creating {}", part.display()))?;
330    let mut writer = HashingWriter {
331        inner: file,
332        hasher: Sha256::new(),
333    };
334    let copied = std::io::copy(&mut response, &mut writer);
335    let digest = writer.hasher.finalize();
336    // Close the handle before any remove / rename so cleanup works on
337    // Windows, where an open file can't be unlinked.
338    drop(writer.inner);
339    let bytes = match copied {
340        Ok(bytes) => bytes,
341        Err(e) => {
342            remove_temp_file(&part);
343            warn!(
344                target: TRACE_TARGET,
345                op = "download",
346                url,
347                dest = %dest.display(),
348                elapsed_ms = started.elapsed().as_millis() as u64,
349                error = %e,
350                "download failed: streaming body"
351            );
352            return Err(e).context("streaming body");
353        }
354    };
355    if let Err(e) = verify_download_len(bytes, expected_len) {
356        remove_temp_file(&part);
357        warn!(
358            target: TRACE_TARGET,
359            op = "download",
360            url,
361            dest = %dest.display(),
362            bytes,
363            elapsed_ms = started.elapsed().as_millis() as u64,
364            error = %e,
365            "download failed: size mismatch"
366        );
367        return Err(e).with_context(|| format!("downloading {url}"));
368    }
369    let actual_hex: String = digest.iter().map(|b| format!("{b:02x}")).collect();
370    if let Err(e) = verify_sha256(&actual_hex, expected_sha256) {
371        remove_temp_file(&part);
372        warn!(
373            target: TRACE_TARGET,
374            op = "download",
375            url,
376            dest = %dest.display(),
377            bytes,
378            elapsed_ms = started.elapsed().as_millis() as u64,
379            error = %e,
380            "download failed: sha256 mismatch"
381        );
382        return Err(e).with_context(|| format!("downloading {url}"));
383    }
384    std::fs::rename(&part, dest)
385        .with_context(|| format!("renaming {} -> {}", part.display(), dest.display()))?;
386    let elapsed_ms = started.elapsed().as_millis() as u64;
387    info!(
388        target: TRACE_TARGET,
389        op = "download",
390        url,
391        dest = %dest.display(),
392        bytes,
393        elapsed_ms,
394        "done"
395    );
396    Ok(())
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use tempfile::tempdir;
403
404    // -----------------------------------------------------------------
405    // sniff_image_extension / ensure_correct_image_extension — the guard
406    // that names a downloaded base after its real content so sd-cli's
407    // extension-keyed `media_io` decoder picks the right codec.  Studio
408    // asset URLs lie (`latest.webp` is often JPEG bytes); a mismatched
409    // name was failing every img2img / edit / inpaint job with
410    // `load image from '...' failed`.
411    // -----------------------------------------------------------------
412
413    /// A tiny lossy-VP8 webp (one of the formats studio bases arrive
414    /// in) used to exercise the webp signature branch.
415    const LOSSY_WEBP: &[u8] = include_bytes!("../../tests/fixtures/lossy-vp8.webp");
416
417    #[test]
418    fn sniff_image_extension_maps_each_magic_to_an_sd_cli_extension() {
419        assert_eq!(sniff_image_extension(LOSSY_WEBP), Some("webp"));
420        assert_eq!(
421            sniff_image_extension(&[0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10]),
422            Some("jpg"),
423            "JPEG (the bytes studio serves under .webp URLs)"
424        );
425        assert_eq!(
426            sniff_image_extension(&[0x89, b'P', b'N', b'G', 0x0d, 0x0a, 0x1a, 0x0a]),
427            Some("png")
428        );
429        assert_eq!(sniff_image_extension(b"GIF89a..."), Some("gif"));
430        assert_eq!(sniff_image_extension(b"BM......"), Some("bmp"));
431        assert_eq!(
432            sniff_image_extension(&[0x49, 0x49, 0x2a, 0x00]),
433            Some("tif")
434        );
435        // A RIFF container that is not WEBP (e.g. a WAV) is not an image.
436        assert_eq!(sniff_image_extension(b"RIFF\x00\x00\x00\x00WAVEfmt "), None);
437        // Unknown / too-short content yields no opinion.
438        assert_eq!(sniff_image_extension(b"\x00\x01\x02"), None);
439        assert_eq!(sniff_image_extension(b""), None);
440    }
441
442    #[test]
443    fn ensure_correct_image_extension_renames_jpeg_served_as_webp() {
444        // The exact prod failure: bytes are JPEG but the file is named
445        // `.webp` (from the lying URL).  It must be renamed to `.jpg`.
446        let dir = tempdir().unwrap();
447        let mislabelled = dir.path().join("out-init.webp");
448        std::fs::write(
449            &mislabelled,
450            [0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 0x4a, 0x46],
451        )
452        .unwrap();
453
454        let corrected = ensure_correct_image_extension(&mislabelled).unwrap();
455
456        assert_eq!(corrected, dir.path().join("out-init.jpg"));
457        assert!(corrected.exists(), "renamed file carries the bytes");
458        assert!(
459            !mislabelled.exists(),
460            "the misnamed file is gone after rename"
461        );
462    }
463
464    #[test]
465    fn ensure_correct_image_extension_renames_webp_served_as_png() {
466        let dir = tempdir().unwrap();
467        let mislabelled = dir.path().join("out-init.png");
468        std::fs::write(&mislabelled, LOSSY_WEBP).unwrap();
469
470        let corrected = ensure_correct_image_extension(&mislabelled).unwrap();
471
472        assert_eq!(corrected, dir.path().join("out-init.webp"));
473        assert!(corrected.exists() && !mislabelled.exists());
474    }
475
476    #[test]
477    fn ensure_correct_image_extension_leaves_correct_or_unknown_files_in_place() {
478        let dir = tempdir().unwrap();
479        // Already-correct png: returned verbatim, not renamed.
480        let png = dir.path().join("out-mask.png");
481        std::fs::write(&png, [0x89, b'P', b'N', b'G', 0x0d, 0x0a, 0x1a, 0x0a]).unwrap();
482        assert_eq!(ensure_correct_image_extension(&png).unwrap(), png);
483        assert!(png.exists());
484
485        // `.jpeg` spelling for JPEG content is not churned to `.jpg`.
486        let jpeg = dir.path().join("out-ref.jpeg");
487        std::fs::write(&jpeg, [0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10]).unwrap();
488        assert_eq!(ensure_correct_image_extension(&jpeg).unwrap(), jpeg);
489        assert!(jpeg.exists() && !dir.path().join("out-ref.jpg").exists());
490
491        // Unknown content (no recognised magic) passes through untouched.
492        let unknown = dir.path().join("out-init.webp");
493        std::fs::write(&unknown, [0x00, 0x01, 0x02, 0x03]).unwrap();
494        assert_eq!(ensure_correct_image_extension(&unknown).unwrap(), unknown);
495        assert!(unknown.exists());
496    }
497
498    #[test]
499    fn model_cache_path_accepts_plain_filenames_only() {
500        let root = Path::new("/models");
501        assert_eq!(
502            model_cache_path(root, "model.gguf").unwrap(),
503            PathBuf::from("/models/model.gguf")
504        );
505        assert!(model_cache_path(root, "../outside.gguf").is_err());
506        assert!(model_cache_path(root, "nested/model.gguf").is_err());
507        assert!(model_cache_path(root, "/tmp/model.gguf").is_err());
508        assert!(model_cache_path(root, r"nested\model.gguf").is_err());
509        assert!(model_cache_path(root, "").is_err());
510    }
511
512    #[test]
513    fn verify_download_len_accepts_exact_match() {
514        assert!(verify_download_len(2_700_000_000, Some(2_700_000_000)).is_ok());
515    }
516
517    #[test]
518    fn verify_download_len_accepts_when_length_unknown() {
519        assert!(verify_download_len(123, None).is_ok());
520    }
521
522    #[test]
523    fn verify_download_len_rejects_truncated_download() {
524        let err = verify_download_len(40, Some(100)).unwrap_err().to_string();
525        assert!(err.contains("size mismatch"), "got: {err}");
526        assert!(err.contains("40"), "got: {err}");
527        assert!(err.contains("100"), "got: {err}");
528    }
529
530    #[test]
531    fn verify_download_len_rejects_overlong_download() {
532        assert!(verify_download_len(120, Some(100)).is_err());
533    }
534
535    fn test_file(filename: &str, url: &str) -> ModelFile {
536        ModelFile {
537            role: crate::types::ModelFileRole::Model,
538            url: url.to_string(),
539            filename: filename.to_string(),
540            approx_bytes: None,
541            sha256: None,
542        }
543    }
544
545    #[test]
546    fn ensure_file_returns_cached_path_without_network() {
547        // A file already present must be returned as-is — `ensure_file`
548        // never touches the network, so an unreachable URL is fine.
549        let dir = tempdir().unwrap();
550        std::fs::write(dir.path().join("cached.gguf"), b"already here").unwrap();
551        let path = ensure_file(
552            dir.path(),
553            &test_file("cached.gguf", "https://example.invalid/x"),
554        )
555        .unwrap();
556        assert_eq!(path, dir.path().join("cached.gguf"));
557        assert_eq!(std::fs::read(&path).unwrap(), b"already here");
558    }
559
560    #[test]
561    fn ensure_file_rejects_path_traversal_before_any_network() {
562        let dir = tempdir().unwrap();
563        let err = ensure_file(
564            dir.path(),
565            &test_file("../escape.gguf", "https://example.invalid/x"),
566        )
567        .unwrap_err()
568        .to_string();
569        assert!(err.contains("plain file name"), "got: {err}");
570    }
571
572    // -----------------------------------------------------------------
573    // verify_sha256 — the integrity gate for registry-pinned hashes.
574    // -----------------------------------------------------------------
575
576    #[test]
577    fn verify_sha256_accepts_match_and_absence() {
578        assert!(verify_sha256("abc123", Some("abc123")).is_ok());
579        assert!(
580            verify_sha256("abc123", Some("ABC123")).is_ok(),
581            "case-insensitive"
582        );
583        assert!(
584            verify_sha256("abc123", Some(" abc123 ")).is_ok(),
585            "whitespace-tolerant"
586        );
587        assert!(
588            verify_sha256("abc123", None).is_ok(),
589            "legacy rows have no hash"
590        );
591    }
592
593    #[test]
594    fn verify_sha256_rejects_mismatch() {
595        let err = verify_sha256("abc123", Some("def456"))
596            .unwrap_err()
597            .to_string();
598        assert!(err.contains("sha256 mismatch"), "got: {err}");
599        assert!(
600            err.contains("abc123") && err.contains("def456"),
601            "must name both digests: {err}"
602        );
603    }
604
605    // -----------------------------------------------------------------
606    // HashingWriter — streams the body into the cache file while
607    // computing the sha256 that `verify_sha256` later checks.  The
608    // integrity guarantee hinges on hashing *exactly* the bytes the
609    // inner writer accepted: `write` slices `&buf[..written]`, so a
610    // short write (inner takes only a prefix) must hash only that
611    // prefix — the unwritten tail is re-offered by `io::copy` on the
612    // next call.  Hashing the whole `buf` on a short write would
613    // silently corrupt every digest and turn the integrity gate into a
614    // false-reject.  The download integration test wraps a real `File`,
615    // which never short-writes, so this prefix branch is only reachable
616    // here.
617    // -----------------------------------------------------------------
618
619    /// A writer that accepts at most `max_per_write` bytes per call (to
620    /// model a short write) and counts `flush` calls.
621    struct ProbeWriter {
622        sink: Vec<u8>,
623        max_per_write: usize,
624        flushes: usize,
625    }
626
627    impl Write for ProbeWriter {
628        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
629            let take = buf.len().min(self.max_per_write);
630            self.sink.extend_from_slice(&buf[..take]);
631            Ok(take)
632        }
633        fn flush(&mut self) -> std::io::Result<()> {
634            self.flushes += 1;
635            Ok(())
636        }
637    }
638
639    fn hex(bytes: &[u8]) -> String {
640        bytes.iter().map(|b| format!("{b:02x}")).collect()
641    }
642
643    #[test]
644    fn hashing_writer_hashes_only_the_bytes_the_inner_accepted() {
645        // The inner writer takes only 3 of the 8 offered bytes, so the
646        // hasher must absorb just "abc" — proving the `&buf[..written]`
647        // slice.  If `write` hashed the whole `buf`, the digest would be
648        // sha256("abcdefgh") and this assertion would fail.
649        let mut writer = HashingWriter {
650            inner: ProbeWriter {
651                sink: Vec::new(),
652                max_per_write: 3,
653                flushes: 0,
654            },
655            hasher: Sha256::new(),
656        };
657        let written = writer.write(b"abcdefgh").unwrap();
658        assert_eq!(written, 3, "inner accepts at most 3 bytes per write");
659        assert_eq!(writer.inner.sink, b"abc", "only the prefix reaches inner");
660        assert_eq!(
661            hex(&writer.hasher.finalize()),
662            hex(&Sha256::digest(b"abc")),
663            "hash covers only the accepted prefix"
664        );
665    }
666
667    #[test]
668    fn hashing_writer_digest_matches_a_short_writing_stream_end_to_end() {
669        // Drive the writer the way `download_file_verified` does — via
670        // `io::copy`, which re-offers the unwritten tail — through an
671        // inner that only takes 4 bytes at a time.  The streamed bytes
672        // and the final digest must both equal the full source, with no
673        // double-hashing across the re-offered chunks.
674        let source = b"the quick brown model weights".to_vec();
675        let mut reader = source.as_slice();
676        let mut writer = HashingWriter {
677            inner: ProbeWriter {
678                sink: Vec::new(),
679                max_per_write: 4,
680                flushes: 0,
681            },
682            hasher: Sha256::new(),
683        };
684        let copied = std::io::copy(&mut reader, &mut writer).unwrap();
685        assert_eq!(copied as usize, source.len());
686        assert_eq!(
687            writer.inner.sink, source,
688            "every byte reaches the cache file"
689        );
690        assert_eq!(
691            hex(&writer.hasher.finalize()),
692            hex(&Sha256::digest(&source)),
693            "digest matches the full body"
694        );
695    }
696
697    #[test]
698    fn hashing_writer_flush_delegates_to_the_inner_writer() {
699        let mut writer = HashingWriter {
700            inner: ProbeWriter {
701                sink: Vec::new(),
702                max_per_write: usize::MAX,
703                flushes: 0,
704            },
705            hasher: Sha256::new(),
706        };
707        writer.flush().unwrap();
708        writer.flush().unwrap();
709        assert_eq!(writer.inner.flushes, 2, "flush is forwarded to inner");
710    }
711
712    // -----------------------------------------------------------------
713    // remove_temp_file + TempFileGuard — the shared best-effort cleanup
714    // primitives every engine routes its per-job scratch files through.
715    // Owned here (the shared engine-provisioning module) so the sdcpp
716    // output guard and the onnx init/mask cleanup share one tested
717    // implementation instead of each rolling its own silent removal.
718    // -----------------------------------------------------------------
719
720    #[test]
721    fn remove_temp_file_deletes_an_existing_file_quietly() {
722        let dir = tempdir().unwrap();
723        let f = dir.path().join("artefact.webp");
724        std::fs::write(&f, b"bytes").unwrap();
725        let out = crate::test_support::capture({
726            let f = f.clone();
727            move || remove_temp_file(&f)
728        });
729        assert!(!f.exists(), "file should be gone after cleanup");
730        assert!(
731            !out.contains("failed to remove temp file"),
732            "the success path must not warn: {out:?}"
733        );
734    }
735
736    #[test]
737    fn remove_temp_file_ignores_a_missing_file() {
738        let dir = tempdir().unwrap();
739        let out = crate::test_support::capture({
740            let missing = dir.path().join("never.part");
741            move || remove_temp_file(&missing)
742        });
743        assert!(
744            !out.contains("failed to remove temp file"),
745            "a not-found temp file is the desired end state: {out:?}"
746        );
747    }
748
749    #[test]
750    fn remove_temp_file_surfaces_a_failed_removal() {
751        // Pointing the helper at a directory makes `remove_file` fail on
752        // every platform (it refuses to unlink a dir): the closest
753        // portable stand-in for a locked / permission-denied temp file.
754        let dir = tempdir().unwrap();
755        let stubborn = dir.path().join("subdir");
756        std::fs::create_dir(&stubborn).unwrap();
757        let out = crate::test_support::capture(move || remove_temp_file(&stubborn));
758        assert!(
759            out.contains("failed to remove temp file"),
760            "a failed removal must surface in the logs: {out:?}"
761        );
762        assert!(
763            out.contains("subdir"),
764            "the warning must name the offending path: {out:?}"
765        );
766        assert!(
767            out.contains("cleanup"),
768            "the warning should tag the cleanup op: {out:?}"
769        );
770    }
771
772    #[test]
773    fn temp_file_guard_removes_every_registered_file_on_drop() {
774        let dir = tempdir().unwrap();
775        let out = dir.path().join("out.webp");
776        let init = dir.path().join("out-init.png");
777        std::fs::write(&out, b"image").unwrap();
778        std::fs::write(&init, b"init").unwrap();
779        {
780            let mut guard = TempFileGuard::new();
781            guard.push(out.clone());
782            guard.push(init.clone());
783            assert!(out.exists() && init.exists(), "files present before drop");
784        }
785        assert!(!out.exists(), "output temp must be removed on drop");
786        assert!(!init.exists(), "init-image temp must be removed on drop");
787    }
788
789    #[test]
790    fn temp_file_guard_tolerates_a_file_that_never_materialised() {
791        // A path registered before its download runs (so an early
792        // failure drops a guard pointing at a file that never existed)
793        // is the desired end state, not a cleanup warning.
794        let dir = tempdir().unwrap();
795        let missing = dir.path().join("never-written.webp");
796        let out = crate::test_support::capture(move || {
797            let mut guard = TempFileGuard::new();
798            guard.push(missing);
799            drop(guard);
800        });
801        assert!(
802            !out.contains("failed to remove temp file"),
803            "a never-created temp file must not warn on cleanup: {out:?}"
804        );
805    }
806}