1use futures_util::StreamExt;
2use ring::digest::{Context, SHA256};
3use serde::{Deserialize, Serialize};
4use std::path::PathBuf;
5use thiserror::Error;
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7
8pub const DEFAULT_CACHE_ENV: &str = "VONA_MODEL_CACHE_DIR";
9
10#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum LocalModelProvider {
13 HuggingFace {
14 repo: String,
15 revision: Option<String>,
16 },
17 Ollama {
18 model: String,
19 },
20 LocalFile,
21 Custom {
22 name: String,
23 },
24 ProviderManaged {
25 name: String,
26 },
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
30pub struct ModelArtifact {
31 pub name: String,
32 pub relative_path: PathBuf,
33 pub source_url: Option<String>,
34 pub expected_size_bytes: Option<u64>,
35 pub sha256: Option<String>,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
39pub struct ModelManifest {
40 pub id: String,
41 pub provider: LocalModelProvider,
42 pub artifacts: Vec<ModelArtifact>,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct ModelCache {
47 pub root: PathBuf,
48}
49
50impl ModelCache {
51 pub fn from_env_or(root: impl Into<PathBuf>) -> Self {
52 Self {
53 root: std::env::var(DEFAULT_CACHE_ENV)
54 .map(PathBuf::from)
55 .unwrap_or_else(|_| root.into()),
56 }
57 }
58
59 pub fn model_dir(&self, manifest: &ModelManifest) -> PathBuf {
60 self.root.join(sanitize_model_id(&manifest.id))
61 }
62
63 pub fn artifact_path(&self, manifest: &ModelManifest, artifact: &ModelArtifact) -> PathBuf {
64 self.model_dir(manifest).join(&artifact.relative_path)
65 }
66
67 pub fn inspect(&self, manifest: &ModelManifest) -> ProvisionPlan {
68 let mut present = Vec::new();
69 let mut missing = Vec::new();
70 for artifact in &manifest.artifacts {
71 let path = self.artifact_path(manifest, artifact);
72 if path.is_file() {
73 present.push(PlannedArtifact {
74 artifact: artifact.clone(),
75 path,
76 });
77 } else {
78 missing.push(PlannedArtifact {
79 artifact: artifact.clone(),
80 path,
81 });
82 }
83 }
84 ProvisionPlan {
85 manifest: manifest.clone(),
86 model_dir: self.model_dir(manifest),
87 present,
88 missing,
89 }
90 }
91
92 pub fn ensure_dirs(&self, manifest: &ModelManifest) -> Result<(), ProvisioningError> {
93 std::fs::create_dir_all(self.model_dir(manifest))
94 .map_err(|err| ProvisioningError::Io(err.to_string()))
95 }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
99pub struct ProvisionPlan {
100 pub manifest: ModelManifest,
101 pub model_dir: PathBuf,
102 pub present: Vec<PlannedArtifact>,
103 pub missing: Vec<PlannedArtifact>,
104}
105
106impl ProvisionPlan {
107 pub fn is_ready(&self) -> bool {
108 self.missing.is_empty()
109 }
110
111 pub fn missing_urls(&self) -> Vec<&str> {
112 self.missing
113 .iter()
114 .filter_map(|artifact| artifact.artifact.source_url.as_deref())
115 .collect()
116 }
117}
118
119#[derive(Debug, Clone, PartialEq, Eq)]
120pub struct PlannedArtifact {
121 pub artifact: ModelArtifact,
122 pub path: PathBuf,
123}
124
125#[derive(Debug, Clone, Error, PartialEq, Eq)]
126pub enum ProvisioningError {
127 #[error("model manifest has no artifacts: {0}")]
128 EmptyManifest(String),
129 #[error("artifact path must be relative: {0}")]
130 AbsoluteArtifactPath(String),
131 #[error("io error: {0}")]
132 Io(String),
133 #[error("artifact has no source URL: {0}")]
134 MissingSourceUrl(String),
135 #[error("download failed for {url}: {message}")]
136 Download { url: String, message: String },
137 #[error("artifact size mismatch for {name}: expected {expected} bytes, got {actual} bytes")]
138 SizeMismatch {
139 name: String,
140 expected: u64,
141 actual: u64,
142 },
143 #[error("artifact checksum mismatch for {name}: expected sha256 {expected}, got {actual}")]
144 ChecksumMismatch {
145 name: String,
146 expected: String,
147 actual: String,
148 },
149}
150
151#[derive(Debug, Clone)]
152pub struct HttpModelProvisioner {
153 client: reqwest::Client,
154}
155
156impl Default for HttpModelProvisioner {
157 fn default() -> Self {
158 Self {
159 client: reqwest::Client::new(),
160 }
161 }
162}
163
164impl HttpModelProvisioner {
165 pub fn new(client: reqwest::Client) -> Self {
166 Self { client }
167 }
168
169 pub async fn provision_missing(
170 &self,
171 cache: &ModelCache,
172 manifest: &ModelManifest,
173 ) -> Result<ProvisionPlan, ProvisioningError> {
174 validate_manifest(manifest)?;
175 cache.ensure_dirs(manifest)?;
176 let plan = cache.inspect(manifest);
177 let mut to_download = plan.missing;
178 for planned in plan.present {
179 if let Err(err) = verify_artifact_file(&planned).await {
180 let _ = tokio::fs::remove_file(&planned.path).await;
181 if matches!(
182 err,
183 ProvisioningError::SizeMismatch { .. }
184 | ProvisioningError::ChecksumMismatch { .. }
185 ) {
186 to_download.push(planned);
187 } else {
188 return Err(err);
189 }
190 }
191 }
192
193 for planned in &to_download {
194 self.download_artifact(planned).await?;
195 }
196 Ok(cache.inspect(manifest))
197 }
198
199 async fn download_artifact(&self, planned: &PlannedArtifact) -> Result<(), ProvisioningError> {
200 let url =
201 planned.artifact.source_url.as_ref().ok_or_else(|| {
202 ProvisioningError::MissingSourceUrl(planned.artifact.name.clone())
203 })?;
204 if let Some(parent) = planned.path.parent() {
205 tokio::fs::create_dir_all(parent)
206 .await
207 .map_err(|err| ProvisioningError::Io(err.to_string()))?;
208 }
209
210 let temp_path = planned
211 .path
212 .with_extension(format!("{}.tmp", std::process::id()));
213 let mut file = tokio::fs::File::create(&temp_path)
214 .await
215 .map_err(|err| ProvisioningError::Io(err.to_string()))?;
216 let mut response = self
217 .client
218 .get(url)
219 .send()
220 .await
221 .map_err(|err| ProvisioningError::Download {
222 url: url.clone(),
223 message: err.to_string(),
224 })?
225 .error_for_status()
226 .map_err(|err| ProvisioningError::Download {
227 url: url.clone(),
228 message: err.to_string(),
229 })?
230 .bytes_stream();
231
232 let mut hasher = Context::new(&SHA256);
233 let mut size = 0_u64;
234 while let Some(chunk) = response.next().await {
235 let chunk = chunk.map_err(|err| ProvisioningError::Download {
236 url: url.clone(),
237 message: err.to_string(),
238 })?;
239 size += chunk.len() as u64;
240 hasher.update(&chunk);
241 file.write_all(&chunk)
242 .await
243 .map_err(|err| ProvisioningError::Io(err.to_string()))?;
244 }
245 file.flush()
246 .await
247 .map_err(|err| ProvisioningError::Io(err.to_string()))?;
248 drop(file);
249
250 verify_size(&planned.artifact, size)?;
251 verify_sha256(&planned.artifact, encode_hex(hasher.finish().as_ref()))?;
252
253 tokio::fs::rename(&temp_path, &planned.path)
254 .await
255 .map_err(|err| ProvisioningError::Io(err.to_string()))?;
256 Ok(())
257 }
258}
259
260pub fn validate_manifest(manifest: &ModelManifest) -> Result<(), ProvisioningError> {
261 if manifest.artifacts.is_empty()
262 && !matches!(
263 manifest.provider,
264 LocalModelProvider::Ollama { .. } | LocalModelProvider::ProviderManaged { .. }
265 )
266 {
267 return Err(ProvisioningError::EmptyManifest(manifest.id.clone()));
268 }
269 for artifact in &manifest.artifacts {
270 if artifact.relative_path.is_absolute() {
271 return Err(ProvisioningError::AbsoluteArtifactPath(
272 artifact.relative_path.display().to_string(),
273 ));
274 }
275 }
276 Ok(())
277}
278
279pub fn seamless_m4t_onnx_manifest(
280 model_id: impl Into<String>,
281 onnx_url: impl Into<String>,
282) -> ModelManifest {
283 ModelManifest {
284 id: model_id.into(),
285 provider: LocalModelProvider::HuggingFace {
286 repo: "facebook/hf-seamless-m4t-medium".to_string(),
287 revision: None,
288 },
289 artifacts: vec![ModelArtifact {
290 name: "encoder-decoder-onnx".to_string(),
291 relative_path: PathBuf::from("model.onnx"),
292 source_url: Some(onnx_url.into()),
293 expected_size_bytes: None,
294 sha256: None,
295 }],
296 }
297}
298
299pub fn moshi_server_manifest(model: impl Into<String>) -> ModelManifest {
300 let model = model.into();
301 ModelManifest {
302 id: format!("moshi/{model}"),
303 provider: LocalModelProvider::ProviderManaged {
304 name: format!("moshi/{model}"),
305 },
306 artifacts: Vec::new(),
307 }
308}
309
310async fn verify_artifact_file(planned: &PlannedArtifact) -> Result<(), ProvisioningError> {
311 let metadata = tokio::fs::metadata(&planned.path)
312 .await
313 .map_err(|err| ProvisioningError::Io(err.to_string()))?;
314 verify_size(&planned.artifact, metadata.len())?;
315
316 if planned.artifact.sha256.is_some() {
317 let mut file = tokio::fs::File::open(&planned.path)
318 .await
319 .map_err(|err| ProvisioningError::Io(err.to_string()))?;
320 let mut hasher = Context::new(&SHA256);
321 let mut buffer = vec![0_u8; 64 * 1024];
322 loop {
323 let read = file
324 .read(&mut buffer)
325 .await
326 .map_err(|err| ProvisioningError::Io(err.to_string()))?;
327 if read == 0 {
328 break;
329 }
330 hasher.update(&buffer[..read]);
331 }
332 verify_sha256(&planned.artifact, encode_hex(hasher.finish().as_ref()))?;
333 }
334 Ok(())
335}
336
337fn verify_size(artifact: &ModelArtifact, actual: u64) -> Result<(), ProvisioningError> {
338 if let Some(expected) = artifact.expected_size_bytes
339 && actual != expected
340 {
341 return Err(ProvisioningError::SizeMismatch {
342 name: artifact.name.clone(),
343 expected,
344 actual,
345 });
346 }
347 Ok(())
348}
349
350fn verify_sha256(artifact: &ModelArtifact, actual: String) -> Result<(), ProvisioningError> {
351 if let Some(expected) = &artifact.sha256
352 && !expected.eq_ignore_ascii_case(&actual)
353 {
354 return Err(ProvisioningError::ChecksumMismatch {
355 name: artifact.name.clone(),
356 expected: expected.clone(),
357 actual,
358 });
359 }
360 Ok(())
361}
362
363fn encode_hex(bytes: &[u8]) -> String {
364 const HEX: &[u8; 16] = b"0123456789abcdef";
365 let mut encoded = String::with_capacity(bytes.len() * 2);
366 for byte in bytes {
367 encoded.push(HEX[(byte >> 4) as usize] as char);
368 encoded.push(HEX[(byte & 0x0f) as usize] as char);
369 }
370 encoded
371}
372
373fn sanitize_model_id(id: &str) -> String {
374 id.chars()
375 .map(|ch| match ch {
376 '/' | ':' | '\\' => '_',
377 ch => ch,
378 })
379 .collect()
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn validate_rejects_empty_manifest() {
388 let manifest = ModelManifest {
389 id: "empty".to_string(),
390 provider: LocalModelProvider::LocalFile,
391 artifacts: Vec::new(),
392 };
393 assert_eq!(
394 validate_manifest(&manifest),
395 Err(ProvisioningError::EmptyManifest("empty".to_string()))
396 );
397 }
398
399 #[test]
400 fn validate_rejects_absolute_artifact_paths() {
401 let manifest = ModelManifest {
402 id: "bad".to_string(),
403 provider: LocalModelProvider::LocalFile,
404 artifacts: vec![ModelArtifact {
405 name: "bad".to_string(),
406 relative_path: PathBuf::from("/tmp/model.onnx"),
407 source_url: None,
408 expected_size_bytes: None,
409 sha256: None,
410 }],
411 };
412 assert!(matches!(
413 validate_manifest(&manifest),
414 Err(ProvisioningError::AbsoluteArtifactPath(_))
415 ));
416 }
417
418 #[test]
419 fn inspect_splits_present_and_missing_artifacts() {
420 let root =
421 std::env::temp_dir().join(format!("vona-provisioning-test-{}", std::process::id()));
422 let cache = ModelCache { root };
423 let manifest = seamless_m4t_onnx_manifest(
424 "facebook/hf-seamless-m4t-medium",
425 "https://example.test/model.onnx",
426 );
427 cache.ensure_dirs(&manifest).unwrap();
428 std::fs::write(cache.model_dir(&manifest).join("model.onnx"), b"onnx").unwrap();
429 let plan = cache.inspect(&manifest);
430 assert!(plan.is_ready());
431 assert_eq!(plan.present.len(), 1);
432 let _ = std::fs::remove_dir_all(cache.root);
433 }
434
435 #[test]
436 fn moshi_manifest_is_provider_managed_and_valid_without_artifacts() {
437 let manifest = moshi_server_manifest("kyutai/moshi");
438 assert!(matches!(
439 manifest.provider,
440 LocalModelProvider::ProviderManaged { .. }
441 ));
442 assert!(validate_manifest(&manifest).is_ok());
443 }
444
445 #[test]
446 fn sha256_verification_detects_mismatch() {
447 let artifact = ModelArtifact {
448 name: "model".to_string(),
449 relative_path: PathBuf::from("model.bin"),
450 source_url: None,
451 expected_size_bytes: Some(4),
452 sha256: Some("0000".to_string()),
453 };
454 assert!(matches!(
455 verify_sha256(&artifact, "abcd".to_string()),
456 Err(ProvisioningError::ChecksumMismatch { .. })
457 ));
458 assert!(verify_size(&artifact, 4).is_ok());
459 }
460}