1use crate::error::{NativeError, Result};
10use futures_util::StreamExt;
11use reqwest::Client;
12use serde::Deserialize;
13use sha2::{Digest, Sha256};
14use spn_core::{BackendError, DownloadRequest, DownloadResult, ModelInfo, ModelStorage, PullProgress};
15use std::path::{Path, PathBuf};
16use tokio::fs::{self, File};
17use tokio::io::AsyncWriteExt;
18
19#[derive(Debug, Deserialize)]
25struct HfFileInfo {
26 #[serde(rename = "rfilename")]
28 filename: String,
29 size: u64,
31 lfs: Option<HfLfsInfo>,
33}
34
35#[derive(Debug, Deserialize)]
37struct HfLfsInfo {
38 sha256: String,
40}
41
42pub struct HuggingFaceStorage {
71 storage_dir: PathBuf,
73 client: Client,
75}
76
77impl HuggingFaceStorage {
78 #[must_use]
80 pub fn new(storage_dir: PathBuf) -> Self {
81 Self {
82 storage_dir,
83 client: Client::builder()
84 .user_agent("spn-native/0.1.0")
85 .build()
86 .expect("Failed to create HTTP client"),
87 }
88 }
89
90 #[must_use]
92 pub fn with_client(storage_dir: PathBuf, client: Client) -> Self {
93 Self {
94 storage_dir,
95 client,
96 }
97 }
98
99 pub async fn download<F>(
114 &self,
115 request: &DownloadRequest<'_>,
116 progress: F,
117 ) -> Result<DownloadResult>
118 where
119 F: Fn(PullProgress) + Send + 'static,
120 {
121 let (repo, filename) = self.resolve_request(request)?;
123
124 let model_dir = self.storage_dir.join(&repo);
126 fs::create_dir_all(&model_dir).await?;
127
128 let file_path = model_dir.join(&filename);
129
130 if !request.force && file_path.exists() {
132 progress(PullProgress::new("cached", 1, 1));
133 let metadata = fs::metadata(&file_path).await?;
134 return Ok(DownloadResult {
135 path: file_path,
136 size: metadata.len(),
137 checksum: None,
138 cached: true,
139 });
140 }
141
142 progress(PullProgress::new("fetching metadata", 0, 1));
144 let file_info = self.get_file_info(&repo, &filename).await?;
145
146 let download_url = format!(
148 "https://huggingface.co/{}/resolve/main/{}",
149 repo, filename
150 );
151
152 progress(PullProgress::new("downloading", 0, file_info.size));
153
154 let response = self.client.get(&download_url).send().await?;
155
156 if !response.status().is_success() {
157 return Err(NativeError::ModelNotFound {
158 repo: repo.clone(),
159 filename: filename.clone(),
160 });
161 }
162
163 let mut file = File::create(&file_path).await?;
165 let mut stream = response.bytes_stream();
166 let mut downloaded: u64 = 0;
167 let mut hasher = Sha256::new();
168
169 while let Some(chunk) = stream.next().await {
170 let chunk = chunk?;
171 hasher.update(&chunk);
172 file.write_all(&chunk).await?;
173 downloaded += chunk.len() as u64;
174
175 progress(PullProgress::new("downloading", downloaded, file_info.size));
176 }
177
178 file.flush().await?;
179 drop(file);
180
181 let checksum = format!("{:x}", hasher.finalize());
183 if let Some(ref lfs) = file_info.lfs {
184 if checksum != lfs.sha256 {
185 let _ = fs::remove_file(&file_path).await;
187 return Err(NativeError::ChecksumMismatch {
188 path: file_path,
189 expected: lfs.sha256.clone(),
190 actual: checksum,
191 });
192 }
193 }
194
195 progress(PullProgress::new("complete", file_info.size, file_info.size));
196
197 Ok(DownloadResult {
198 path: file_path,
199 size: file_info.size,
200 checksum: Some(checksum),
201 cached: false,
202 })
203 }
204
205 fn resolve_request(&self, request: &DownloadRequest<'_>) -> Result<(String, String)> {
207 if let Some(hf_repo) = &request.hf_repo {
208 let filename = request
209 .filename
210 .clone()
211 .ok_or_else(|| NativeError::InvalidConfig("HuggingFace download requires filename".into()))?;
212 return Ok((hf_repo.clone(), filename));
213 }
214
215 if let Some(model) = request.model {
216 let filename = request
217 .target_filename()
218 .ok_or_else(|| NativeError::InvalidConfig("No quantization available for model".into()))?;
219 return Ok((model.hf_repo.to_string(), filename));
220 }
221
222 Err(NativeError::InvalidConfig(
223 "Download request must specify model or HuggingFace repo".into(),
224 ))
225 }
226
227 async fn get_file_info(&self, repo: &str, filename: &str) -> Result<HfFileInfo> {
229 let api_url = format!(
230 "https://huggingface.co/api/models/{}/tree/main",
231 repo
232 );
233
234 let response = self.client.get(&api_url).send().await?;
235
236 if !response.status().is_success() {
237 return Err(NativeError::ModelNotFound {
238 repo: repo.to_string(),
239 filename: filename.to_string(),
240 });
241 }
242
243 let files: Vec<HfFileInfo> = response.json().await?;
244
245 files
246 .into_iter()
247 .find(|f| f.filename == filename)
248 .ok_or_else(|| NativeError::ModelNotFound {
249 repo: repo.to_string(),
250 filename: filename.to_string(),
251 })
252 }
253}
254
255impl ModelStorage for HuggingFaceStorage {
260 fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, BackendError> {
261 let mut models = Vec::new();
262
263 if !self.storage_dir.exists() {
264 return Ok(models);
265 }
266
267 let entries = std::fs::read_dir(&self.storage_dir)
269 .map_err(|e| BackendError::StorageError(e.to_string()))?;
270
271 for entry in entries.flatten() {
272 let path = entry.path();
273 if path.is_dir() {
274 let repo_name = entry.file_name().to_string_lossy().to_string();
276
277 if let Ok(files) = std::fs::read_dir(&path) {
279 for file in files.flatten() {
280 let filename = file.file_name().to_string_lossy().to_string();
281 if filename.ends_with(".gguf") {
282 if let Ok(metadata) = file.metadata() {
283 let quant = extract_quantization(&filename);
284 models.push(ModelInfo {
285 name: format!("{}/{}", repo_name, filename),
286 size: metadata.len(),
287 quantization: quant,
288 parameters: None,
289 digest: None,
290 });
291 }
292 }
293 }
294 }
295 }
296 }
297
298 Ok(models)
299 }
300
301 fn exists(&self, model_id: &str) -> bool {
302 self.model_path(model_id).exists()
303 }
304
305 fn model_info(&self, model_id: &str) -> std::result::Result<ModelInfo, BackendError> {
306 let path = self.model_path(model_id);
307 if !path.exists() {
308 return Err(BackendError::ModelNotFound(model_id.to_string()));
309 }
310
311 let metadata = std::fs::metadata(&path)
312 .map_err(|e| BackendError::StorageError(e.to_string()))?;
313
314 let filename = path.file_name().unwrap_or_default().to_string_lossy();
315
316 Ok(ModelInfo {
317 name: model_id.to_string(),
318 size: metadata.len(),
319 quantization: extract_quantization(&filename),
320 parameters: None,
321 digest: None,
322 })
323 }
324
325 fn delete(&self, model_id: &str) -> std::result::Result<(), BackendError> {
326 let path = self.model_path(model_id);
327 if !path.exists() {
328 return Err(BackendError::ModelNotFound(model_id.to_string()));
329 }
330
331 std::fs::remove_file(&path)
332 .map_err(|e| BackendError::StorageError(e.to_string()))?;
333
334 Ok(())
335 }
336
337 fn model_path(&self, model_id: &str) -> PathBuf {
338 self.storage_dir.join(model_id)
341 }
342
343 fn storage_dir(&self) -> &Path {
344 &self.storage_dir
345 }
346}
347
348use crate::extract_quantization;
354
355#[cfg(test)]
360mod tests {
361 use super::*;
362 use tempfile::tempdir;
363
364 #[test]
365 fn test_extract_quantization() {
366 assert_eq!(
367 extract_quantization("model-q4_k_m.gguf"),
368 Some("Q4_K_M".to_string())
369 );
370 assert_eq!(
371 extract_quantization("model-Q8_0.gguf"),
372 Some("Q8_0".to_string())
373 );
374 assert_eq!(
375 extract_quantization("model-f16.gguf"),
376 Some("F16".to_string())
377 );
378 assert_eq!(extract_quantization("model.gguf"), None);
379 }
380
381 #[test]
382 fn test_storage_new() {
383 let dir = tempdir().unwrap();
384 let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
385 assert_eq!(storage.storage_dir(), dir.path());
386 }
387
388 #[test]
389 fn test_model_path() {
390 let dir = tempdir().unwrap();
391 let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
392
393 let path = storage.model_path("repo/model.gguf");
394 assert!(path.ends_with("repo/model.gguf"));
395
396 let path = storage.model_path("model.gguf");
397 assert!(path.ends_with("model.gguf"));
398 }
399
400 #[test]
401 fn test_list_models_empty() {
402 let dir = tempdir().unwrap();
403 let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
404 let models = storage.list_models().unwrap();
405 assert!(models.is_empty());
406 }
407
408 #[test]
409 fn test_exists_false() {
410 let dir = tempdir().unwrap();
411 let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
412 assert!(!storage.exists("nonexistent/model.gguf"));
413 }
414}