Skip to main content

spn_core/
backend.rs

1//! Backend types for model management.
2//!
3//! These types are used by spn-ollama (and future backends like llama.cpp)
4//! to provide a unified interface for local model management.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────────────────────┐
10//! │  spn-core (this module)                                                    │
11//! │  ├── PullProgress       Progress updates during model download              │
12//! │  ├── ModelInfo          Information about an installed model                │
13//! │  ├── RunningModel       Currently loaded model with GPU allocation          │
14//! │  ├── GpuInfo            GPU device information                              │
15//! │  ├── LoadConfig         Configuration for loading a model                   │
16//! │  └── BackendError       Error types for backend operations                  │
17//! └─────────────────────────────────────────────────────────────────────────────┘
18//! ```
19//!
20//! # Example
21//!
22//! ```
23//! use spn_core::{LoadConfig, ModelInfo, PullProgress};
24//!
25//! // Create a load configuration
26//! let config = LoadConfig::default()
27//!     .with_gpu_layers(-1)  // Use all GPU layers
28//!     .with_context_size(4096);
29//!
30//! // Model info from backend
31//! let info = ModelInfo {
32//!     name: "llama3.2:7b".to_string(),
33//!     size: 4_000_000_000,
34//!     quantization: Some("Q4_K_M".to_string()),
35//!     parameters: Some("7B".to_string()),
36//!     digest: Some("sha256:abc123".to_string()),
37//! };
38//!
39//! assert!(info.size_gb() > 3.0);
40//! ```
41
42use std::fmt;
43
44#[cfg(feature = "serde")]
45use serde::{Deserialize, Serialize};
46
47/// Progress information during model pull/download.
48#[derive(Debug, Clone, PartialEq, Eq)]
49#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50pub struct PullProgress {
51    /// Current status message (e.g., "pulling manifest", "downloading").
52    pub status: String,
53    /// Bytes completed.
54    pub completed: u64,
55    /// Total bytes to download.
56    pub total: u64,
57}
58
59impl PullProgress {
60    /// Create a new progress update.
61    #[must_use]
62    pub fn new(status: impl Into<String>, completed: u64, total: u64) -> Self {
63        Self {
64            status: status.into(),
65            completed,
66            total,
67        }
68    }
69
70    /// Get progress as a percentage (0.0 to 100.0).
71    #[must_use]
72    pub fn percent(&self) -> f64 {
73        if self.total == 0 {
74            0.0
75        } else {
76            (self.completed as f64 / self.total as f64) * 100.0
77        }
78    }
79
80    /// Check if download is complete.
81    #[must_use]
82    pub fn is_complete(&self) -> bool {
83        self.total > 0 && self.completed >= self.total
84    }
85}
86
87impl fmt::Display for PullProgress {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        write!(f, "{}: {:.1}%", self.status, self.percent())
90    }
91}
92
93/// Information about an installed model.
94#[derive(Debug, Clone, PartialEq, Eq)]
95#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
96pub struct ModelInfo {
97    /// Model name (e.g., "llama3.2:7b").
98    pub name: String,
99    /// Size in bytes.
100    pub size: u64,
101    /// Quantization level (e.g., "Q4_K_M", "Q8_0").
102    pub quantization: Option<String>,
103    /// Parameter count (e.g., "7B", "70B").
104    pub parameters: Option<String>,
105    /// Model digest/hash.
106    pub digest: Option<String>,
107}
108
109impl ModelInfo {
110    /// Get size in gigabytes.
111    #[must_use]
112    pub fn size_gb(&self) -> f64 {
113        self.size as f64 / 1_000_000_000.0
114    }
115
116    /// Get size as human-readable string.
117    #[must_use]
118    pub fn size_human(&self) -> String {
119        let gb = self.size_gb();
120        if gb >= 1.0 {
121            format!("{gb:.1} GB")
122        } else {
123            format!("{:.0} MB", self.size as f64 / 1_000_000.0)
124        }
125    }
126}
127
128impl fmt::Display for ModelInfo {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(f, "{} ({})", self.name, self.size_human())
131    }
132}
133
134/// Information about a currently running/loaded model.
135#[derive(Debug, Clone, PartialEq, Eq)]
136#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
137pub struct RunningModel {
138    /// Model name.
139    pub name: String,
140    /// VRAM used in bytes (if available).
141    pub vram_used: Option<u64>,
142    /// GPU IDs this model is loaded on.
143    pub gpu_ids: Vec<u32>,
144}
145
146impl RunningModel {
147    /// Get VRAM used in gigabytes.
148    #[must_use]
149    pub fn vram_gb(&self) -> Option<f64> {
150        self.vram_used.map(|v| v as f64 / 1_000_000_000.0)
151    }
152}
153
154impl fmt::Display for RunningModel {
155    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156        write!(f, "{}", self.name)?;
157        if !self.gpu_ids.is_empty() {
158            write!(f, " [GPU: {:?}]", self.gpu_ids)?;
159        }
160        if let Some(vram) = self.vram_gb() {
161            write!(f, " ({vram:.1} GB VRAM)")?;
162        }
163        Ok(())
164    }
165}
166
167/// GPU device information.
168#[derive(Debug, Clone, PartialEq, Eq)]
169#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
170pub struct GpuInfo {
171    /// GPU device ID.
172    pub id: u32,
173    /// GPU name (e.g., "NVIDIA RTX 4090").
174    pub name: String,
175    /// Total memory in bytes.
176    pub memory_total: u64,
177    /// Free memory in bytes.
178    pub memory_free: u64,
179}
180
181impl GpuInfo {
182    /// Get total memory in gigabytes.
183    #[must_use]
184    pub fn memory_total_gb(&self) -> f64 {
185        self.memory_total as f64 / 1_000_000_000.0
186    }
187
188    /// Get free memory in gigabytes.
189    #[must_use]
190    pub fn memory_free_gb(&self) -> f64 {
191        self.memory_free as f64 / 1_000_000_000.0
192    }
193
194    /// Get memory usage percentage.
195    #[must_use]
196    pub fn memory_used_percent(&self) -> f64 {
197        if self.memory_total == 0 {
198            0.0
199        } else {
200            let used = self.memory_total.saturating_sub(self.memory_free);
201            (used as f64 / self.memory_total as f64) * 100.0
202        }
203    }
204}
205
206impl fmt::Display for GpuInfo {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        write!(
209            f,
210            "GPU {}: {} ({:.1}/{:.1} GB free)",
211            self.id,
212            self.name,
213            self.memory_free_gb(),
214            self.memory_total_gb()
215        )
216    }
217}
218
219/// Error types for backend operations.
220#[derive(Debug, Clone, PartialEq, Eq)]
221#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
222pub enum BackendError {
223    /// Backend server is not running.
224    NotRunning,
225    /// Model not found in registry or locally.
226    ModelNotFound(String),
227    /// Model is already loaded.
228    AlreadyLoaded(String),
229    /// Insufficient GPU/system memory.
230    InsufficientMemory,
231    /// Network error during pull/API call.
232    NetworkError(String),
233    /// Process management error.
234    ProcessError(String),
235    /// Backend-specific error.
236    BackendSpecific(String),
237}
238
239impl std::error::Error for BackendError {}
240
241impl fmt::Display for BackendError {
242    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243        match self {
244            Self::NotRunning => write!(f, "Backend server is not running"),
245            Self::ModelNotFound(name) => write!(f, "Model not found: {name}"),
246            Self::AlreadyLoaded(name) => write!(f, "Model already loaded: {name}"),
247            Self::InsufficientMemory => write!(f, "Insufficient memory to load model"),
248            Self::NetworkError(msg) => write!(f, "Network error: {msg}"),
249            Self::ProcessError(msg) => write!(f, "Process error: {msg}"),
250            Self::BackendSpecific(msg) => write!(f, "Backend error: {msg}"),
251        }
252    }
253}
254
255/// Configuration for loading a model.
256#[derive(Debug, Clone, PartialEq, Eq)]
257#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
258pub struct LoadConfig {
259    /// GPU IDs to use for this model (empty = auto).
260    pub gpu_ids: Vec<u32>,
261    /// Number of layers to offload to GPU (-1 = all, 0 = none).
262    pub gpu_layers: i32,
263    /// Context size (token window).
264    pub context_size: Option<u32>,
265    /// Keep model loaded in memory (prevent unload).
266    pub keep_alive: bool,
267}
268
269impl Default for LoadConfig {
270    fn default() -> Self {
271        Self {
272            gpu_ids: Vec::new(),
273            gpu_layers: -1, // All layers on GPU by default
274            context_size: None,
275            keep_alive: false,
276        }
277    }
278}
279
280impl LoadConfig {
281    /// Create a new load configuration.
282    #[must_use]
283    pub fn new() -> Self {
284        Self::default()
285    }
286
287    /// Set specific GPU IDs.
288    #[must_use]
289    pub fn with_gpus(mut self, gpu_ids: Vec<u32>) -> Self {
290        self.gpu_ids = gpu_ids;
291        self
292    }
293
294    /// Set GPU layers (-1 = all, 0 = CPU only).
295    #[must_use]
296    pub fn with_gpu_layers(mut self, layers: i32) -> Self {
297        self.gpu_layers = layers;
298        self
299    }
300
301    /// Set context size.
302    #[must_use]
303    pub fn with_context_size(mut self, size: u32) -> Self {
304        self.context_size = Some(size);
305        self
306    }
307
308    /// Set keep alive.
309    #[must_use]
310    pub fn with_keep_alive(mut self, keep: bool) -> Self {
311        self.keep_alive = keep;
312        self
313    }
314
315    /// Check if this is a CPU-only configuration.
316    #[must_use]
317    pub fn is_cpu_only(&self) -> bool {
318        self.gpu_layers == 0
319    }
320
321    /// Check if using all GPU layers.
322    #[must_use]
323    pub fn is_full_gpu(&self) -> bool {
324        self.gpu_layers < 0
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_pull_progress() {
334        let progress = PullProgress::new("downloading", 500, 1000);
335        assert_eq!(progress.percent(), 50.0);
336        assert!(!progress.is_complete());
337
338        let complete = PullProgress::new("complete", 1000, 1000);
339        assert!(complete.is_complete());
340    }
341
342    #[test]
343    fn test_pull_progress_display() {
344        let progress = PullProgress::new("pulling", 750, 1000);
345        assert_eq!(progress.to_string(), "pulling: 75.0%");
346    }
347
348    #[test]
349    fn test_pull_progress_zero_total() {
350        let progress = PullProgress::new("starting", 0, 0);
351        assert_eq!(progress.percent(), 0.0);
352        assert!(!progress.is_complete());
353    }
354
355    #[test]
356    fn test_model_info_size() {
357        let info = ModelInfo {
358            name: "llama3.2:7b".to_string(),
359            size: 4_500_000_000,
360            quantization: Some("Q4_K_M".to_string()),
361            parameters: Some("7B".to_string()),
362            digest: None,
363        };
364
365        assert!((info.size_gb() - 4.5).abs() < 0.01);
366        assert_eq!(info.size_human(), "4.5 GB");
367    }
368
369    #[test]
370    fn test_model_info_display() {
371        let info = ModelInfo {
372            name: "test:latest".to_string(),
373            size: 500_000_000,
374            quantization: None,
375            parameters: None,
376            digest: None,
377        };
378
379        assert!(info.to_string().contains("test:latest"));
380        assert!(info.to_string().contains("500 MB"));
381    }
382
383    #[test]
384    fn test_running_model() {
385        let model = RunningModel {
386            name: "llama3.2".to_string(),
387            vram_used: Some(4_000_000_000),
388            gpu_ids: vec![0],
389        };
390
391        assert!((model.vram_gb().unwrap() - 4.0).abs() < 0.01);
392        assert!(model.to_string().contains("llama3.2"));
393        assert!(model.to_string().contains("GPU"));
394    }
395
396    #[test]
397    fn test_gpu_info() {
398        let gpu = GpuInfo {
399            id: 0,
400            name: "RTX 4090".to_string(),
401            memory_total: 24_000_000_000,
402            memory_free: 20_000_000_000,
403        };
404
405        assert!((gpu.memory_total_gb() - 24.0).abs() < 0.01);
406        assert!((gpu.memory_free_gb() - 20.0).abs() < 0.01);
407        assert!((gpu.memory_used_percent() - 16.67).abs() < 0.5);
408    }
409
410    #[test]
411    fn test_backend_error_display() {
412        let err = BackendError::NotRunning;
413        assert!(err.to_string().contains("not running"));
414
415        let err = BackendError::ModelNotFound("test".to_string());
416        assert!(err.to_string().contains("test"));
417    }
418
419    #[test]
420    fn test_load_config_default() {
421        let config = LoadConfig::default();
422        assert!(config.gpu_ids.is_empty());
423        assert_eq!(config.gpu_layers, -1);
424        assert!(config.is_full_gpu());
425        assert!(!config.is_cpu_only());
426    }
427
428    #[test]
429    fn test_load_config_builder() {
430        let config = LoadConfig::new()
431            .with_gpus(vec![0, 1])
432            .with_gpu_layers(32)
433            .with_context_size(8192)
434            .with_keep_alive(true);
435
436        assert_eq!(config.gpu_ids, vec![0, 1]);
437        assert_eq!(config.gpu_layers, 32);
438        assert_eq!(config.context_size, Some(8192));
439        assert!(config.keep_alive);
440        assert!(!config.is_cpu_only());
441        assert!(!config.is_full_gpu());
442    }
443
444    #[test]
445    fn test_load_config_cpu_only() {
446        let config = LoadConfig::new().with_gpu_layers(0);
447        assert!(config.is_cpu_only());
448        assert!(!config.is_full_gpu());
449    }
450}