Skip to main content

vona_model_provisioning/
lib.rs

1use futures_util::StreamExt;
2use ring::digest::{Context, SHA256};
3use serde::{Deserialize, Serialize};
4use std::path::PathBuf;
5use thiserror::Error;
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7
8pub const DEFAULT_CACHE_ENV: &str = "VONA_MODEL_CACHE_DIR";
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum LocalModelProvider {
13    HuggingFace {
14        repo: String,
15        revision: Option<String>,
16    },
17    Ollama {
18        model: String,
19    },
20    LocalFile,
21    Custom {
22        name: String,
23    },
24    ProviderManaged {
25        name: String,
26    },
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
30pub struct ModelArtifact {
31    pub name: String,
32    pub relative_path: PathBuf,
33    pub source_url: Option<String>,
34    pub expected_size_bytes: Option<u64>,
35    pub sha256: Option<String>,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
39pub struct ModelManifest {
40    pub id: String,
41    pub provider: LocalModelProvider,
42    pub artifacts: Vec<ModelArtifact>,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct ModelCache {
47    pub root: PathBuf,
48}
49
50impl ModelCache {
51    pub fn from_env_or(root: impl Into<PathBuf>) -> Self {
52        Self {
53            root: std::env::var(DEFAULT_CACHE_ENV)
54                .map(PathBuf::from)
55                .unwrap_or_else(|_| root.into()),
56        }
57    }
58
59    pub fn model_dir(&self, manifest: &ModelManifest) -> PathBuf {
60        self.root.join(sanitize_model_id(&manifest.id))
61    }
62
63    pub fn artifact_path(&self, manifest: &ModelManifest, artifact: &ModelArtifact) -> PathBuf {
64        self.model_dir(manifest).join(&artifact.relative_path)
65    }
66
67    pub fn inspect(&self, manifest: &ModelManifest) -> ProvisionPlan {
68        let mut present = Vec::new();
69        let mut missing = Vec::new();
70        for artifact in &manifest.artifacts {
71            let path = self.artifact_path(manifest, artifact);
72            if path.is_file() {
73                present.push(PlannedArtifact {
74                    artifact: artifact.clone(),
75                    path,
76                });
77            } else {
78                missing.push(PlannedArtifact {
79                    artifact: artifact.clone(),
80                    path,
81                });
82            }
83        }
84        ProvisionPlan {
85            manifest: manifest.clone(),
86            model_dir: self.model_dir(manifest),
87            present,
88            missing,
89        }
90    }
91
92    pub fn ensure_dirs(&self, manifest: &ModelManifest) -> Result<(), ProvisioningError> {
93        std::fs::create_dir_all(self.model_dir(manifest))
94            .map_err(|err| ProvisioningError::Io(err.to_string()))
95    }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct ProvisionPlan {
100    pub manifest: ModelManifest,
101    pub model_dir: PathBuf,
102    pub present: Vec<PlannedArtifact>,
103    pub missing: Vec<PlannedArtifact>,
104}
105
106impl ProvisionPlan {
107    pub fn is_ready(&self) -> bool {
108        self.missing.is_empty()
109    }
110
111    pub fn missing_urls(&self) -> Vec<&str> {
112        self.missing
113            .iter()
114            .filter_map(|artifact| artifact.artifact.source_url.as_deref())
115            .collect()
116    }
117}
118
119#[derive(Debug, Clone, PartialEq, Eq)]
120pub struct PlannedArtifact {
121    pub artifact: ModelArtifact,
122    pub path: PathBuf,
123}
124
125#[derive(Debug, Clone, Error, PartialEq, Eq)]
126pub enum ProvisioningError {
127    #[error("model manifest has no artifacts: {0}")]
128    EmptyManifest(String),
129    #[error("artifact path must be relative: {0}")]
130    AbsoluteArtifactPath(String),
131    #[error("io error: {0}")]
132    Io(String),
133    #[error("artifact has no source URL: {0}")]
134    MissingSourceUrl(String),
135    #[error("download failed for {url}: {message}")]
136    Download { url: String, message: String },
137    #[error("artifact size mismatch for {name}: expected {expected} bytes, got {actual} bytes")]
138    SizeMismatch {
139        name: String,
140        expected: u64,
141        actual: u64,
142    },
143    #[error("artifact checksum mismatch for {name}: expected sha256 {expected}, got {actual}")]
144    ChecksumMismatch {
145        name: String,
146        expected: String,
147        actual: String,
148    },
149}
150
151#[derive(Debug, Clone)]
152pub struct HttpModelProvisioner {
153    client: reqwest::Client,
154}
155
156impl Default for HttpModelProvisioner {
157    fn default() -> Self {
158        Self {
159            client: reqwest::Client::new(),
160        }
161    }
162}
163
164impl HttpModelProvisioner {
165    pub fn new(client: reqwest::Client) -> Self {
166        Self { client }
167    }
168
169    pub async fn provision_missing(
170        &self,
171        cache: &ModelCache,
172        manifest: &ModelManifest,
173    ) -> Result<ProvisionPlan, ProvisioningError> {
174        validate_manifest(manifest)?;
175        cache.ensure_dirs(manifest)?;
176        let plan = cache.inspect(manifest);
177        let mut to_download = plan.missing;
178        for planned in plan.present {
179            if let Err(err) = verify_artifact_file(&planned).await {
180                let _ = tokio::fs::remove_file(&planned.path).await;
181                if matches!(
182                    err,
183                    ProvisioningError::SizeMismatch { .. }
184                        | ProvisioningError::ChecksumMismatch { .. }
185                ) {
186                    to_download.push(planned);
187                } else {
188                    return Err(err);
189                }
190            }
191        }
192
193        for planned in &to_download {
194            self.download_artifact(planned).await?;
195        }
196        Ok(cache.inspect(manifest))
197    }
198
199    async fn download_artifact(&self, planned: &PlannedArtifact) -> Result<(), ProvisioningError> {
200        let url =
201            planned.artifact.source_url.as_ref().ok_or_else(|| {
202                ProvisioningError::MissingSourceUrl(planned.artifact.name.clone())
203            })?;
204        if let Some(parent) = planned.path.parent() {
205            tokio::fs::create_dir_all(parent)
206                .await
207                .map_err(|err| ProvisioningError::Io(err.to_string()))?;
208        }
209
210        let temp_path = planned
211            .path
212            .with_extension(format!("{}.tmp", std::process::id()));
213        let mut file = tokio::fs::File::create(&temp_path)
214            .await
215            .map_err(|err| ProvisioningError::Io(err.to_string()))?;
216        let mut response = self
217            .client
218            .get(url)
219            .send()
220            .await
221            .map_err(|err| ProvisioningError::Download {
222                url: url.clone(),
223                message: err.to_string(),
224            })?
225            .error_for_status()
226            .map_err(|err| ProvisioningError::Download {
227                url: url.clone(),
228                message: err.to_string(),
229            })?
230            .bytes_stream();
231
232        let mut hasher = Context::new(&SHA256);
233        let mut size = 0_u64;
234        while let Some(chunk) = response.next().await {
235            let chunk = chunk.map_err(|err| ProvisioningError::Download {
236                url: url.clone(),
237                message: err.to_string(),
238            })?;
239            size += chunk.len() as u64;
240            hasher.update(&chunk);
241            file.write_all(&chunk)
242                .await
243                .map_err(|err| ProvisioningError::Io(err.to_string()))?;
244        }
245        file.flush()
246            .await
247            .map_err(|err| ProvisioningError::Io(err.to_string()))?;
248        drop(file);
249
250        verify_size(&planned.artifact, size)?;
251        verify_sha256(&planned.artifact, encode_hex(hasher.finish().as_ref()))?;
252
253        tokio::fs::rename(&temp_path, &planned.path)
254            .await
255            .map_err(|err| ProvisioningError::Io(err.to_string()))?;
256        Ok(())
257    }
258}
259
260pub fn validate_manifest(manifest: &ModelManifest) -> Result<(), ProvisioningError> {
261    if manifest.artifacts.is_empty()
262        && !matches!(
263            manifest.provider,
264            LocalModelProvider::Ollama { .. } | LocalModelProvider::ProviderManaged { .. }
265        )
266    {
267        return Err(ProvisioningError::EmptyManifest(manifest.id.clone()));
268    }
269    for artifact in &manifest.artifacts {
270        if artifact.relative_path.is_absolute() {
271            return Err(ProvisioningError::AbsoluteArtifactPath(
272                artifact.relative_path.display().to_string(),
273            ));
274        }
275    }
276    Ok(())
277}
278
279pub fn seamless_m4t_onnx_manifest(
280    model_id: impl Into<String>,
281    onnx_url: impl Into<String>,
282) -> ModelManifest {
283    ModelManifest {
284        id: model_id.into(),
285        provider: LocalModelProvider::HuggingFace {
286            repo: "facebook/hf-seamless-m4t-medium".to_string(),
287            revision: None,
288        },
289        artifacts: vec![ModelArtifact {
290            name: "encoder-decoder-onnx".to_string(),
291            relative_path: PathBuf::from("model.onnx"),
292            source_url: Some(onnx_url.into()),
293            expected_size_bytes: None,
294            sha256: None,
295        }],
296    }
297}
298
299pub fn moshi_server_manifest(model: impl Into<String>) -> ModelManifest {
300    let model = model.into();
301    ModelManifest {
302        id: format!("moshi/{model}"),
303        provider: LocalModelProvider::ProviderManaged {
304            name: format!("moshi/{model}"),
305        },
306        artifacts: Vec::new(),
307    }
308}
309
310async fn verify_artifact_file(planned: &PlannedArtifact) -> Result<(), ProvisioningError> {
311    let metadata = tokio::fs::metadata(&planned.path)
312        .await
313        .map_err(|err| ProvisioningError::Io(err.to_string()))?;
314    verify_size(&planned.artifact, metadata.len())?;
315
316    if planned.artifact.sha256.is_some() {
317        let mut file = tokio::fs::File::open(&planned.path)
318            .await
319            .map_err(|err| ProvisioningError::Io(err.to_string()))?;
320        let mut hasher = Context::new(&SHA256);
321        let mut buffer = vec![0_u8; 64 * 1024];
322        loop {
323            let read = file
324                .read(&mut buffer)
325                .await
326                .map_err(|err| ProvisioningError::Io(err.to_string()))?;
327            if read == 0 {
328                break;
329            }
330            hasher.update(&buffer[..read]);
331        }
332        verify_sha256(&planned.artifact, encode_hex(hasher.finish().as_ref()))?;
333    }
334    Ok(())
335}
336
337fn verify_size(artifact: &ModelArtifact, actual: u64) -> Result<(), ProvisioningError> {
338    if let Some(expected) = artifact.expected_size_bytes
339        && actual != expected
340    {
341        return Err(ProvisioningError::SizeMismatch {
342            name: artifact.name.clone(),
343            expected,
344            actual,
345        });
346    }
347    Ok(())
348}
349
350fn verify_sha256(artifact: &ModelArtifact, actual: String) -> Result<(), ProvisioningError> {
351    if let Some(expected) = &artifact.sha256
352        && !expected.eq_ignore_ascii_case(&actual)
353    {
354        return Err(ProvisioningError::ChecksumMismatch {
355            name: artifact.name.clone(),
356            expected: expected.clone(),
357            actual,
358        });
359    }
360    Ok(())
361}
362
363fn encode_hex(bytes: &[u8]) -> String {
364    const HEX: &[u8; 16] = b"0123456789abcdef";
365    let mut encoded = String::with_capacity(bytes.len() * 2);
366    for byte in bytes {
367        encoded.push(HEX[(byte >> 4) as usize] as char);
368        encoded.push(HEX[(byte & 0x0f) as usize] as char);
369    }
370    encoded
371}
372
373fn sanitize_model_id(id: &str) -> String {
374    id.chars()
375        .map(|ch| match ch {
376            '/' | ':' | '\\' => '_',
377            ch => ch,
378        })
379        .collect()
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385
386    #[test]
387    fn validate_rejects_empty_manifest() {
388        let manifest = ModelManifest {
389            id: "empty".to_string(),
390            provider: LocalModelProvider::LocalFile,
391            artifacts: Vec::new(),
392        };
393        assert_eq!(
394            validate_manifest(&manifest),
395            Err(ProvisioningError::EmptyManifest("empty".to_string()))
396        );
397    }
398
399    #[test]
400    fn validate_rejects_absolute_artifact_paths() {
401        let manifest = ModelManifest {
402            id: "bad".to_string(),
403            provider: LocalModelProvider::LocalFile,
404            artifacts: vec![ModelArtifact {
405                name: "bad".to_string(),
406                relative_path: PathBuf::from("/tmp/model.onnx"),
407                source_url: None,
408                expected_size_bytes: None,
409                sha256: None,
410            }],
411        };
412        assert!(matches!(
413            validate_manifest(&manifest),
414            Err(ProvisioningError::AbsoluteArtifactPath(_))
415        ));
416    }
417
418    #[test]
419    fn inspect_splits_present_and_missing_artifacts() {
420        let root =
421            std::env::temp_dir().join(format!("vona-provisioning-test-{}", std::process::id()));
422        let cache = ModelCache { root };
423        let manifest = seamless_m4t_onnx_manifest(
424            "facebook/hf-seamless-m4t-medium",
425            "https://example.test/model.onnx",
426        );
427        cache.ensure_dirs(&manifest).unwrap();
428        std::fs::write(cache.model_dir(&manifest).join("model.onnx"), b"onnx").unwrap();
429        let plan = cache.inspect(&manifest);
430        assert!(plan.is_ready());
431        assert_eq!(plan.present.len(), 1);
432        let _ = std::fs::remove_dir_all(cache.root);
433    }
434
435    #[test]
436    fn moshi_manifest_is_provider_managed_and_valid_without_artifacts() {
437        let manifest = moshi_server_manifest("kyutai/moshi");
438        assert!(matches!(
439            manifest.provider,
440            LocalModelProvider::ProviderManaged { .. }
441        ));
442        assert!(validate_manifest(&manifest).is_ok());
443    }
444
445    #[test]
446    fn sha256_verification_detects_mismatch() {
447        let artifact = ModelArtifact {
448            name: "model".to_string(),
449            relative_path: PathBuf::from("model.bin"),
450            source_url: None,
451            expected_size_bytes: Some(4),
452            sha256: Some("0000".to_string()),
453        };
454        assert!(matches!(
455            verify_sha256(&artifact, "abcd".to_string()),
456            Err(ProvisioningError::ChecksumMismatch { .. })
457        ));
458        assert!(verify_size(&artifact, 4).is_ok());
459    }
460}