1use crate::error::{NativeError, Result};
10use futures_util::StreamExt;
11use reqwest::Client;
12use serde::Deserialize;
13use sha2::{Digest, Sha256};
14use spn_core::{
15 BackendError, DownloadRequest, DownloadResult, ModelInfo, ModelStorage, PullProgress,
16};
17use std::path::{Path, PathBuf};
18use tokio::fs::{self, File};
19use tokio::io::AsyncWriteExt;
20
21#[derive(Debug, Deserialize)]
27struct HfFileInfo {
28 #[serde(rename = "rfilename")]
30 filename: String,
31 size: u64,
33 lfs: Option<HfLfsInfo>,
35}
36
37#[derive(Debug, Deserialize)]
39struct HfLfsInfo {
40 sha256: String,
42}
43
44pub struct HuggingFaceStorage {
73 storage_dir: PathBuf,
75 client: Client,
77}
78
79impl HuggingFaceStorage {
80 #[must_use]
82 pub fn new(storage_dir: PathBuf) -> Self {
83 Self {
84 storage_dir,
85 client: Client::builder()
86 .user_agent("spn-native/0.1.0")
87 .build()
88 .expect("Failed to create HTTP client"),
89 }
90 }
91
92 #[must_use]
94 pub fn with_client(storage_dir: PathBuf, client: Client) -> Self {
95 Self {
96 storage_dir,
97 client,
98 }
99 }
100
101 pub async fn download<F>(
116 &self,
117 request: &DownloadRequest<'_>,
118 progress: F,
119 ) -> Result<DownloadResult>
120 where
121 F: Fn(PullProgress) + Send + 'static,
122 {
123 let (repo, filename) = self.resolve_request(request)?;
125
126 let model_dir = self.storage_dir.join(&repo);
128 fs::create_dir_all(&model_dir).await?;
129
130 let file_path = model_dir.join(&filename);
131
132 if !request.force && file_path.exists() {
134 progress(PullProgress::new("cached", 1, 1));
135 let metadata = fs::metadata(&file_path).await?;
136 return Ok(DownloadResult {
137 path: file_path,
138 size: metadata.len(),
139 checksum: None,
140 cached: true,
141 });
142 }
143
144 progress(PullProgress::new("fetching metadata", 0, 1));
146 let file_info = self.get_file_info(&repo, &filename).await?;
147
148 let download_url = format!("https://huggingface.co/{}/resolve/main/{}", repo, filename);
150
151 progress(PullProgress::new("downloading", 0, file_info.size));
152
153 let response = self.client.get(&download_url).send().await?;
154
155 if !response.status().is_success() {
156 return Err(NativeError::ModelNotFound {
157 repo: repo.clone(),
158 filename: filename.clone(),
159 });
160 }
161
162 let mut file = File::create(&file_path).await?;
164 let mut stream = response.bytes_stream();
165 let mut downloaded: u64 = 0;
166 let mut hasher = Sha256::new();
167
168 while let Some(chunk) = stream.next().await {
169 let chunk = chunk?;
170 hasher.update(&chunk);
171 file.write_all(&chunk).await?;
172 downloaded += chunk.len() as u64;
173
174 progress(PullProgress::new("downloading", downloaded, file_info.size));
175 }
176
177 file.flush().await?;
178 drop(file);
179
180 let checksum = format!("{:x}", hasher.finalize());
182 if let Some(ref lfs) = file_info.lfs {
183 if checksum != lfs.sha256 {
184 let _ = fs::remove_file(&file_path).await;
186 return Err(NativeError::ChecksumMismatch {
187 path: file_path,
188 expected: lfs.sha256.clone(),
189 actual: checksum,
190 });
191 }
192 }
193
194 progress(PullProgress::new(
195 "complete",
196 file_info.size,
197 file_info.size,
198 ));
199
200 Ok(DownloadResult {
201 path: file_path,
202 size: file_info.size,
203 checksum: Some(checksum),
204 cached: false,
205 })
206 }
207
208 fn resolve_request(&self, request: &DownloadRequest<'_>) -> Result<(String, String)> {
210 if let Some(hf_repo) = &request.hf_repo {
211 let filename = request.filename.clone().ok_or_else(|| {
212 NativeError::InvalidConfig("HuggingFace download requires filename".into())
213 })?;
214 return Ok((hf_repo.clone(), filename));
215 }
216
217 if let Some(model) = request.model {
218 let filename = request.target_filename().ok_or_else(|| {
219 NativeError::InvalidConfig("No quantization available for model".into())
220 })?;
221 return Ok((model.hf_repo.to_string(), filename));
222 }
223
224 Err(NativeError::InvalidConfig(
225 "Download request must specify model or HuggingFace repo".into(),
226 ))
227 }
228
229 async fn get_file_info(&self, repo: &str, filename: &str) -> Result<HfFileInfo> {
231 let api_url = format!("https://huggingface.co/api/models/{}/tree/main", repo);
232
233 let response = self.client.get(&api_url).send().await?;
234
235 if !response.status().is_success() {
236 return Err(NativeError::ModelNotFound {
237 repo: repo.to_string(),
238 filename: filename.to_string(),
239 });
240 }
241
242 let files: Vec<HfFileInfo> = response.json().await?;
243
244 files
245 .into_iter()
246 .find(|f| f.filename == filename)
247 .ok_or_else(|| NativeError::ModelNotFound {
248 repo: repo.to_string(),
249 filename: filename.to_string(),
250 })
251 }
252}
253
254impl ModelStorage for HuggingFaceStorage {
259 fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, BackendError> {
260 let mut models = Vec::new();
261
262 if !self.storage_dir.exists() {
263 return Ok(models);
264 }
265
266 let entries = std::fs::read_dir(&self.storage_dir)
268 .map_err(|e| BackendError::StorageError(e.to_string()))?;
269
270 for entry in entries.flatten() {
271 let path = entry.path();
272 if path.is_dir() {
273 let repo_name = entry.file_name().to_string_lossy().to_string();
275
276 if let Ok(files) = std::fs::read_dir(&path) {
278 for file in files.flatten() {
279 let filename = file.file_name().to_string_lossy().to_string();
280 if filename.ends_with(".gguf") {
281 if let Ok(metadata) = file.metadata() {
282 let quant = extract_quantization(&filename);
283 models.push(ModelInfo {
284 name: format!("{}/{}", repo_name, filename),
285 size: metadata.len(),
286 quantization: quant,
287 parameters: None,
288 digest: None,
289 });
290 }
291 }
292 }
293 }
294 }
295 }
296
297 Ok(models)
298 }
299
300 fn exists(&self, model_id: &str) -> bool {
301 self.model_path(model_id).exists()
302 }
303
304 fn model_info(&self, model_id: &str) -> std::result::Result<ModelInfo, BackendError> {
305 let path = self.model_path(model_id);
306 if !path.exists() {
307 return Err(BackendError::ModelNotFound(model_id.to_string()));
308 }
309
310 let metadata =
311 std::fs::metadata(&path).map_err(|e| BackendError::StorageError(e.to_string()))?;
312
313 let filename = path.file_name().unwrap_or_default().to_string_lossy();
314
315 Ok(ModelInfo {
316 name: model_id.to_string(),
317 size: metadata.len(),
318 quantization: extract_quantization(&filename),
319 parameters: None,
320 digest: None,
321 })
322 }
323
324 fn delete(&self, model_id: &str) -> std::result::Result<(), BackendError> {
325 let path = self.model_path(model_id);
326 if !path.exists() {
327 return Err(BackendError::ModelNotFound(model_id.to_string()));
328 }
329
330 std::fs::remove_file(&path).map_err(|e| BackendError::StorageError(e.to_string()))?;
331
332 Ok(())
333 }
334
335 fn model_path(&self, model_id: &str) -> PathBuf {
336 self.storage_dir.join(model_id)
339 }
340
341 fn storage_dir(&self) -> &Path {
342 &self.storage_dir
343 }
344}
345
346use crate::extract_quantization;
352
353#[cfg(test)]
358mod tests {
359 use super::*;
360 use tempfile::tempdir;
361
362 #[test]
363 fn test_extract_quantization() {
364 assert_eq!(
365 extract_quantization("model-q4_k_m.gguf"),
366 Some("Q4_K_M".to_string())
367 );
368 assert_eq!(
369 extract_quantization("model-Q8_0.gguf"),
370 Some("Q8_0".to_string())
371 );
372 assert_eq!(
373 extract_quantization("model-f16.gguf"),
374 Some("F16".to_string())
375 );
376 assert_eq!(extract_quantization("model.gguf"), None);
377 }
378
379 #[test]
380 fn test_storage_new() {
381 let dir = tempdir().unwrap();
382 let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
383 assert_eq!(storage.storage_dir(), dir.path());
384 }
385
386 #[test]
387 fn test_model_path() {
388 let dir = tempdir().unwrap();
389 let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
390
391 let path = storage.model_path("repo/model.gguf");
392 assert!(path.ends_with("repo/model.gguf"));
393
394 let path = storage.model_path("model.gguf");
395 assert!(path.ends_with("model.gguf"));
396 }
397
398 #[test]
399 fn test_list_models_empty() {
400 let dir = tempdir().unwrap();
401 let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
402 let models = storage.list_models().unwrap();
403 assert!(models.is_empty());
404 }
405
406 #[test]
407 fn test_exists_false() {
408 let dir = tempdir().unwrap();
409 let storage = HuggingFaceStorage::new(dir.path().to_path_buf());
410 assert!(!storage.exists("nonexistent/model.gguf"));
411 }
412}