stem_splitter_core/model/
model_manager.rs1use crate::{
2 error::{Result, StemError},
3 io::{
4 crypto::verify_sha256,
5 net::{download_with_progress, http_client},
6 paths::models_cache_dir,
7 },
8 model::registry::resolve_manifest_url,
9 types::ModelManifest,
10};
11
12use std::{fs, path::PathBuf};
13
14pub struct ModelHandle {
15 pub manifest: ModelManifest,
16 pub local_path: PathBuf,
17}
18
19pub fn ensure_model(model_name: &str, manifest_url_override: Option<&str>) -> Result<ModelHandle> {
20 let manifest_url = manifest_url_override
21 .map(|s| s.to_string())
22 .unwrap_or_else(|| resolve_manifest_url(model_name).expect("resolve_manifest_url failed"));
23
24 let client = http_client();
25 let manifest: ModelManifest = client
26 .get(&manifest_url)
27 .send()?
28 .error_for_status()?
29 .json()?;
30
31 let a = manifest
32 .resolve_primary_artifact()
33 .map_err(|msg| StemError::Manifest(msg))?;
34
35 let cache_dir = models_cache_dir()?;
36 fs::create_dir_all(&cache_dir)?;
37 let ext = a
38 .file
39 .rsplit('.')
40 .next()
41 .map(|s| format!(".{s}"))
42 .unwrap_or_default();
43 let file_name = format!("{}-{}{}", manifest.name, &a.sha256[..8], ext);
44 let local_path = cache_dir.join(file_name);
45
46 let need_download = !matches!(verify_sha256(&local_path, &a.sha256), Ok(true));
47 if need_download {
48 download_with_progress(&client, &a.url, &local_path)?;
49 if !verify_sha256(&local_path, &a.sha256)? {
50 return Err(StemError::Checksum {
51 path: local_path.display().to_string(),
52 });
53 }
54 if a.size_bytes > 0 {
55 let size = fs::metadata(&local_path).map(|m| m.len()).unwrap_or(0);
56 if size != a.size_bytes {
57 eprintln!(
58 "warn: size mismatch for {}, expected {}, got {}",
59 local_path.display(),
60 a.size_bytes,
61 size
62 );
63 }
64 }
65 }
66
67 Ok(ModelHandle {
68 manifest,
69 local_path,
70 })
71}