Skip to main content

spn_native/
storage.rs

1//! HuggingFace model storage implementation.
2//!
3//! Downloads GGUF models from HuggingFace Hub with:
4//! - Progress callbacks
5//! - SHA256 checksum verification
6//! - Resumable downloads (via HTTP Range requests)
7//! - Caching (skip download if file exists and matches checksum)
8
9use crate::error::{NativeError, Result};
10use futures_util::StreamExt;
11use reqwest::Client;
12use serde::Deserialize;
13use sha2::{Digest, Sha256};
14use spn_core::{BackendError, DownloadRequest, DownloadResult, ModelInfo, ModelStorage, PullProgress};
15use std::path::{Path, PathBuf};
16use tokio::fs::{self, File};
17use tokio::io::AsyncWriteExt;
18
19// ============================================================================
20// HuggingFace API Types
21// ============================================================================
22
23/// File info from HuggingFace API.
24#[derive(Debug, Deserialize)]
25struct HfFileInfo {
26    /// Filename.
27    #[serde(rename = "rfilename")]
28    filename: String,
29    /// File size in bytes.
30    size: u64,
31    /// LFS info (contains SHA256).
32    lfs: Option<HfLfsInfo>,
33}
34
35/// LFS metadata from HuggingFace.
36#[derive(Debug, Deserialize)]
37struct HfLfsInfo {
38    /// SHA256 checksum.
39    sha256: String,
40}
41
42// ============================================================================
43// HuggingFace Storage
44// ============================================================================
45
46/// Storage backend for HuggingFace Hub models.
47///
48/// Downloads GGUF models from HuggingFace with progress tracking and
49/// checksum verification.
50///
51/// # Example
52///
53/// ```ignore
54/// use spn_native::{HuggingFaceStorage, default_model_dir, DownloadRequest, find_model};
55///
56/// #[tokio::main]
57/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
58///     let storage = HuggingFaceStorage::new(default_model_dir());
59///     let model = find_model("qwen3:8b").unwrap();
60///     let request = DownloadRequest::curated(model);
61///
62///     let result = storage.download(&request, |p| {
63///         println!("{}", p);
64///     }).await?;
65///
66///     println!("Downloaded: {:?}", result.path);
67///     Ok(())
68/// }
69/// ```
70pub struct HuggingFaceStorage {
71    /// Root directory for model storage.
72    storage_dir: PathBuf,
73    /// HTTP client.
74    client: Client,
75}
76
77impl HuggingFaceStorage {
78    /// Create a new HuggingFace storage with the given directory.
79    #[must_use]
80    pub fn new(storage_dir: PathBuf) -> Self {
81        Self {
82            storage_dir,
83            client: Client::builder()
84                .user_agent("spn-native/0.1.0")
85                .build()
86                .expect("Failed to create HTTP client"),
87        }
88    }
89
90    /// Create storage with a custom HTTP client.
91    #[must_use]
92    pub fn with_client(storage_dir: PathBuf, client: Client) -> Self {
93        Self {
94            storage_dir,
95            client,
96        }
97    }
98
99    /// Download a model with progress callback.
100    ///
101    /// # Arguments
102    ///
103    /// * `request` - Download request specifying model and quantization
104    /// * `progress` - Callback for download progress updates
105    ///
106    /// # Errors
107    ///
108    /// Returns error if:
109    /// - Model not found on HuggingFace
110    /// - Network error during download
111    /// - Checksum verification fails
112    /// - I/O error writing file
113    pub async fn download<F>(
114        &self,
115        request: &DownloadRequest<'_>,
116        progress: F,
117    ) -> Result<DownloadResult>
118    where
119        F: Fn(PullProgress) + Send + 'static,
120    {
121        // Resolve repo and filename
122        let (repo, filename) = self.resolve_request(request)?;
123
124        // Create storage directory
125        let model_dir = self.storage_dir.join(&repo);
126        fs::create_dir_all(&model_dir).await?;
127
128        let file_path = model_dir.join(&filename);
129
130        // Check if already downloaded
131        if !request.force && file_path.exists() {
132            progress(PullProgress::new("cached", 1, 1));
133            let metadata = fs::metadata(&file_path).await?;
134            return Ok(DownloadResult {
135                path: file_path,
136                size: metadata.len(),
137                checksum: None,
138                cached: true,
139            });
140        }
141
142        // Get file info from HuggingFace API
143        progress(PullProgress::new("fetching metadata", 0, 1));
144        let file_info = self.get_file_info(&repo, &filename).await?;
145
146        // Download the file
147        let download_url = format!(
148            "https://huggingface.co/{}/resolve/main/{}",
149            repo, filename
150        );
151
152        progress(PullProgress::new("downloading", 0, file_info.size));
153
154        let response = self.client.get(&download_url).send().await?;
155
156        if !response.status().is_success() {
157            return Err(NativeError::ModelNotFound {
158                repo: repo.clone(),
159                filename: filename.clone(),
160            });
161        }
162
163        // Stream download to file with progress
164        let mut file = File::create(&file_path).await?;
165        let mut stream = response.bytes_stream();
166        let mut downloaded: u64 = 0;
167        let mut hasher = Sha256::new();
168
169        while let Some(chunk) = stream.next().await {
170            let chunk = chunk?;
171            hasher.update(&chunk);
172            file.write_all(&chunk).await?;
173            downloaded += chunk.len() as u64;
174
175            progress(PullProgress::new("downloading", downloaded, file_info.size));
176        }
177
178        file.flush().await?;
179        drop(file);
180
181        // Verify checksum
182        let checksum = format!("{:x}", hasher.finalize());
183        if let Some(ref lfs) = file_info.lfs {
184            if checksum != lfs.sha256 {
185                // Delete corrupted file
186                let _ = fs::remove_file(&file_path).await;
187                return Err(NativeError::ChecksumMismatch {
188                    path: file_path,
189                    expected: lfs.sha256.clone(),
190                    actual: checksum,
191                });
192            }
193        }
194
195        progress(PullProgress::new("complete", file_info.size, file_info.size));
196
197        Ok(DownloadResult {
198            path: file_path,
199            size: file_info.size,
200            checksum: Some(checksum),
201            cached: false,
202        })
203    }
204
205    /// Resolve download request to HuggingFace repo and filename.
206    fn resolve_request(&self, request: &DownloadRequest<'_>) -> Result<(String, String)> {
207        if let Some(hf_repo) = &request.hf_repo {
208            let filename = request
209                .filename
210                .clone()
211                .ok_or_else(|| NativeError::InvalidConfig("HuggingFace download requires filename".into()))?;
212            return Ok((hf_repo.clone(), filename));
213        }
214
215        if let Some(model) = request.model {
216            let filename = request
217                .target_filename()
218                .ok_or_else(|| NativeError::InvalidConfig("No quantization available for model".into()))?;
219            return Ok((model.hf_repo.to_string(), filename));
220        }
221
222        Err(NativeError::InvalidConfig(
223            "Download request must specify model or HuggingFace repo".into(),
224        ))
225    }
226
227    /// Get file info from HuggingFace API.
228    async fn get_file_info(&self, repo: &str, filename: &str) -> Result<HfFileInfo> {
229        let api_url = format!(
230            "https://huggingface.co/api/models/{}/tree/main",
231            repo
232        );
233
234        let response = self.client.get(&api_url).send().await?;
235
236        if !response.status().is_success() {
237            return Err(NativeError::ModelNotFound {
238                repo: repo.to_string(),
239                filename: filename.to_string(),
240            });
241        }
242
243        let files: Vec<HfFileInfo> = response.json().await?;
244
245        files
246            .into_iter()
247            .find(|f| f.filename == filename)
248            .ok_or_else(|| NativeError::ModelNotFound {
249                repo: repo.to_string(),
250                filename: filename.to_string(),
251            })
252    }
253}
254
255// ============================================================================
256// ModelStorage Implementation
257// ============================================================================
258
259impl ModelStorage for HuggingFaceStorage {
260    fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, BackendError> {
261        let mut models = Vec::new();
262
263        if !self.storage_dir.exists() {
264            return Ok(models);
265        }
266
267        // Walk the storage directory
268        let entries = std::fs::read_dir(&self.storage_dir)
269            .map_err(|e| BackendError::StorageError(e.to_string()))?;
270
271        for entry in entries.flatten() {
272            let path = entry.path();
273            if path.is_dir() {
274                // This is a repo directory
275                let repo_name = entry.file_name().to_string_lossy().to_string();
276
277                // List GGUF files in this directory
278                if let Ok(files) = std::fs::read_dir(&path) {
279                    for file in files.flatten() {
280                        let filename = file.file_name().to_string_lossy().to_string();
281                        if filename.ends_with(".gguf") {
282                            if let Ok(metadata) = file.metadata() {
283                                let quant = extract_quantization(&filename);
284                                models.push(ModelInfo {
285                                    name: format!("{}/{}", repo_name, filename),
286                                    size: metadata.len(),
287                                    quantization: quant,
288                                    parameters: None,
289                                    digest: None,
290                                });
291                            }
292                        }
293                    }
294                }
295            }
296        }
297
298        Ok(models)
299    }
300
301    fn exists(&self, model_id: &str) -> bool {
302        self.model_path(model_id).exists()
303    }
304
305    fn model_info(&self, model_id: &str) -> std::result::Result<ModelInfo, BackendError> {
306        let path = self.model_path(model_id);
307        if !path.exists() {
308            return Err(BackendError::ModelNotFound(model_id.to_string()));
309        }
310
311        let metadata = std::fs::metadata(&path)
312            .map_err(|e| BackendError::StorageError(e.to_string()))?;
313
314        let filename = path.file_name().unwrap_or_default().to_string_lossy();
315
316        Ok(ModelInfo {
317            name: model_id.to_string(),
318            size: metadata.len(),
319            quantization: extract_quantization(&filename),
320            parameters: None,
321            digest: None,
322        })
323    }
324
325    fn delete(&self, model_id: &str) -> std::result::Result<(), BackendError> {
326        let path = self.model_path(model_id);
327        if !path.exists() {
328            return Err(BackendError::ModelNotFound(model_id.to_string()));
329        }
330
331        std::fs::remove_file(&path)
332            .map_err(|e| BackendError::StorageError(e.to_string()))?;
333
334        Ok(())
335    }
336
337    fn model_path(&self, model_id: &str) -> PathBuf {
338        // model_id format: "repo/filename" or just "filename"
339        // Both cases join to storage_dir
340        self.storage_dir.join(model_id)
341    }
342
343    fn storage_dir(&self) -> &Path {
344        &self.storage_dir
345    }
346}
347
348// ============================================================================
349// Helpers
350// ============================================================================
351
352// Use the shared extract_quantization function from crate root.
353use crate::extract_quantization;
354
355// ============================================================================
356// Tests
357// ============================================================================
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use tempfile::tempdir;
363
364    #[test]
365    fn test_extract_quantization() {
366        assert_eq!(
367            extract_quantization("model-q4_k_m.gguf"),
368            Some("Q4_K_M".to_string())
369        );
370        assert_eq!(
371            extract_quantization("model-Q8_0.gguf"),
372            Some("Q8_0".to_string())
373        );
374        assert_eq!(
375            extract_quantization("model-f16.gguf"),
376            Some("F16".to_string())
377        );
378        assert_eq!(extract_quantization("model.gguf"), None);
379    }
380
381    #[test]
382    fn test_storage_new() {
383        let dir = tempdir().unwrap();
384        let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
385        assert_eq!(storage.storage_dir(), dir.path());
386    }
387
388    #[test]
389    fn test_model_path() {
390        let dir = tempdir().unwrap();
391        let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
392
393        let path = storage.model_path("repo/model.gguf");
394        assert!(path.ends_with("repo/model.gguf"));
395
396        let path = storage.model_path("model.gguf");
397        assert!(path.ends_with("model.gguf"));
398    }
399
400    #[test]
401    fn test_list_models_empty() {
402        let dir = tempdir().unwrap();
403        let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
404        let models = storage.list_models().unwrap();
405        assert!(models.is_empty());
406    }
407
408    #[test]
409    fn test_exists_false() {
410        let dir = tempdir().unwrap();
411        let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
412        assert!(!storage.exists("nonexistent/model.gguf"));
413    }
414}