1use anyhow::{bail, Context, Result};
22use sha2::{Digest, Sha256};
23use std::io::Write;
24use std::path::{Component, Path, PathBuf};
25use std::time::Instant;
26use tracing::{info, warn};
27
28use crate::types::ModelFile;
29
30const TRACE_TARGET: &str = "studio_worker::engine::download";
33
34const DOWNLOAD_TIMEOUT_SECS: u64 = 30 * 60;
37
38pub fn model_cache_path(dir: &Path, filename: &str) -> Result<PathBuf> {
42 let path = Path::new(filename);
43 let mut components = path.components();
44 match (components.next(), components.next()) {
45 (Some(Component::Normal(name)), None)
46 if !filename.contains('/') && !filename.contains('\\') =>
47 {
48 Ok(dir.join(name))
49 }
50 _ => bail!("model filename must be a plain file name: {filename:?}"),
51 }
52}
53
54pub fn verify_download_len(copied: u64, expected: Option<u64>) -> Result<()> {
61 match expected {
62 Some(expected) if copied != expected => bail!(
63 "size mismatch: wrote {copied} bytes but the server declared \
64 Content-Length {expected} (download truncated or corrupt)"
65 ),
66 _ => Ok(()),
67 }
68}
69
70pub fn verify_sha256(actual_hex: &str, expected: Option<&str>) -> Result<()> {
76 match expected {
77 Some(expected) if !actual_hex.eq_ignore_ascii_case(expected.trim()) => bail!(
78 "sha256 mismatch: downloaded body hashes to {actual_hex} but the registry \
79 expects {expected} (corrupted or tampered download)"
80 ),
81 _ => Ok(()),
82 }
83}
84
85struct HashingWriter<W: Write> {
89 inner: W,
90 hasher: Sha256,
91}
92
93impl<W: Write> Write for HashingWriter<W> {
94 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
95 let written = self.inner.write(buf)?;
96 self.hasher.update(&buf[..written]);
97 Ok(written)
98 }
99
100 fn flush(&mut self) -> std::io::Result<()> {
101 self.inner.flush()
102 }
103}
104
105pub fn sniff_image_extension(bytes: &[u8]) -> Option<&'static str> {
116 let starts = |sig: &[u8]| bytes.len() >= sig.len() && &bytes[..sig.len()] == sig;
117 if starts(&[0xff, 0xd8, 0xff]) {
118 Some("jpg")
119 } else if starts(&[0x89, b'P', b'N', b'G', 0x0d, 0x0a, 0x1a, 0x0a]) {
120 Some("png")
121 } else if bytes.len() >= 12 && &bytes[0..4] == b"RIFF" && &bytes[8..12] == b"WEBP" {
122 Some("webp")
123 } else if starts(b"GIF87a") || starts(b"GIF89a") {
124 Some("gif")
125 } else if starts(b"BM") {
126 Some("bmp")
127 } else if starts(&[0x49, 0x49, 0x2a, 0x00]) || starts(&[0x4d, 0x4d, 0x00, 0x2a]) {
128 Some("tif")
129 } else {
130 None
131 }
132}
133
134pub fn ensure_correct_image_extension(path: &Path) -> Result<PathBuf> {
151 let mut header = [0u8; 16];
152 let read = {
153 use std::io::Read;
154 let mut file = std::fs::File::open(path)
155 .with_context(|| format!("opening input image {}", path.display()))?;
156 file.read(&mut header)
157 .with_context(|| format!("reading input image header {}", path.display()))?
158 };
159 let Some(actual_ext) = sniff_image_extension(&header[..read]) else {
160 return Ok(path.to_path_buf());
161 };
162 let current_ext = path
163 .extension()
164 .and_then(|e| e.to_str())
165 .map(|e| e.to_ascii_lowercase());
166 let matches = current_ext.as_deref() == Some(actual_ext)
169 || (actual_ext == "jpg" && current_ext.as_deref() == Some("jpeg"));
170 if matches {
171 return Ok(path.to_path_buf());
172 }
173 let corrected = path.with_extension(actual_ext);
174 std::fs::rename(path, &corrected)
175 .with_context(|| format!("renaming {} -> {}", path.display(), corrected.display()))?;
176 info!(
177 target: TRACE_TARGET,
178 op = "sniff",
179 from = %path.display(),
180 to = %corrected.display(),
181 actual_ext,
182 "renamed input image to match its actual format for sd-cli"
183 );
184 Ok(corrected)
185}
186
187pub fn remove_temp_file(path: &Path) {
193 if let Err(e) = std::fs::remove_file(path) {
194 if e.kind() != std::io::ErrorKind::NotFound {
195 warn!(
196 target: TRACE_TARGET,
197 op = "cleanup",
198 path = %path.display(),
199 error = %e,
200 "failed to remove temp file"
201 );
202 }
203 }
204}
205
206#[derive(Default)]
214pub struct TempFileGuard {
215 paths: Vec<PathBuf>,
216}
217
218impl TempFileGuard {
219 pub fn new() -> Self {
220 Self { paths: Vec::new() }
221 }
222
223 pub fn push(&mut self, path: PathBuf) {
225 self.paths.push(path);
226 }
227}
228
229impl Drop for TempFileGuard {
230 fn drop(&mut self) {
231 for path in &self.paths {
232 remove_temp_file(path);
233 }
234 }
235}
236
237#[cfg_attr(coverage_nightly, coverage(off))]
241pub fn ensure_file(dir: &Path, file: &ModelFile) -> Result<PathBuf> {
242 let filename = file.filename.as_str();
243 let url = file.url.as_str();
244 let local = model_cache_path(dir, filename)?;
245 if local.is_file() {
246 tracing::debug!(
247 target: TRACE_TARGET,
248 op = "ensure_file",
249 filename,
250 path = %local.display(),
251 "cached"
252 );
253 return Ok(local);
254 }
255 download_file_verified(url, &local, file.sha256.as_deref())
256 .with_context(|| format!("downloading {filename} ({url}) -> {}", local.display()))?;
257 Ok(local)
258}
259
260#[cfg_attr(coverage_nightly, coverage(off))]
268pub fn download_file(url: &str, dest: &Path) -> Result<()> {
269 download_file_verified(url, dest, None)
270}
271
272#[cfg_attr(coverage_nightly, coverage(off))]
276pub fn download_file_verified(url: &str, dest: &Path, expected_sha256: Option<&str>) -> Result<()> {
277 if let Some(parent) = dest.parent() {
278 std::fs::create_dir_all(parent)
279 .with_context(|| format!("creating {}", parent.display()))?;
280 }
281 let part = dest.with_extension("part");
282 let client = reqwest::blocking::Client::builder()
283 .timeout(std::time::Duration::from_secs(DOWNLOAD_TIMEOUT_SECS))
284 .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
285 .build()?;
286 info!(
287 target: TRACE_TARGET,
288 op = "download",
289 url,
290 dest = %dest.display(),
291 "starting"
292 );
293 let started = Instant::now();
294 let mut response = match client.get(url).send() {
295 Ok(response) => response,
296 Err(e) => {
297 warn!(
303 target: TRACE_TARGET,
304 op = "download",
305 url,
306 dest = %dest.display(),
307 elapsed_ms = started.elapsed().as_millis() as u64,
308 error = %e,
309 "download failed: request error"
310 );
311 return Err(e).context("GET");
312 }
313 };
314 let status = response.status();
315 if !status.is_success() {
316 warn!(
317 target: TRACE_TARGET,
318 op = "download",
319 url,
320 dest = %dest.display(),
321 status = status.as_u16(),
322 elapsed_ms = started.elapsed().as_millis() as u64,
323 "download failed: non-success status"
324 );
325 bail!("GET {url} -> {status}");
326 }
327 let expected_len = response.content_length();
328 let file =
329 std::fs::File::create(&part).with_context(|| format!("creating {}", part.display()))?;
330 let mut writer = HashingWriter {
331 inner: file,
332 hasher: Sha256::new(),
333 };
334 let copied = std::io::copy(&mut response, &mut writer);
335 let digest = writer.hasher.finalize();
336 drop(writer.inner);
339 let bytes = match copied {
340 Ok(bytes) => bytes,
341 Err(e) => {
342 remove_temp_file(&part);
343 warn!(
344 target: TRACE_TARGET,
345 op = "download",
346 url,
347 dest = %dest.display(),
348 elapsed_ms = started.elapsed().as_millis() as u64,
349 error = %e,
350 "download failed: streaming body"
351 );
352 return Err(e).context("streaming body");
353 }
354 };
355 if let Err(e) = verify_download_len(bytes, expected_len) {
356 remove_temp_file(&part);
357 warn!(
358 target: TRACE_TARGET,
359 op = "download",
360 url,
361 dest = %dest.display(),
362 bytes,
363 elapsed_ms = started.elapsed().as_millis() as u64,
364 error = %e,
365 "download failed: size mismatch"
366 );
367 return Err(e).with_context(|| format!("downloading {url}"));
368 }
369 let actual_hex: String = digest.iter().map(|b| format!("{b:02x}")).collect();
370 if let Err(e) = verify_sha256(&actual_hex, expected_sha256) {
371 remove_temp_file(&part);
372 warn!(
373 target: TRACE_TARGET,
374 op = "download",
375 url,
376 dest = %dest.display(),
377 bytes,
378 elapsed_ms = started.elapsed().as_millis() as u64,
379 error = %e,
380 "download failed: sha256 mismatch"
381 );
382 return Err(e).with_context(|| format!("downloading {url}"));
383 }
384 std::fs::rename(&part, dest)
385 .with_context(|| format!("renaming {} -> {}", part.display(), dest.display()))?;
386 let elapsed_ms = started.elapsed().as_millis() as u64;
387 info!(
388 target: TRACE_TARGET,
389 op = "download",
390 url,
391 dest = %dest.display(),
392 bytes,
393 elapsed_ms,
394 "done"
395 );
396 Ok(())
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use tempfile::tempdir;
403
404 const LOSSY_WEBP: &[u8] = include_bytes!("../../tests/fixtures/lossy-vp8.webp");
416
417 #[test]
418 fn sniff_image_extension_maps_each_magic_to_an_sd_cli_extension() {
419 assert_eq!(sniff_image_extension(LOSSY_WEBP), Some("webp"));
420 assert_eq!(
421 sniff_image_extension(&[0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10]),
422 Some("jpg"),
423 "JPEG (the bytes studio serves under .webp URLs)"
424 );
425 assert_eq!(
426 sniff_image_extension(&[0x89, b'P', b'N', b'G', 0x0d, 0x0a, 0x1a, 0x0a]),
427 Some("png")
428 );
429 assert_eq!(sniff_image_extension(b"GIF89a..."), Some("gif"));
430 assert_eq!(sniff_image_extension(b"BM......"), Some("bmp"));
431 assert_eq!(
432 sniff_image_extension(&[0x49, 0x49, 0x2a, 0x00]),
433 Some("tif")
434 );
435 assert_eq!(sniff_image_extension(b"RIFF\x00\x00\x00\x00WAVEfmt "), None);
437 assert_eq!(sniff_image_extension(b"\x00\x01\x02"), None);
439 assert_eq!(sniff_image_extension(b""), None);
440 }
441
442 #[test]
443 fn ensure_correct_image_extension_renames_jpeg_served_as_webp() {
444 let dir = tempdir().unwrap();
447 let mislabelled = dir.path().join("out-init.webp");
448 std::fs::write(
449 &mislabelled,
450 [0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 0x4a, 0x46],
451 )
452 .unwrap();
453
454 let corrected = ensure_correct_image_extension(&mislabelled).unwrap();
455
456 assert_eq!(corrected, dir.path().join("out-init.jpg"));
457 assert!(corrected.exists(), "renamed file carries the bytes");
458 assert!(
459 !mislabelled.exists(),
460 "the misnamed file is gone after rename"
461 );
462 }
463
464 #[test]
465 fn ensure_correct_image_extension_renames_webp_served_as_png() {
466 let dir = tempdir().unwrap();
467 let mislabelled = dir.path().join("out-init.png");
468 std::fs::write(&mislabelled, LOSSY_WEBP).unwrap();
469
470 let corrected = ensure_correct_image_extension(&mislabelled).unwrap();
471
472 assert_eq!(corrected, dir.path().join("out-init.webp"));
473 assert!(corrected.exists() && !mislabelled.exists());
474 }
475
476 #[test]
477 fn ensure_correct_image_extension_leaves_correct_or_unknown_files_in_place() {
478 let dir = tempdir().unwrap();
479 let png = dir.path().join("out-mask.png");
481 std::fs::write(&png, [0x89, b'P', b'N', b'G', 0x0d, 0x0a, 0x1a, 0x0a]).unwrap();
482 assert_eq!(ensure_correct_image_extension(&png).unwrap(), png);
483 assert!(png.exists());
484
485 let jpeg = dir.path().join("out-ref.jpeg");
487 std::fs::write(&jpeg, [0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10]).unwrap();
488 assert_eq!(ensure_correct_image_extension(&jpeg).unwrap(), jpeg);
489 assert!(jpeg.exists() && !dir.path().join("out-ref.jpg").exists());
490
491 let unknown = dir.path().join("out-init.webp");
493 std::fs::write(&unknown, [0x00, 0x01, 0x02, 0x03]).unwrap();
494 assert_eq!(ensure_correct_image_extension(&unknown).unwrap(), unknown);
495 assert!(unknown.exists());
496 }
497
498 #[test]
499 fn model_cache_path_accepts_plain_filenames_only() {
500 let root = Path::new("/models");
501 assert_eq!(
502 model_cache_path(root, "model.gguf").unwrap(),
503 PathBuf::from("/models/model.gguf")
504 );
505 assert!(model_cache_path(root, "../outside.gguf").is_err());
506 assert!(model_cache_path(root, "nested/model.gguf").is_err());
507 assert!(model_cache_path(root, "/tmp/model.gguf").is_err());
508 assert!(model_cache_path(root, r"nested\model.gguf").is_err());
509 assert!(model_cache_path(root, "").is_err());
510 }
511
512 #[test]
513 fn verify_download_len_accepts_exact_match() {
514 assert!(verify_download_len(2_700_000_000, Some(2_700_000_000)).is_ok());
515 }
516
517 #[test]
518 fn verify_download_len_accepts_when_length_unknown() {
519 assert!(verify_download_len(123, None).is_ok());
520 }
521
522 #[test]
523 fn verify_download_len_rejects_truncated_download() {
524 let err = verify_download_len(40, Some(100)).unwrap_err().to_string();
525 assert!(err.contains("size mismatch"), "got: {err}");
526 assert!(err.contains("40"), "got: {err}");
527 assert!(err.contains("100"), "got: {err}");
528 }
529
530 #[test]
531 fn verify_download_len_rejects_overlong_download() {
532 assert!(verify_download_len(120, Some(100)).is_err());
533 }
534
535 fn test_file(filename: &str, url: &str) -> ModelFile {
536 ModelFile {
537 role: crate::types::ModelFileRole::Model,
538 url: url.to_string(),
539 filename: filename.to_string(),
540 approx_bytes: None,
541 sha256: None,
542 }
543 }
544
545 #[test]
546 fn ensure_file_returns_cached_path_without_network() {
547 let dir = tempdir().unwrap();
550 std::fs::write(dir.path().join("cached.gguf"), b"already here").unwrap();
551 let path = ensure_file(
552 dir.path(),
553 &test_file("cached.gguf", "https://example.invalid/x"),
554 )
555 .unwrap();
556 assert_eq!(path, dir.path().join("cached.gguf"));
557 assert_eq!(std::fs::read(&path).unwrap(), b"already here");
558 }
559
560 #[test]
561 fn ensure_file_rejects_path_traversal_before_any_network() {
562 let dir = tempdir().unwrap();
563 let err = ensure_file(
564 dir.path(),
565 &test_file("../escape.gguf", "https://example.invalid/x"),
566 )
567 .unwrap_err()
568 .to_string();
569 assert!(err.contains("plain file name"), "got: {err}");
570 }
571
572 #[test]
577 fn verify_sha256_accepts_match_and_absence() {
578 assert!(verify_sha256("abc123", Some("abc123")).is_ok());
579 assert!(
580 verify_sha256("abc123", Some("ABC123")).is_ok(),
581 "case-insensitive"
582 );
583 assert!(
584 verify_sha256("abc123", Some(" abc123 ")).is_ok(),
585 "whitespace-tolerant"
586 );
587 assert!(
588 verify_sha256("abc123", None).is_ok(),
589 "legacy rows have no hash"
590 );
591 }
592
593 #[test]
594 fn verify_sha256_rejects_mismatch() {
595 let err = verify_sha256("abc123", Some("def456"))
596 .unwrap_err()
597 .to_string();
598 assert!(err.contains("sha256 mismatch"), "got: {err}");
599 assert!(
600 err.contains("abc123") && err.contains("def456"),
601 "must name both digests: {err}"
602 );
603 }
604
605 struct ProbeWriter {
622 sink: Vec<u8>,
623 max_per_write: usize,
624 flushes: usize,
625 }
626
627 impl Write for ProbeWriter {
628 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
629 let take = buf.len().min(self.max_per_write);
630 self.sink.extend_from_slice(&buf[..take]);
631 Ok(take)
632 }
633 fn flush(&mut self) -> std::io::Result<()> {
634 self.flushes += 1;
635 Ok(())
636 }
637 }
638
639 fn hex(bytes: &[u8]) -> String {
640 bytes.iter().map(|b| format!("{b:02x}")).collect()
641 }
642
643 #[test]
644 fn hashing_writer_hashes_only_the_bytes_the_inner_accepted() {
645 let mut writer = HashingWriter {
650 inner: ProbeWriter {
651 sink: Vec::new(),
652 max_per_write: 3,
653 flushes: 0,
654 },
655 hasher: Sha256::new(),
656 };
657 let written = writer.write(b"abcdefgh").unwrap();
658 assert_eq!(written, 3, "inner accepts at most 3 bytes per write");
659 assert_eq!(writer.inner.sink, b"abc", "only the prefix reaches inner");
660 assert_eq!(
661 hex(&writer.hasher.finalize()),
662 hex(&Sha256::digest(b"abc")),
663 "hash covers only the accepted prefix"
664 );
665 }
666
667 #[test]
668 fn hashing_writer_digest_matches_a_short_writing_stream_end_to_end() {
669 let source = b"the quick brown model weights".to_vec();
675 let mut reader = source.as_slice();
676 let mut writer = HashingWriter {
677 inner: ProbeWriter {
678 sink: Vec::new(),
679 max_per_write: 4,
680 flushes: 0,
681 },
682 hasher: Sha256::new(),
683 };
684 let copied = std::io::copy(&mut reader, &mut writer).unwrap();
685 assert_eq!(copied as usize, source.len());
686 assert_eq!(
687 writer.inner.sink, source,
688 "every byte reaches the cache file"
689 );
690 assert_eq!(
691 hex(&writer.hasher.finalize()),
692 hex(&Sha256::digest(&source)),
693 "digest matches the full body"
694 );
695 }
696
697 #[test]
698 fn hashing_writer_flush_delegates_to_the_inner_writer() {
699 let mut writer = HashingWriter {
700 inner: ProbeWriter {
701 sink: Vec::new(),
702 max_per_write: usize::MAX,
703 flushes: 0,
704 },
705 hasher: Sha256::new(),
706 };
707 writer.flush().unwrap();
708 writer.flush().unwrap();
709 assert_eq!(writer.inner.flushes, 2, "flush is forwarded to inner");
710 }
711
712 #[test]
721 fn remove_temp_file_deletes_an_existing_file_quietly() {
722 let dir = tempdir().unwrap();
723 let f = dir.path().join("artefact.webp");
724 std::fs::write(&f, b"bytes").unwrap();
725 let out = crate::test_support::capture({
726 let f = f.clone();
727 move || remove_temp_file(&f)
728 });
729 assert!(!f.exists(), "file should be gone after cleanup");
730 assert!(
731 !out.contains("failed to remove temp file"),
732 "the success path must not warn: {out:?}"
733 );
734 }
735
736 #[test]
737 fn remove_temp_file_ignores_a_missing_file() {
738 let dir = tempdir().unwrap();
739 let out = crate::test_support::capture({
740 let missing = dir.path().join("never.part");
741 move || remove_temp_file(&missing)
742 });
743 assert!(
744 !out.contains("failed to remove temp file"),
745 "a not-found temp file is the desired end state: {out:?}"
746 );
747 }
748
749 #[test]
750 fn remove_temp_file_surfaces_a_failed_removal() {
751 let dir = tempdir().unwrap();
755 let stubborn = dir.path().join("subdir");
756 std::fs::create_dir(&stubborn).unwrap();
757 let out = crate::test_support::capture(move || remove_temp_file(&stubborn));
758 assert!(
759 out.contains("failed to remove temp file"),
760 "a failed removal must surface in the logs: {out:?}"
761 );
762 assert!(
763 out.contains("subdir"),
764 "the warning must name the offending path: {out:?}"
765 );
766 assert!(
767 out.contains("cleanup"),
768 "the warning should tag the cleanup op: {out:?}"
769 );
770 }
771
772 #[test]
773 fn temp_file_guard_removes_every_registered_file_on_drop() {
774 let dir = tempdir().unwrap();
775 let out = dir.path().join("out.webp");
776 let init = dir.path().join("out-init.png");
777 std::fs::write(&out, b"image").unwrap();
778 std::fs::write(&init, b"init").unwrap();
779 {
780 let mut guard = TempFileGuard::new();
781 guard.push(out.clone());
782 guard.push(init.clone());
783 assert!(out.exists() && init.exists(), "files present before drop");
784 }
785 assert!(!out.exists(), "output temp must be removed on drop");
786 assert!(!init.exists(), "init-image temp must be removed on drop");
787 }
788
789 #[test]
790 fn temp_file_guard_tolerates_a_file_that_never_materialised() {
791 let dir = tempdir().unwrap();
795 let missing = dir.path().join("never-written.webp");
796 let out = crate::test_support::capture(move || {
797 let mut guard = TempFileGuard::new();
798 guard.push(missing);
799 drop(guard);
800 });
801 assert!(
802 !out.contains("failed to remove temp file"),
803 "a never-created temp file must not warn on cleanup: {out:?}"
804 );
805 }
806}