stem_splitter_core/model/
model_manager.rs

1use crate::{
2    error::{Result, StemError},
3    io::{
4        crypto::verify_sha256,
5        net::{download_with_progress, http_client},
6        paths::models_cache_dir,
7    },
8    model::registry::resolve_manifest_url,
9    types::ModelManifest,
10};
11
12use std::{fs, path::PathBuf};
13
14pub struct ModelHandle {
15    pub manifest: ModelManifest,
16    pub local_path: PathBuf,
17}
18
19pub fn ensure_model(model_name: &str, manifest_url_override: Option<&str>) -> Result<ModelHandle> {
20    let manifest_url = manifest_url_override
21        .map(|s| s.to_string())
22        .unwrap_or_else(|| resolve_manifest_url(model_name).expect("resolve_manifest_url failed"));
23
24    let client = http_client();
25    let manifest: ModelManifest = client
26        .get(&manifest_url)
27        .send()?
28        .error_for_status()?
29        .json()?;
30
31    let a = manifest
32        .resolve_primary_artifact()
33        .map_err(|msg| StemError::Manifest(msg))?;
34
35    let cache_dir = models_cache_dir()?;
36    fs::create_dir_all(&cache_dir)?;
37    let ext = a
38        .file
39        .rsplit('.')
40        .next()
41        .map(|s| format!(".{s}"))
42        .unwrap_or_default();
43    let file_name = format!("{}-{}{}", manifest.name, &a.sha256[..8], ext);
44    let local_path = cache_dir.join(file_name);
45
46    let need_download = !matches!(verify_sha256(&local_path, &a.sha256), Ok(true));
47    if need_download {
48        download_with_progress(&client, &a.url, &local_path)?;
49        if !verify_sha256(&local_path, &a.sha256)? {
50            return Err(StemError::Checksum {
51                path: local_path.display().to_string(),
52            });
53        }
54        if a.size_bytes > 0 {
55            let size = fs::metadata(&local_path).map(|m| m.len()).unwrap_or(0);
56            if size != a.size_bytes {
57                eprintln!(
58                    "warn: size mismatch for {}, expected {}, got {}",
59                    local_path.display(),
60                    a.size_bytes,
61                    size
62                );
63            }
64        }
65    }
66
67    Ok(ModelHandle {
68        manifest,
69        local_path,
70    })
71}