spn_core/storage.rs
1//! Model storage trait (download-only, no inference).
2//!
3//! This module defines the [`ModelStorage`] trait for downloading and managing
4//! local models. The actual implementation lives in `spn-native`.
5//!
6//! **Note:** This trait is for spn (package manager) - download only, no inference.
7//! For inference, see `NativeRuntime` in Nika.
8
9use crate::backend::{BackendError, ModelInfo, PullProgress, Quantization};
10use crate::model::KnownModel;
11use std::path::{Path, PathBuf};
12
13// ============================================================================
14// Storage Location
15// ============================================================================
16
17/// Default model storage directory.
18///
19/// Models are stored in `~/.spn/models/` by default.
20///
21/// Requires the `dirs` feature to be enabled.
22#[cfg(feature = "dirs")]
23#[must_use]
24pub fn default_model_dir() -> PathBuf {
25 dirs::home_dir()
26 .map(|h| h.join(".spn").join("models"))
27 .unwrap_or_else(|| PathBuf::from(".spn/models"))
28}
29
30/// Default model storage directory (fallback when `dirs` feature is disabled).
31///
32/// Returns a relative path `.spn/models` when running without home directory support.
33#[cfg(not(feature = "dirs"))]
34#[must_use]
35pub fn default_model_dir() -> PathBuf {
36 PathBuf::from(".spn/models")
37}
38
39// ============================================================================
40// Sync Storage Trait (zero-dep)
41// ============================================================================
42
43/// Model storage backend (sync version).
44///
45/// This trait defines operations for downloading and managing local models.
46/// Implementations include `HuggingFaceStorage` in spn-native.
47///
48/// For the async version, enable the `async-storage` feature.
49pub trait ModelStorage: Send + Sync {
50 /// List downloaded models.
51 ///
52 /// # Errors
53 ///
54 /// Returns error if the storage directory cannot be read.
55 fn list_models(&self) -> Result<Vec<ModelInfo>, BackendError>;
56
57 /// Check if a model exists locally.
58 fn exists(&self, model_id: &str) -> bool;
59
60 /// Get model info for a specific model.
61 ///
62 /// # Errors
63 ///
64 /// Returns error if the model is not found.
65 fn model_info(&self, model_id: &str) -> Result<ModelInfo, BackendError>;
66
67 /// Delete a model.
68 ///
69 /// # Errors
70 ///
71 /// Returns error if the model cannot be deleted.
72 fn delete(&self, model_id: &str) -> Result<(), BackendError>;
73
74 /// Get the local path for a model.
75 fn model_path(&self, model_id: &str) -> PathBuf;
76
77 /// Get the storage root directory.
78 fn storage_dir(&self) -> &Path;
79}
80
81// ============================================================================
82// Download Progress Callback
83// ============================================================================
84
85/// Type alias for download progress callbacks.
86pub type ProgressCallback = Box<dyn Fn(PullProgress) + Send + 'static>;
87
88// ============================================================================
89// Download Request
90// ============================================================================
91
92/// Request to download a model.
93#[derive(Debug, Clone)]
94pub struct DownloadRequest<'a> {
95 /// The model to download (curated).
96 pub model: Option<&'a KnownModel>,
97
98 /// HuggingFace repo (for passthrough).
99 pub hf_repo: Option<String>,
100
101 /// Specific filename to download.
102 pub filename: Option<String>,
103
104 /// Quantization level (for curated models).
105 pub quantization: Option<Quantization>,
106
107 /// Force re-download even if exists.
108 pub force: bool,
109}
110
111impl<'a> DownloadRequest<'a> {
112 /// Create a request for a curated model.
113 #[must_use]
114 pub fn curated(model: &'a KnownModel) -> Self {
115 Self {
116 model: Some(model),
117 hf_repo: None,
118 filename: None,
119 quantization: None,
120 force: false,
121 }
122 }
123
124 /// Create a request for a HuggingFace model.
125 #[must_use]
126 pub fn huggingface(repo: impl Into<String>, filename: impl Into<String>) -> Self {
127 Self {
128 model: None,
129 hf_repo: Some(repo.into()),
130 filename: Some(filename.into()),
131 quantization: None,
132 force: false,
133 }
134 }
135
136 /// Set the quantization level.
137 #[must_use]
138 pub fn with_quantization(mut self, quant: Quantization) -> Self {
139 self.quantization = Some(quant);
140 self
141 }
142
143 /// Force re-download.
144 #[must_use]
145 pub fn force(mut self) -> Self {
146 self.force = true;
147 self
148 }
149
150 /// Get the target filename for this download.
151 #[must_use]
152 pub fn target_filename(&self) -> Option<String> {
153 if let Some(filename) = &self.filename {
154 return Some(filename.clone());
155 }
156
157 if let Some(model) = self.model {
158 let quant = self.quantization.unwrap_or(Quantization::Q4_K_M);
159 return model.filename_for_quant(quant).map(String::from);
160 }
161
162 None
163 }
164}
165
166// ============================================================================
167// Download Result
168// ============================================================================
169
170/// Result of a model download.
171#[derive(Debug, Clone)]
172pub struct DownloadResult {
173 /// Local path to the downloaded model.
174 pub path: PathBuf,
175
176 /// Size of the downloaded file in bytes.
177 pub size: u64,
178
179 /// SHA256 checksum of the file.
180 pub checksum: Option<String>,
181
182 /// Whether the file was already cached.
183 pub cached: bool,
184}
185
186// ============================================================================
187// Tests
188// ============================================================================
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_default_model_dir() {
196 let dir = default_model_dir();
197 // Without dirs feature, just check it contains expected components
198 let dir_str = dir.to_string_lossy();
199 assert!(dir_str.contains(".spn") || dir_str.contains("spn"));
200 assert!(dir_str.contains("models"));
201 }
202
203 #[test]
204 fn test_download_request_curated() {
205 use crate::model::find_model;
206
207 let model = find_model("qwen3:8b").unwrap();
208 let request = DownloadRequest::curated(model).with_quantization(Quantization::Q4_K_M);
209
210 assert!(request.model.is_some());
211 assert!(request.hf_repo.is_none());
212 assert_eq!(request.quantization, Some(Quantization::Q4_K_M));
213
214 let filename = request.target_filename();
215 assert!(filename.is_some());
216 assert!(filename.unwrap().contains("q4_k_m"));
217 }
218
219 #[test]
220 fn test_download_request_huggingface() {
221 let request =
222 DownloadRequest::huggingface("bartowski/Model", "model-q4_k_m.gguf").force();
223
224 assert!(request.model.is_none());
225 assert_eq!(request.hf_repo.as_deref(), Some("bartowski/Model"));
226 assert!(request.force);
227
228 let filename = request.target_filename();
229 assert_eq!(filename.as_deref(), Some("model-q4_k_m.gguf"));
230 }
231}