Skip to main content

spn_core/
storage.rs

1//! Model storage trait (download-only, no inference).
2//!
3//! This module defines the [`ModelStorage`] trait for downloading and managing
4//! local models. The actual implementation lives in `spn-native`.
5//!
6//! **Note:** This trait is for spn (package manager) - download only, no inference.
7//! For inference, see `NativeRuntime` in Nika.
8
9use crate::backend::{BackendError, ModelInfo, PullProgress, Quantization};
10use crate::model::KnownModel;
11use std::path::{Path, PathBuf};
12
13// ============================================================================
14// Storage Location
15// ============================================================================
16
17/// Default model storage directory.
18///
19/// Models are stored in `~/.spn/models/` by default.
20///
21/// Requires the `dirs` feature to be enabled.
22#[cfg(feature = "dirs")]
23#[must_use]
24pub fn default_model_dir() -> PathBuf {
25    dirs::home_dir()
26        .map(|h| h.join(".spn").join("models"))
27        .unwrap_or_else(|| PathBuf::from(".spn/models"))
28}
29
30/// Default model storage directory (fallback when `dirs` feature is disabled).
31///
32/// Returns a relative path `.spn/models` when running without home directory support.
33#[cfg(not(feature = "dirs"))]
34#[must_use]
35pub fn default_model_dir() -> PathBuf {
36    PathBuf::from(".spn/models")
37}
38
39// ============================================================================
40// Sync Storage Trait (zero-dep)
41// ============================================================================
42
43/// Model storage backend (sync version).
44///
45/// This trait defines operations for downloading and managing local models.
46/// Implementations include `HuggingFaceStorage` in spn-native.
47///
48/// For the async version, enable the `async-storage` feature.
49pub trait ModelStorage: Send + Sync {
50    /// List downloaded models.
51    ///
52    /// # Errors
53    ///
54    /// Returns error if the storage directory cannot be read.
55    fn list_models(&self) -> Result<Vec<ModelInfo>, BackendError>;
56
57    /// Check if a model exists locally.
58    fn exists(&self, model_id: &str) -> bool;
59
60    /// Get model info for a specific model.
61    ///
62    /// # Errors
63    ///
64    /// Returns error if the model is not found.
65    fn model_info(&self, model_id: &str) -> Result<ModelInfo, BackendError>;
66
67    /// Delete a model.
68    ///
69    /// # Errors
70    ///
71    /// Returns error if the model cannot be deleted.
72    fn delete(&self, model_id: &str) -> Result<(), BackendError>;
73
74    /// Get the local path for a model.
75    fn model_path(&self, model_id: &str) -> PathBuf;
76
77    /// Get the storage root directory.
78    fn storage_dir(&self) -> &Path;
79}
80
81// ============================================================================
82// Download Progress Callback
83// ============================================================================
84
85/// Type alias for download progress callbacks.
86pub type ProgressCallback = Box<dyn Fn(PullProgress) + Send + 'static>;
87
88// ============================================================================
89// Download Request
90// ============================================================================
91
92/// Request to download a model.
93#[derive(Debug, Clone)]
94pub struct DownloadRequest<'a> {
95    /// The model to download (curated).
96    pub model: Option<&'a KnownModel>,
97
98    /// HuggingFace repo (for passthrough).
99    pub hf_repo: Option<String>,
100
101    /// Specific filename to download.
102    pub filename: Option<String>,
103
104    /// Quantization level (for curated models).
105    pub quantization: Option<Quantization>,
106
107    /// Force re-download even if exists.
108    pub force: bool,
109}
110
111impl<'a> DownloadRequest<'a> {
112    /// Create a request for a curated model.
113    #[must_use]
114    pub fn curated(model: &'a KnownModel) -> Self {
115        Self {
116            model: Some(model),
117            hf_repo: None,
118            filename: None,
119            quantization: None,
120            force: false,
121        }
122    }
123
124    /// Create a request for a HuggingFace model.
125    #[must_use]
126    pub fn huggingface(repo: impl Into<String>, filename: impl Into<String>) -> Self {
127        Self {
128            model: None,
129            hf_repo: Some(repo.into()),
130            filename: Some(filename.into()),
131            quantization: None,
132            force: false,
133        }
134    }
135
136    /// Set the quantization level.
137    #[must_use]
138    pub fn with_quantization(mut self, quant: Quantization) -> Self {
139        self.quantization = Some(quant);
140        self
141    }
142
143    /// Force re-download.
144    #[must_use]
145    pub fn force(mut self) -> Self {
146        self.force = true;
147        self
148    }
149
150    /// Get the target filename for this download.
151    #[must_use]
152    pub fn target_filename(&self) -> Option<String> {
153        if let Some(filename) = &self.filename {
154            return Some(filename.clone());
155        }
156
157        if let Some(model) = self.model {
158            let quant = self.quantization.unwrap_or(Quantization::Q4_K_M);
159            return model.filename_for_quant(quant).map(String::from);
160        }
161
162        None
163    }
164}
165
166// ============================================================================
167// Download Result
168// ============================================================================
169
170/// Result of a model download.
171#[derive(Debug, Clone)]
172pub struct DownloadResult {
173    /// Local path to the downloaded model.
174    pub path: PathBuf,
175
176    /// Size of the downloaded file in bytes.
177    pub size: u64,
178
179    /// SHA256 checksum of the file.
180    pub checksum: Option<String>,
181
182    /// Whether the file was already cached.
183    pub cached: bool,
184}
185
186// ============================================================================
187// Tests
188// ============================================================================
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[test]
195    fn test_default_model_dir() {
196        let dir = default_model_dir();
197        // Without dirs feature, just check it contains expected components
198        let dir_str = dir.to_string_lossy();
199        assert!(dir_str.contains(".spn") || dir_str.contains("spn"));
200        assert!(dir_str.contains("models"));
201    }
202
203    #[test]
204    fn test_download_request_curated() {
205        use crate::model::find_model;
206
207        let model = find_model("qwen3:8b").unwrap();
208        let request = DownloadRequest::curated(model).with_quantization(Quantization::Q4_K_M);
209
210        assert!(request.model.is_some());
211        assert!(request.hf_repo.is_none());
212        assert_eq!(request.quantization, Some(Quantization::Q4_K_M));
213
214        let filename = request.target_filename();
215        assert!(filename.is_some());
216        assert!(filename.unwrap().contains("q4_k_m"));
217    }
218
219    #[test]
220    fn test_download_request_huggingface() {
221        let request =
222            DownloadRequest::huggingface("bartowski/Model", "model-q4_k_m.gguf").force();
223
224        assert!(request.model.is_none());
225        assert_eq!(request.hf_repo.as_deref(), Some("bartowski/Model"));
226        assert!(request.force);
227
228        let filename = request.target_filename();
229        assert_eq!(filename.as_deref(), Some("model-q4_k_m.gguf"));
230    }
231}