studio_worker/engine/
download.rs1use anyhow::{bail, Context, Result};
15use std::path::{Component, Path, PathBuf};
16use std::time::Instant;
17use tracing::{info, warn};
18
19const TRACE_TARGET: &str = "studio_worker::engine::download";
22
23const DOWNLOAD_TIMEOUT_SECS: u64 = 30 * 60;
26
27pub fn model_cache_path(dir: &Path, filename: &str) -> Result<PathBuf> {
31 let path = Path::new(filename);
32 let mut components = path.components();
33 match (components.next(), components.next()) {
34 (Some(Component::Normal(name)), None)
35 if !filename.contains('/') && !filename.contains('\\') =>
36 {
37 Ok(dir.join(name))
38 }
39 _ => bail!("model filename must be a plain file name: {filename:?}"),
40 }
41}
42
43pub fn verify_download_len(copied: u64, expected: Option<u64>) -> Result<()> {
50 match expected {
51 Some(expected) if copied != expected => bail!(
52 "size mismatch: wrote {copied} bytes but the server declared \
53 Content-Length {expected} (download truncated or corrupt)"
54 ),
55 _ => Ok(()),
56 }
57}
58
59pub fn remove_partial(path: &Path) {
64 if let Err(e) = std::fs::remove_file(path) {
65 if e.kind() != std::io::ErrorKind::NotFound {
66 warn!(
67 target: TRACE_TARGET,
68 op = "cleanup",
69 path = %path.display(),
70 error = %e,
71 "failed to remove partial download"
72 );
73 }
74 }
75}
76
77#[cfg_attr(coverage_nightly, coverage(off))]
80pub fn ensure_file(dir: &Path, filename: &str, url: &str) -> Result<PathBuf> {
81 let local = model_cache_path(dir, filename)?;
82 if local.is_file() {
83 tracing::debug!(
84 target: TRACE_TARGET,
85 op = "ensure_file",
86 filename,
87 path = %local.display(),
88 "cached"
89 );
90 return Ok(local);
91 }
92 download_file(url, &local)
93 .with_context(|| format!("downloading {filename} ({url}) -> {}", local.display()))?;
94 Ok(local)
95}
96
97#[cfg_attr(coverage_nightly, coverage(off))]
105pub fn download_file(url: &str, dest: &Path) -> Result<()> {
106 if let Some(parent) = dest.parent() {
107 std::fs::create_dir_all(parent)
108 .with_context(|| format!("creating {}", parent.display()))?;
109 }
110 let part = dest.with_extension("part");
111 let client = reqwest::blocking::Client::builder()
112 .timeout(std::time::Duration::from_secs(DOWNLOAD_TIMEOUT_SECS))
113 .user_agent(concat!("studio-worker/", env!("CARGO_PKG_VERSION")))
114 .build()?;
115 info!(
116 target: TRACE_TARGET,
117 op = "download",
118 url,
119 dest = %dest.display(),
120 "starting"
121 );
122 let started = Instant::now();
123 let mut response = client.get(url).send().context("GET")?;
124 if !response.status().is_success() {
125 bail!("GET {url} -> {}", response.status());
126 }
127 let expected_len = response.content_length();
128 let mut file =
129 std::fs::File::create(&part).with_context(|| format!("creating {}", part.display()))?;
130 let copied = std::io::copy(&mut response, &mut file);
131 drop(file);
134 let bytes = match copied {
135 Ok(bytes) => bytes,
136 Err(e) => {
137 remove_partial(&part);
138 return Err(e).context("streaming body");
139 }
140 };
141 if let Err(e) = verify_download_len(bytes, expected_len) {
142 remove_partial(&part);
143 return Err(e).with_context(|| format!("downloading {url}"));
144 }
145 std::fs::rename(&part, dest)
146 .with_context(|| format!("renaming {} -> {}", part.display(), dest.display()))?;
147 let elapsed_ms = started.elapsed().as_millis() as u64;
148 info!(
149 target: TRACE_TARGET,
150 op = "download",
151 url,
152 dest = %dest.display(),
153 bytes,
154 elapsed_ms,
155 "done"
156 );
157 Ok(())
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use tempfile::tempdir;
164
165 #[test]
166 fn model_cache_path_accepts_plain_filenames_only() {
167 let root = Path::new("/models");
168 assert_eq!(
169 model_cache_path(root, "model.gguf").unwrap(),
170 PathBuf::from("/models/model.gguf")
171 );
172 assert!(model_cache_path(root, "../outside.gguf").is_err());
173 assert!(model_cache_path(root, "nested/model.gguf").is_err());
174 assert!(model_cache_path(root, "/tmp/model.gguf").is_err());
175 assert!(model_cache_path(root, r"nested\model.gguf").is_err());
176 assert!(model_cache_path(root, "").is_err());
177 }
178
179 #[test]
180 fn verify_download_len_accepts_exact_match() {
181 assert!(verify_download_len(2_700_000_000, Some(2_700_000_000)).is_ok());
182 }
183
184 #[test]
185 fn verify_download_len_accepts_when_length_unknown() {
186 assert!(verify_download_len(123, None).is_ok());
187 }
188
189 #[test]
190 fn verify_download_len_rejects_truncated_download() {
191 let err = verify_download_len(40, Some(100)).unwrap_err().to_string();
192 assert!(err.contains("size mismatch"), "got: {err}");
193 assert!(err.contains("40"), "got: {err}");
194 assert!(err.contains("100"), "got: {err}");
195 }
196
197 #[test]
198 fn verify_download_len_rejects_overlong_download() {
199 assert!(verify_download_len(120, Some(100)).is_err());
200 }
201
202 #[test]
203 fn ensure_file_returns_cached_path_without_network() {
204 let dir = tempdir().unwrap();
207 std::fs::write(dir.path().join("cached.gguf"), b"already here").unwrap();
208 let path = ensure_file(dir.path(), "cached.gguf", "https://example.invalid/x").unwrap();
209 assert_eq!(path, dir.path().join("cached.gguf"));
210 assert_eq!(std::fs::read(&path).unwrap(), b"already here");
211 }
212
213 #[test]
214 fn ensure_file_rejects_path_traversal_before_any_network() {
215 let dir = tempdir().unwrap();
216 let err = ensure_file(dir.path(), "../escape.gguf", "https://example.invalid/x")
217 .unwrap_err()
218 .to_string();
219 assert!(err.contains("plain file name"), "got: {err}");
220 }
221
222 #[test]
223 fn remove_partial_ignores_a_missing_file() {
224 let dir = tempdir().unwrap();
225 let out = crate::test_support::capture({
226 let missing = dir.path().join("never.part");
227 move || remove_partial(&missing)
228 });
229 assert!(
230 !out.contains("failed to remove partial download"),
231 "a not-found partial is the desired end state: {out:?}"
232 );
233 }
234
235 #[test]
236 fn remove_partial_surfaces_a_failed_removal() {
237 let dir = tempdir().unwrap();
240 let stubborn = dir.path().join("subdir");
241 std::fs::create_dir(&stubborn).unwrap();
242 let out = crate::test_support::capture(move || remove_partial(&stubborn));
243 assert!(
244 out.contains("failed to remove partial download"),
245 "a failed removal must surface in the logs: {out:?}"
246 );
247 }
248}