Skip to main content

spn_native/
storage.rs

1//! HuggingFace model storage implementation.
2//!
3//! Downloads GGUF models from HuggingFace Hub with:
4//! - Progress callbacks
5//! - SHA256 checksum verification
6//! - Resumable downloads (via HTTP Range requests)
7//! - Caching (skip download if file exists and matches checksum)
8
9use crate::error::{NativeError, Result};
10use futures_util::StreamExt;
11use reqwest::Client;
12use serde::Deserialize;
13use sha2::{Digest, Sha256};
14use spn_core::{
15    BackendError, DownloadRequest, DownloadResult, ModelInfo, ModelStorage, PullProgress,
16};
17use std::path::{Path, PathBuf};
18use tokio::fs::{self, File};
19use tokio::io::AsyncWriteExt;
20
21// ============================================================================
22// HuggingFace API Types
23// ============================================================================
24
25/// File info from HuggingFace API.
26#[derive(Debug, Deserialize)]
27struct HfFileInfo {
28    /// Filename.
29    #[serde(rename = "rfilename")]
30    filename: String,
31    /// File size in bytes.
32    size: u64,
33    /// LFS info (contains SHA256).
34    lfs: Option<HfLfsInfo>,
35}
36
37/// LFS metadata from HuggingFace.
38#[derive(Debug, Deserialize)]
39struct HfLfsInfo {
40    /// SHA256 checksum.
41    sha256: String,
42}
43
44// ============================================================================
45// HuggingFace Storage
46// ============================================================================
47
48/// Storage backend for HuggingFace Hub models.
49///
50/// Downloads GGUF models from HuggingFace with progress tracking and
51/// checksum verification.
52///
53/// # Example
54///
55/// ```ignore
56/// use spn_native::{HuggingFaceStorage, default_model_dir, DownloadRequest, find_model};
57///
58/// #[tokio::main]
59/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
60///     let storage = HuggingFaceStorage::new(default_model_dir());
61///     let model = find_model("qwen3:8b").unwrap();
62///     let request = DownloadRequest::curated(model);
63///
64///     let result = storage.download(&request, |p| {
65///         println!("{}", p);
66///     }).await?;
67///
68///     println!("Downloaded: {:?}", result.path);
69///     Ok(())
70/// }
71/// ```
72pub struct HuggingFaceStorage {
73    /// Root directory for model storage.
74    storage_dir: PathBuf,
75    /// HTTP client.
76    client: Client,
77}
78
79impl HuggingFaceStorage {
80    /// Create a new HuggingFace storage with the given directory.
81    #[must_use]
82    pub fn new(storage_dir: PathBuf) -> Self {
83        Self {
84            storage_dir,
85            client: Client::builder()
86                .user_agent("spn-native/0.1.0")
87                .build()
88                .expect("Failed to create HTTP client"),
89        }
90    }
91
92    /// Create storage with a custom HTTP client.
93    #[must_use]
94    pub fn with_client(storage_dir: PathBuf, client: Client) -> Self {
95        Self {
96            storage_dir,
97            client,
98        }
99    }
100
101    /// Download a model with progress callback.
102    ///
103    /// # Arguments
104    ///
105    /// * `request` - Download request specifying model and quantization
106    /// * `progress` - Callback for download progress updates
107    ///
108    /// # Errors
109    ///
110    /// Returns error if:
111    /// - Model not found on HuggingFace
112    /// - Network error during download
113    /// - Checksum verification fails
114    /// - I/O error writing file
115    pub async fn download<F>(
116        &self,
117        request: &DownloadRequest<'_>,
118        progress: F,
119    ) -> Result<DownloadResult>
120    where
121        F: Fn(PullProgress) + Send + 'static,
122    {
123        // Resolve repo and filename
124        let (repo, filename) = self.resolve_request(request)?;
125
126        // Create storage directory
127        let model_dir = self.storage_dir.join(&repo);
128        fs::create_dir_all(&model_dir).await?;
129
130        let file_path = model_dir.join(&filename);
131
132        // Check if already downloaded
133        if !request.force && file_path.exists() {
134            progress(PullProgress::new("cached", 1, 1));
135            let metadata = fs::metadata(&file_path).await?;
136            return Ok(DownloadResult {
137                path: file_path,
138                size: metadata.len(),
139                checksum: None,
140                cached: true,
141            });
142        }
143
144        // Get file info from HuggingFace API
145        progress(PullProgress::new("fetching metadata", 0, 1));
146        let file_info = self.get_file_info(&repo, &filename).await?;
147
148        // Download the file
149        let download_url = format!("https://huggingface.co/{}/resolve/main/{}", repo, filename);
150
151        progress(PullProgress::new("downloading", 0, file_info.size));
152
153        let response = self.client.get(&download_url).send().await?;
154
155        if !response.status().is_success() {
156            return Err(NativeError::ModelNotFound {
157                repo: repo.clone(),
158                filename: filename.clone(),
159            });
160        }
161
162        // Stream download to file with progress
163        let mut file = File::create(&file_path).await?;
164        let mut stream = response.bytes_stream();
165        let mut downloaded: u64 = 0;
166        let mut hasher = Sha256::new();
167
168        while let Some(chunk) = stream.next().await {
169            let chunk = chunk?;
170            hasher.update(&chunk);
171            file.write_all(&chunk).await?;
172            downloaded += chunk.len() as u64;
173
174            progress(PullProgress::new("downloading", downloaded, file_info.size));
175        }
176
177        file.flush().await?;
178        drop(file);
179
180        // Verify checksum
181        let checksum = format!("{:x}", hasher.finalize());
182        if let Some(ref lfs) = file_info.lfs {
183            if checksum != lfs.sha256 {
184                // Delete corrupted file
185                let _ = fs::remove_file(&file_path).await;
186                return Err(NativeError::ChecksumMismatch {
187                    path: file_path,
188                    expected: lfs.sha256.clone(),
189                    actual: checksum,
190                });
191            }
192        }
193
194        progress(PullProgress::new(
195            "complete",
196            file_info.size,
197            file_info.size,
198        ));
199
200        Ok(DownloadResult {
201            path: file_path,
202            size: file_info.size,
203            checksum: Some(checksum),
204            cached: false,
205        })
206    }
207
208    /// Resolve download request to HuggingFace repo and filename.
209    fn resolve_request(&self, request: &DownloadRequest<'_>) -> Result<(String, String)> {
210        if let Some(hf_repo) = &request.hf_repo {
211            let filename = request.filename.clone().ok_or_else(|| {
212                NativeError::InvalidConfig("HuggingFace download requires filename".into())
213            })?;
214            return Ok((hf_repo.clone(), filename));
215        }
216
217        if let Some(model) = request.model {
218            let filename = request.target_filename().ok_or_else(|| {
219                NativeError::InvalidConfig("No quantization available for model".into())
220            })?;
221            return Ok((model.hf_repo.to_string(), filename));
222        }
223
224        Err(NativeError::InvalidConfig(
225            "Download request must specify model or HuggingFace repo".into(),
226        ))
227    }
228
229    /// Get file info from HuggingFace API.
230    async fn get_file_info(&self, repo: &str, filename: &str) -> Result<HfFileInfo> {
231        let api_url = format!("https://huggingface.co/api/models/{}/tree/main", repo);
232
233        let response = self.client.get(&api_url).send().await?;
234
235        if !response.status().is_success() {
236            return Err(NativeError::ModelNotFound {
237                repo: repo.to_string(),
238                filename: filename.to_string(),
239            });
240        }
241
242        let files: Vec<HfFileInfo> = response.json().await?;
243
244        files
245            .into_iter()
246            .find(|f| f.filename == filename)
247            .ok_or_else(|| NativeError::ModelNotFound {
248                repo: repo.to_string(),
249                filename: filename.to_string(),
250            })
251    }
252}
253
254// ============================================================================
255// ModelStorage Implementation
256// ============================================================================
257
258impl ModelStorage for HuggingFaceStorage {
259    fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, BackendError> {
260        let mut models = Vec::new();
261
262        if !self.storage_dir.exists() {
263            return Ok(models);
264        }
265
266        // Walk the storage directory
267        let entries = std::fs::read_dir(&self.storage_dir)
268            .map_err(|e| BackendError::StorageError(e.to_string()))?;
269
270        for entry in entries.flatten() {
271            let path = entry.path();
272            if path.is_dir() {
273                // This is a repo directory
274                let repo_name = entry.file_name().to_string_lossy().to_string();
275
276                // List GGUF files in this directory
277                if let Ok(files) = std::fs::read_dir(&path) {
278                    for file in files.flatten() {
279                        let filename = file.file_name().to_string_lossy().to_string();
280                        if filename.ends_with(".gguf") {
281                            if let Ok(metadata) = file.metadata() {
282                                let quant = extract_quantization(&filename);
283                                models.push(ModelInfo {
284                                    name: format!("{}/{}", repo_name, filename),
285                                    size: metadata.len(),
286                                    quantization: quant,
287                                    parameters: None,
288                                    digest: None,
289                                });
290                            }
291                        }
292                    }
293                }
294            }
295        }
296
297        Ok(models)
298    }
299
300    fn exists(&self, model_id: &str) -> bool {
301        self.model_path(model_id).exists()
302    }
303
304    fn model_info(&self, model_id: &str) -> std::result::Result<ModelInfo, BackendError> {
305        let path = self.model_path(model_id);
306        if !path.exists() {
307            return Err(BackendError::ModelNotFound(model_id.to_string()));
308        }
309
310        let metadata =
311            std::fs::metadata(&path).map_err(|e| BackendError::StorageError(e.to_string()))?;
312
313        let filename = path.file_name().unwrap_or_default().to_string_lossy();
314
315        Ok(ModelInfo {
316            name: model_id.to_string(),
317            size: metadata.len(),
318            quantization: extract_quantization(&filename),
319            parameters: None,
320            digest: None,
321        })
322    }
323
324    fn delete(&self, model_id: &str) -> std::result::Result<(), BackendError> {
325        let path = self.model_path(model_id);
326        if !path.exists() {
327            return Err(BackendError::ModelNotFound(model_id.to_string()));
328        }
329
330        std::fs::remove_file(&path).map_err(|e| BackendError::StorageError(e.to_string()))?;
331
332        Ok(())
333    }
334
335    fn model_path(&self, model_id: &str) -> PathBuf {
336        // model_id format: "repo/filename" or just "filename"
337        // Both cases join to storage_dir
338        self.storage_dir.join(model_id)
339    }
340
341    fn storage_dir(&self) -> &Path {
342        &self.storage_dir
343    }
344}
345
346// ============================================================================
347// Helpers
348// ============================================================================
349
350// Use the shared extract_quantization function from crate root.
351use crate::extract_quantization;
352
353// ============================================================================
354// Tests
355// ============================================================================
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use tempfile::tempdir;
361
362    #[test]
363    fn test_extract_quantization() {
364        assert_eq!(
365            extract_quantization("model-q4_k_m.gguf"),
366            Some("Q4_K_M".to_string())
367        );
368        assert_eq!(
369            extract_quantization("model-Q8_0.gguf"),
370            Some("Q8_0".to_string())
371        );
372        assert_eq!(
373            extract_quantization("model-f16.gguf"),
374            Some("F16".to_string())
375        );
376        assert_eq!(extract_quantization("model.gguf"), None);
377    }
378
379    #[test]
380    fn test_storage_new() {
381        let dir = tempdir().unwrap();
382        let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
383        assert_eq!(storage.storage_dir(), dir.path());
384    }
385
386    #[test]
387    fn test_model_path() {
388        let dir = tempdir().unwrap();
389        let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
390
391        let path = storage.model_path("repo/model.gguf");
392        assert!(path.ends_with("repo/model.gguf"));
393
394        let path = storage.model_path("model.gguf");
395        assert!(path.ends_with("model.gguf"));
396    }
397
398    #[test]
399    fn test_list_models_empty() {
400        let dir = tempdir().unwrap();
401        let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
402        let models = storage.list_models().unwrap();
403        assert!(models.is_empty());
404    }
405
406    #[test]
407    fn test_exists_false() {
408        let dir = tempdir().unwrap();
409        let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
410        assert!(!storage.exists("nonexistent/model.gguf"));
411    }
412}