1use anyhow::{bail, Context, Result};
22use sha2::{Digest, Sha256};
23use std::io::Write;
24use std::path::{Component, Path, PathBuf};
25use std::time::Instant;
26use tracing::{info, warn};
27
28use crate::types::ModelFile;
29
30const TRACE_TARGET: &str = "studio_worker::engine::download";
33
34const DOWNLOAD_TIMEOUT_SECS: u64 = 30 * 60;
37
38pub fn model_cache_path(dir: &Path, filename: &str) -> Result<PathBuf> {
42 let path = Path::new(filename);
43 let mut components = path.components();
44 match (components.next(), components.next()) {
45 (Some(Component::Normal(name)), None)
46 if !filename.contains('/') && !filename.contains('\\') =>
47 {
48 Ok(dir.join(name))
49 }
50 _ => bail!("model filename must be a plain file name: {filename:?}"),
51 }
52}
53
54pub fn verify_download_len(copied: u64, expected: Option<u64>) -> Result<()> {
61 match expected {
62 Some(expected) if copied != expected => bail!(
63 "size mismatch: wrote {copied} bytes but the server declared \
64 Content-Length {expected} (download truncated or corrupt)"
65 ),
66 _ => Ok(()),
67 }
68}
69
70pub fn verify_sha256(actual_hex: &str, expected: Option<&str>) -> Result<()> {
76 match expected {
77 Some(expected) if !actual_hex.eq_ignore_ascii_case(expected.trim()) => bail!(
78 "sha256 mismatch: downloaded body hashes to {actual_hex} but the registry \
79 expects {expected} (corrupted or tampered download)"
80 ),
81 _ => Ok(()),
82 }
83}
84
85struct HashingWriter<W: Write> {
89 inner: W,
90 hasher: Sha256,
91}
92
93impl<W: Write> Write for HashingWriter<W> {
94 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
95 let written = self.inner.write(buf)?;
96 self.hasher.update(&buf[..written]);
97 Ok(written)
98 }
99
100 fn flush(&mut self) -> std::io::Result<()> {
101 self.inner.flush()
102 }
103}
104
105pub fn remove_temp_file(path: &Path) {
111 if let Err(e) = std::fs::remove_file(path) {
112 if e.kind() != std::io::ErrorKind::NotFound {
113 warn!(
114 target: TRACE_TARGET,
115 op = "cleanup",
116 path = %path.display(),
117 error = %e,
118 "failed to remove temp file"
119 );
120 }
121 }
122}
123
124#[derive(Default)]
132pub struct TempFileGuard {
133 paths: Vec<PathBuf>,
134}
135
136impl TempFileGuard {
137 pub fn new() -> Self {
138 Self { paths: Vec::new() }
139 }
140
141 pub fn push(&mut self, path: PathBuf) {
143 self.paths.push(path);
144 }
145}
146
147impl Drop for TempFileGuard {
148 fn drop(&mut self) {
149 for path in &self.paths {
150 remove_temp_file(path);
151 }
152 }
153}
154
155#[cfg_attr(coverage_nightly, coverage(off))]
159pub fn ensure_file(dir: &Path, file: &ModelFile) -> Result<PathBuf> {
160 let filename = file.filename.as_str();
161 let url = file.url.as_str();
162 let local = model_cache_path(dir, filename)?;
163 if local.is_file() {
164 tracing::debug!(
165 target: TRACE_TARGET,
166 op = "ensure_file",
167 filename,
168 path = %local.display(),
169 "cached"
170 );
171 return Ok(local);
172 }
173 download_file_verified(url, &local, file.sha256.as_deref())
174 .with_context(|| format!("downloading {filename} ({url}) -> {}", local.display()))?;
175 Ok(local)
176}
177
178#[cfg_attr(coverage_nightly, coverage(off))]
186pub fn download_file(url: &str, dest: &Path) -> Result<()> {
187 download_file_verified(url, dest, None)
188}
189
190#[cfg_attr(coverage_nightly, coverage(off))]
194pub fn download_file_verified(url: &str, dest: &Path, expected_sha256: Option<&str>) -> Result<()> {
195 if let Some(parent) = dest.parent() {
196 std::fs::create_dir_all(parent)
197 .with_context(|| format!("creating {}", parent.display()))?;
198 }
199 let part = dest.with_extension("part");
200 let client = reqwest::blocking::Client::builder()
201 .timeout(std::time::Duration::from_secs(DOWNLOAD_TIMEOUT_SECS))
202 .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
203 .build()?;
204 info!(
205 target: TRACE_TARGET,
206 op = "download",
207 url,
208 dest = %dest.display(),
209 "starting"
210 );
211 let started = Instant::now();
212 let mut response = match client.get(url).send() {
213 Ok(response) => response,
214 Err(e) => {
215 warn!(
221 target: TRACE_TARGET,
222 op = "download",
223 url,
224 dest = %dest.display(),
225 elapsed_ms = started.elapsed().as_millis() as u64,
226 error = %e,
227 "download failed: request error"
228 );
229 return Err(e).context("GET");
230 }
231 };
232 let status = response.status();
233 if !status.is_success() {
234 warn!(
235 target: TRACE_TARGET,
236 op = "download",
237 url,
238 dest = %dest.display(),
239 status = status.as_u16(),
240 elapsed_ms = started.elapsed().as_millis() as u64,
241 "download failed: non-success status"
242 );
243 bail!("GET {url} -> {status}");
244 }
245 let expected_len = response.content_length();
246 let file =
247 std::fs::File::create(&part).with_context(|| format!("creating {}", part.display()))?;
248 let mut writer = HashingWriter {
249 inner: file,
250 hasher: Sha256::new(),
251 };
252 let copied = std::io::copy(&mut response, &mut writer);
253 let digest = writer.hasher.finalize();
254 drop(writer.inner);
257 let bytes = match copied {
258 Ok(bytes) => bytes,
259 Err(e) => {
260 remove_temp_file(&part);
261 warn!(
262 target: TRACE_TARGET,
263 op = "download",
264 url,
265 dest = %dest.display(),
266 elapsed_ms = started.elapsed().as_millis() as u64,
267 error = %e,
268 "download failed: streaming body"
269 );
270 return Err(e).context("streaming body");
271 }
272 };
273 if let Err(e) = verify_download_len(bytes, expected_len) {
274 remove_temp_file(&part);
275 warn!(
276 target: TRACE_TARGET,
277 op = "download",
278 url,
279 dest = %dest.display(),
280 bytes,
281 elapsed_ms = started.elapsed().as_millis() as u64,
282 error = %e,
283 "download failed: size mismatch"
284 );
285 return Err(e).with_context(|| format!("downloading {url}"));
286 }
287 let actual_hex: String = digest.iter().map(|b| format!("{b:02x}")).collect();
288 if let Err(e) = verify_sha256(&actual_hex, expected_sha256) {
289 remove_temp_file(&part);
290 warn!(
291 target: TRACE_TARGET,
292 op = "download",
293 url,
294 dest = %dest.display(),
295 bytes,
296 elapsed_ms = started.elapsed().as_millis() as u64,
297 error = %e,
298 "download failed: sha256 mismatch"
299 );
300 return Err(e).with_context(|| format!("downloading {url}"));
301 }
302 std::fs::rename(&part, dest)
303 .with_context(|| format!("renaming {} -> {}", part.display(), dest.display()))?;
304 let elapsed_ms = started.elapsed().as_millis() as u64;
305 info!(
306 target: TRACE_TARGET,
307 op = "download",
308 url,
309 dest = %dest.display(),
310 bytes,
311 elapsed_ms,
312 "done"
313 );
314 Ok(())
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use tempfile::tempdir;
321
322 #[test]
323 fn model_cache_path_accepts_plain_filenames_only() {
324 let root = Path::new("/models");
325 assert_eq!(
326 model_cache_path(root, "model.gguf").unwrap(),
327 PathBuf::from("/models/model.gguf")
328 );
329 assert!(model_cache_path(root, "../outside.gguf").is_err());
330 assert!(model_cache_path(root, "nested/model.gguf").is_err());
331 assert!(model_cache_path(root, "/tmp/model.gguf").is_err());
332 assert!(model_cache_path(root, r"nested\model.gguf").is_err());
333 assert!(model_cache_path(root, "").is_err());
334 }
335
336 #[test]
337 fn verify_download_len_accepts_exact_match() {
338 assert!(verify_download_len(2_700_000_000, Some(2_700_000_000)).is_ok());
339 }
340
341 #[test]
342 fn verify_download_len_accepts_when_length_unknown() {
343 assert!(verify_download_len(123, None).is_ok());
344 }
345
346 #[test]
347 fn verify_download_len_rejects_truncated_download() {
348 let err = verify_download_len(40, Some(100)).unwrap_err().to_string();
349 assert!(err.contains("size mismatch"), "got: {err}");
350 assert!(err.contains("40"), "got: {err}");
351 assert!(err.contains("100"), "got: {err}");
352 }
353
354 #[test]
355 fn verify_download_len_rejects_overlong_download() {
356 assert!(verify_download_len(120, Some(100)).is_err());
357 }
358
359 fn test_file(filename: &str, url: &str) -> ModelFile {
360 ModelFile {
361 role: crate::types::ModelFileRole::Model,
362 url: url.to_string(),
363 filename: filename.to_string(),
364 approx_bytes: None,
365 sha256: None,
366 }
367 }
368
369 #[test]
370 fn ensure_file_returns_cached_path_without_network() {
371 let dir = tempdir().unwrap();
374 std::fs::write(dir.path().join("cached.gguf"), b"already here").unwrap();
375 let path = ensure_file(
376 dir.path(),
377 &test_file("cached.gguf", "https://example.invalid/x"),
378 )
379 .unwrap();
380 assert_eq!(path, dir.path().join("cached.gguf"));
381 assert_eq!(std::fs::read(&path).unwrap(), b"already here");
382 }
383
384 #[test]
385 fn ensure_file_rejects_path_traversal_before_any_network() {
386 let dir = tempdir().unwrap();
387 let err = ensure_file(
388 dir.path(),
389 &test_file("../escape.gguf", "https://example.invalid/x"),
390 )
391 .unwrap_err()
392 .to_string();
393 assert!(err.contains("plain file name"), "got: {err}");
394 }
395
396 #[test]
401 fn verify_sha256_accepts_match_and_absence() {
402 assert!(verify_sha256("abc123", Some("abc123")).is_ok());
403 assert!(
404 verify_sha256("abc123", Some("ABC123")).is_ok(),
405 "case-insensitive"
406 );
407 assert!(
408 verify_sha256("abc123", Some(" abc123 ")).is_ok(),
409 "whitespace-tolerant"
410 );
411 assert!(
412 verify_sha256("abc123", None).is_ok(),
413 "legacy rows have no hash"
414 );
415 }
416
417 #[test]
418 fn verify_sha256_rejects_mismatch() {
419 let err = verify_sha256("abc123", Some("def456"))
420 .unwrap_err()
421 .to_string();
422 assert!(err.contains("sha256 mismatch"), "got: {err}");
423 assert!(
424 err.contains("abc123") && err.contains("def456"),
425 "must name both digests: {err}"
426 );
427 }
428
429 struct ProbeWriter {
446 sink: Vec<u8>,
447 max_per_write: usize,
448 flushes: usize,
449 }
450
451 impl Write for ProbeWriter {
452 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
453 let take = buf.len().min(self.max_per_write);
454 self.sink.extend_from_slice(&buf[..take]);
455 Ok(take)
456 }
457 fn flush(&mut self) -> std::io::Result<()> {
458 self.flushes += 1;
459 Ok(())
460 }
461 }
462
463 fn hex(bytes: &[u8]) -> String {
464 bytes.iter().map(|b| format!("{b:02x}")).collect()
465 }
466
467 #[test]
468 fn hashing_writer_hashes_only_the_bytes_the_inner_accepted() {
469 let mut writer = HashingWriter {
474 inner: ProbeWriter {
475 sink: Vec::new(),
476 max_per_write: 3,
477 flushes: 0,
478 },
479 hasher: Sha256::new(),
480 };
481 let written = writer.write(b"abcdefgh").unwrap();
482 assert_eq!(written, 3, "inner accepts at most 3 bytes per write");
483 assert_eq!(writer.inner.sink, b"abc", "only the prefix reaches inner");
484 assert_eq!(
485 hex(&writer.hasher.finalize()),
486 hex(&Sha256::digest(b"abc")),
487 "hash covers only the accepted prefix"
488 );
489 }
490
491 #[test]
492 fn hashing_writer_digest_matches_a_short_writing_stream_end_to_end() {
493 let source = b"the quick brown model weights".to_vec();
499 let mut reader = source.as_slice();
500 let mut writer = HashingWriter {
501 inner: ProbeWriter {
502 sink: Vec::new(),
503 max_per_write: 4,
504 flushes: 0,
505 },
506 hasher: Sha256::new(),
507 };
508 let copied = std::io::copy(&mut reader, &mut writer).unwrap();
509 assert_eq!(copied as usize, source.len());
510 assert_eq!(
511 writer.inner.sink, source,
512 "every byte reaches the cache file"
513 );
514 assert_eq!(
515 hex(&writer.hasher.finalize()),
516 hex(&Sha256::digest(&source)),
517 "digest matches the full body"
518 );
519 }
520
521 #[test]
522 fn hashing_writer_flush_delegates_to_the_inner_writer() {
523 let mut writer = HashingWriter {
524 inner: ProbeWriter {
525 sink: Vec::new(),
526 max_per_write: usize::MAX,
527 flushes: 0,
528 },
529 hasher: Sha256::new(),
530 };
531 writer.flush().unwrap();
532 writer.flush().unwrap();
533 assert_eq!(writer.inner.flushes, 2, "flush is forwarded to inner");
534 }
535
536 #[test]
545 fn remove_temp_file_deletes_an_existing_file_quietly() {
546 let dir = tempdir().unwrap();
547 let f = dir.path().join("artefact.webp");
548 std::fs::write(&f, b"bytes").unwrap();
549 let out = crate::test_support::capture({
550 let f = f.clone();
551 move || remove_temp_file(&f)
552 });
553 assert!(!f.exists(), "file should be gone after cleanup");
554 assert!(
555 !out.contains("failed to remove temp file"),
556 "the success path must not warn: {out:?}"
557 );
558 }
559
560 #[test]
561 fn remove_temp_file_ignores_a_missing_file() {
562 let dir = tempdir().unwrap();
563 let out = crate::test_support::capture({
564 let missing = dir.path().join("never.part");
565 move || remove_temp_file(&missing)
566 });
567 assert!(
568 !out.contains("failed to remove temp file"),
569 "a not-found temp file is the desired end state: {out:?}"
570 );
571 }
572
573 #[test]
574 fn remove_temp_file_surfaces_a_failed_removal() {
575 let dir = tempdir().unwrap();
579 let stubborn = dir.path().join("subdir");
580 std::fs::create_dir(&stubborn).unwrap();
581 let out = crate::test_support::capture(move || remove_temp_file(&stubborn));
582 assert!(
583 out.contains("failed to remove temp file"),
584 "a failed removal must surface in the logs: {out:?}"
585 );
586 assert!(
587 out.contains("subdir"),
588 "the warning must name the offending path: {out:?}"
589 );
590 assert!(
591 out.contains("cleanup"),
592 "the warning should tag the cleanup op: {out:?}"
593 );
594 }
595
596 #[test]
597 fn temp_file_guard_removes_every_registered_file_on_drop() {
598 let dir = tempdir().unwrap();
599 let out = dir.path().join("out.webp");
600 let init = dir.path().join("out-init.png");
601 std::fs::write(&out, b"image").unwrap();
602 std::fs::write(&init, b"init").unwrap();
603 {
604 let mut guard = TempFileGuard::new();
605 guard.push(out.clone());
606 guard.push(init.clone());
607 assert!(out.exists() && init.exists(), "files present before drop");
608 }
609 assert!(!out.exists(), "output temp must be removed on drop");
610 assert!(!init.exists(), "init-image temp must be removed on drop");
611 }
612
613 #[test]
614 fn temp_file_guard_tolerates_a_file_that_never_materialised() {
615 let dir = tempdir().unwrap();
619 let missing = dir.path().join("never-written.webp");
620 let out = crate::test_support::capture(move || {
621 let mut guard = TempFileGuard::new();
622 guard.push(missing);
623 drop(guard);
624 });
625 assert!(
626 !out.contains("failed to remove temp file"),
627 "a never-created temp file must not warn on cleanup: {out:?}"
628 );
629 }
630}