Skip to main content

spn_core/
backend.rs

1//! Backend types for model management.
2//!
3//! These types are used by spn-ollama (and future backends like llama.cpp)
4//! to provide a unified interface for local model management.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────────────────────┐
10//! │  spn-core (this module)                                                    │
11//! │  ├── PullProgress       Progress updates during model download              │
12//! │  ├── ModelInfo          Information about an installed model                │
13//! │  ├── RunningModel       Currently loaded model with GPU allocation          │
14//! │  ├── GpuInfo            GPU device information                              │
15//! │  ├── LoadConfig         Configuration for loading a model                   │
16//! │  └── BackendError       Error types for backend operations                  │
17//! └─────────────────────────────────────────────────────────────────────────────┘
18//! ```
19//!
20//! # Example
21//!
22//! ```
23//! use spn_core::{LoadConfig, ModelInfo, PullProgress};
24//!
25//! // Create a load configuration
26//! let config = LoadConfig::default()
27//!     .with_gpu_layers(-1)  // Use all GPU layers
28//!     .with_context_size(4096);
29//!
30//! // Model info from backend
31//! let info = ModelInfo {
32//!     name: "llama3.2:7b".to_string(),
33//!     size: 4_000_000_000,
34//!     quantization: Some("Q4_K_M".to_string()),
35//!     parameters: Some("7B".to_string()),
36//!     digest: Some("sha256:abc123".to_string()),
37//! };
38//!
39//! assert!(info.size_gb() > 3.0);
40//! ```
41
42use std::fmt;
43
44#[cfg(feature = "serde")]
45use serde::{Deserialize, Serialize};
46
47// ============================================================================
48// Quantization Types
49// ============================================================================
50
51/// Quantization levels for GGUF models.
52///
53/// Quantization reduces model size and memory usage at the cost of some quality.
54/// Lower quantization (Q4) = smaller, faster, less accurate.
55/// Higher quantization (F16) = larger, slower, more accurate.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
58#[allow(non_camel_case_types)]
59pub enum Quantization {
60    /// 4-bit quantization, small variant (smallest, fastest).
61    Q4_K_S,
62    /// 4-bit quantization, medium variant (recommended for most use cases).
63    Q4_K_M,
64    /// 5-bit quantization, small variant.
65    Q5_K_S,
66    /// 5-bit quantization, medium variant (balanced quality/size).
67    Q5_K_M,
68    /// 6-bit quantization.
69    Q6_K,
70    /// 8-bit quantization (high quality).
71    Q8_0,
72    /// 16-bit floating point (full precision).
73    F16,
74    /// 32-bit floating point (maximum precision, rarely used).
75    F32,
76}
77
78impl Quantization {
79    /// Human-readable name for this quantization level.
80    #[must_use]
81    pub const fn name(&self) -> &'static str {
82        match self {
83            Self::Q4_K_S => "Q4_K_S (smallest)",
84            Self::Q4_K_M => "Q4_K_M (recommended)",
85            Self::Q5_K_S => "Q5_K_S",
86            Self::Q5_K_M => "Q5_K_M (balanced)",
87            Self::Q6_K => "Q6_K",
88            Self::Q8_0 => "Q8_0 (high quality)",
89            Self::F16 => "F16 (full precision)",
90            Self::F32 => "F32 (maximum)",
91        }
92    }
93
94    /// Short name without description.
95    #[must_use]
96    pub const fn short_name(&self) -> &'static str {
97        match self {
98            Self::Q4_K_S => "Q4_K_S",
99            Self::Q4_K_M => "Q4_K_M",
100            Self::Q5_K_S => "Q5_K_S",
101            Self::Q5_K_M => "Q5_K_M",
102            Self::Q6_K => "Q6_K",
103            Self::Q8_0 => "Q8_0",
104            Self::F16 => "F16",
105            Self::F32 => "F32",
106        }
107    }
108
109    /// Approximate memory multiplier (bytes per parameter).
110    ///
111    /// Use this to estimate model memory requirements:
112    /// `memory_gb = param_billions * multiplier`
113    #[must_use]
114    pub const fn memory_multiplier(&self) -> f32 {
115        match self {
116            Self::Q4_K_S => 0.45,
117            Self::Q4_K_M => 0.50,
118            Self::Q5_K_S => 0.55,
119            Self::Q5_K_M => 0.60,
120            Self::Q6_K => 0.70,
121            Self::Q8_0 => 1.00,
122            Self::F16 => 2.00,
123            Self::F32 => 4.00,
124        }
125    }
126
127    /// Returns all quantization levels in order from smallest to largest.
128    #[must_use]
129    pub const fn all() -> &'static [Quantization] {
130        &[
131            Self::Q4_K_S,
132            Self::Q4_K_M,
133            Self::Q5_K_S,
134            Self::Q5_K_M,
135            Self::Q6_K,
136            Self::Q8_0,
137            Self::F16,
138            Self::F32,
139        ]
140    }
141}
142
143impl fmt::Display for Quantization {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        write!(f, "{}", self.short_name())
146    }
147}
148
149impl Default for Quantization {
150    /// Default to Q4_K_M as it provides the best balance of size and quality.
151    fn default() -> Self {
152        Self::Q4_K_M
153    }
154}
155
156/// Progress information during model pull/download.
157#[derive(Debug, Clone, PartialEq, Eq)]
158#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
159pub struct PullProgress {
160    /// Current status message (e.g., "pulling manifest", "downloading").
161    pub status: String,
162    /// Bytes completed.
163    pub completed: u64,
164    /// Total bytes to download.
165    pub total: u64,
166}
167
168impl PullProgress {
169    /// Create a new progress update.
170    #[must_use]
171    pub fn new(status: impl Into<String>, completed: u64, total: u64) -> Self {
172        Self {
173            status: status.into(),
174            completed,
175            total,
176        }
177    }
178
179    /// Get progress as a percentage (0.0 to 100.0).
180    #[must_use]
181    pub fn percent(&self) -> f64 {
182        if self.total == 0 {
183            0.0
184        } else {
185            (self.completed as f64 / self.total as f64) * 100.0
186        }
187    }
188
189    /// Check if download is complete.
190    #[must_use]
191    pub fn is_complete(&self) -> bool {
192        self.total > 0 && self.completed >= self.total
193    }
194}
195
196impl fmt::Display for PullProgress {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        write!(f, "{}: {:.1}%", self.status, self.percent())
199    }
200}
201
202/// Information about an installed model.
203#[derive(Debug, Clone, PartialEq, Eq)]
204#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
205pub struct ModelInfo {
206    /// Model name (e.g., "llama3.2:7b").
207    pub name: String,
208    /// Size in bytes.
209    pub size: u64,
210    /// Quantization level (e.g., "Q4_K_M", "Q8_0").
211    pub quantization: Option<String>,
212    /// Parameter count (e.g., "7B", "70B").
213    pub parameters: Option<String>,
214    /// Model digest/hash.
215    pub digest: Option<String>,
216}
217
218impl ModelInfo {
219    /// Get size in gigabytes.
220    #[must_use]
221    pub fn size_gb(&self) -> f64 {
222        self.size as f64 / 1_000_000_000.0
223    }
224
225    /// Get size as human-readable string.
226    #[must_use]
227    pub fn size_human(&self) -> String {
228        let gb = self.size_gb();
229        if gb >= 1.0 {
230            format!("{gb:.1} GB")
231        } else {
232            format!("{:.0} MB", self.size as f64 / 1_000_000.0)
233        }
234    }
235}
236
237impl fmt::Display for ModelInfo {
238    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239        write!(f, "{} ({})", self.name, self.size_human())
240    }
241}
242
243/// Information about a currently running/loaded model.
244#[derive(Debug, Clone, PartialEq, Eq)]
245#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
246pub struct RunningModel {
247    /// Model name.
248    pub name: String,
249    /// VRAM used in bytes (if available).
250    pub vram_used: Option<u64>,
251    /// GPU IDs this model is loaded on.
252    pub gpu_ids: Vec<u32>,
253}
254
255impl RunningModel {
256    /// Get VRAM used in gigabytes.
257    #[must_use]
258    pub fn vram_gb(&self) -> Option<f64> {
259        self.vram_used.map(|v| v as f64 / 1_000_000_000.0)
260    }
261}
262
263impl fmt::Display for RunningModel {
264    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265        write!(f, "{}", self.name)?;
266        if !self.gpu_ids.is_empty() {
267            write!(f, " [GPU: {:?}]", self.gpu_ids)?;
268        }
269        if let Some(vram) = self.vram_gb() {
270            write!(f, " ({vram:.1} GB VRAM)")?;
271        }
272        Ok(())
273    }
274}
275
276/// GPU device information.
277#[derive(Debug, Clone, PartialEq, Eq)]
278#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
279pub struct GpuInfo {
280    /// GPU device ID.
281    pub id: u32,
282    /// GPU name (e.g., "NVIDIA RTX 4090").
283    pub name: String,
284    /// Total memory in bytes.
285    pub memory_total: u64,
286    /// Free memory in bytes.
287    pub memory_free: u64,
288}
289
290impl GpuInfo {
291    /// Get total memory in gigabytes.
292    #[must_use]
293    pub fn memory_total_gb(&self) -> f64 {
294        self.memory_total as f64 / 1_000_000_000.0
295    }
296
297    /// Get free memory in gigabytes.
298    #[must_use]
299    pub fn memory_free_gb(&self) -> f64 {
300        self.memory_free as f64 / 1_000_000_000.0
301    }
302
303    /// Get memory usage percentage.
304    #[must_use]
305    pub fn memory_used_percent(&self) -> f64 {
306        if self.memory_total == 0 {
307            0.0
308        } else {
309            let used = self.memory_total.saturating_sub(self.memory_free);
310            (used as f64 / self.memory_total as f64) * 100.0
311        }
312    }
313}
314
315impl fmt::Display for GpuInfo {
316    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317        write!(
318            f,
319            "GPU {}: {} ({:.1}/{:.1} GB free)",
320            self.id,
321            self.name,
322            self.memory_free_gb(),
323            self.memory_total_gb()
324        )
325    }
326}
327
328/// Error types for backend operations.
329#[derive(Debug, Clone, PartialEq, Eq)]
330#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
331pub enum BackendError {
332    /// Backend server is not running.
333    NotRunning,
334    /// Model not found in registry or locally.
335    ModelNotFound(String),
336    /// Model is already loaded.
337    AlreadyLoaded(String),
338    /// Insufficient GPU/system memory.
339    InsufficientMemory,
340    /// Network error during pull/API call.
341    NetworkError(String),
342    /// Process management error.
343    ProcessError(String),
344    /// Backend-specific error.
345    BackendSpecific(String),
346    /// Missing API key for cloud provider.
347    MissingApiKey(String),
348    /// API returned an error response.
349    ApiError {
350        /// HTTP status code.
351        status: u16,
352        /// Error message from API.
353        message: String,
354    },
355    /// Failed to parse API response.
356    ParseError(String),
357    /// Model loading failed.
358    LoadError(String),
359    /// Inference failed.
360    InferenceError(String),
361    /// Invalid model configuration.
362    InvalidConfig(String),
363    /// Storage/filesystem error.
364    StorageError(String),
365    /// Download failed.
366    DownloadError(String),
367    /// Checksum verification failed.
368    ChecksumError {
369        /// Expected checksum.
370        expected: String,
371        /// Actual checksum.
372        actual: String,
373    },
374}
375
376impl std::error::Error for BackendError {}
377
378impl fmt::Display for BackendError {
379    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
380        match self {
381            Self::NotRunning => write!(f, "Backend server is not running"),
382            Self::ModelNotFound(name) => write!(f, "Model not found: {name}"),
383            Self::AlreadyLoaded(name) => write!(f, "Model already loaded: {name}"),
384            Self::InsufficientMemory => write!(f, "Insufficient memory to load model"),
385            Self::NetworkError(msg) => write!(f, "Network error: {msg}"),
386            Self::ProcessError(msg) => write!(f, "Process error: {msg}"),
387            Self::BackendSpecific(msg) => write!(f, "Backend error: {msg}"),
388            Self::MissingApiKey(provider) => {
389                write!(f, "Missing API key for provider: {provider}")
390            }
391            Self::ApiError { status, message } => {
392                write!(f, "API error (HTTP {status}): {message}")
393            }
394            Self::ParseError(msg) => write!(f, "Parse error: {msg}"),
395            Self::LoadError(msg) => write!(f, "Model load error: {msg}"),
396            Self::InferenceError(msg) => write!(f, "Inference error: {msg}"),
397            Self::InvalidConfig(msg) => write!(f, "Invalid configuration: {msg}"),
398            Self::StorageError(msg) => write!(f, "Storage error: {msg}"),
399            Self::DownloadError(msg) => write!(f, "Download error: {msg}"),
400            Self::ChecksumError { expected, actual } => {
401                write!(
402                    f,
403                    "Checksum mismatch: expected {expected}, got {actual}"
404                )
405            }
406        }
407    }
408}
409
410impl BackendError {
411    /// Returns `true` if this error is transient and the operation should be retried.
412    ///
413    /// Retryable errors include network failures and temporary backend unavailability.
414    /// Non-retryable errors include model not found, insufficient memory, etc.
415    #[must_use]
416    pub const fn is_retryable(&self) -> bool {
417        matches!(
418            self,
419            Self::NetworkError(_) | Self::NotRunning | Self::DownloadError(_)
420        )
421    }
422
423    /// Returns `true` if this is an authentication/authorization error.
424    #[must_use]
425    pub fn is_auth_error(&self) -> bool {
426        match self {
427            Self::MissingApiKey(_) => true,
428            Self::ApiError { status, .. } => *status == 401 || *status == 403,
429            _ => false,
430        }
431    }
432}
433
434/// Configuration for loading a model.
435#[derive(Debug, Clone, PartialEq, Eq)]
436#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
437pub struct LoadConfig {
438    /// GPU IDs to use for this model (empty = auto).
439    pub gpu_ids: Vec<u32>,
440    /// Number of layers to offload to GPU (-1 = all, 0 = none).
441    pub gpu_layers: i32,
442    /// Context size (token window).
443    pub context_size: Option<u32>,
444    /// Keep model loaded in memory (prevent unload).
445    pub keep_alive: bool,
446}
447
448// ============================================================================
449// Chat Types
450// ============================================================================
451
452/// Role in a chat conversation.
453#[derive(Debug, Clone, Copy, PartialEq, Eq)]
454#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
455#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
456pub enum ChatRole {
457    /// System message (instructions).
458    System,
459    /// User message.
460    User,
461    /// Assistant response.
462    Assistant,
463}
464
465impl fmt::Display for ChatRole {
466    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
467        match self {
468            Self::System => write!(f, "system"),
469            Self::User => write!(f, "user"),
470            Self::Assistant => write!(f, "assistant"),
471        }
472    }
473}
474
475/// A message in a chat conversation.
476#[derive(Debug, Clone, PartialEq, Eq)]
477#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
478pub struct ChatMessage {
479    /// Role of the message sender.
480    pub role: ChatRole,
481    /// Content of the message.
482    pub content: String,
483}
484
485impl ChatMessage {
486    /// Create a new system message.
487    #[must_use]
488    pub fn system(content: impl Into<String>) -> Self {
489        Self {
490            role: ChatRole::System,
491            content: content.into(),
492        }
493    }
494
495    /// Create a new user message.
496    #[must_use]
497    pub fn user(content: impl Into<String>) -> Self {
498        Self {
499            role: ChatRole::User,
500            content: content.into(),
501        }
502    }
503
504    /// Create a new assistant message.
505    #[must_use]
506    pub fn assistant(content: impl Into<String>) -> Self {
507        Self {
508            role: ChatRole::Assistant,
509            content: content.into(),
510        }
511    }
512}
513
514/// Options for chat completion.
515#[derive(Debug, Clone, PartialEq, Default)]
516#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
517pub struct ChatOptions {
518    /// Temperature for sampling (0.0 to 2.0).
519    pub temperature: Option<f32>,
520    /// Top-p (nucleus) sampling.
521    pub top_p: Option<f32>,
522    /// Top-k sampling.
523    pub top_k: Option<u32>,
524    /// Maximum tokens to generate.
525    pub max_tokens: Option<u32>,
526    /// Stop sequences.
527    pub stop: Vec<String>,
528    /// Seed for reproducibility.
529    pub seed: Option<u64>,
530}
531
532impl ChatOptions {
533    /// Create new chat options.
534    #[must_use]
535    pub fn new() -> Self {
536        Self::default()
537    }
538
539    /// Set temperature.
540    #[must_use]
541    pub fn with_temperature(mut self, temp: f32) -> Self {
542        self.temperature = Some(temp);
543        self
544    }
545
546    /// Set top-p sampling.
547    #[must_use]
548    pub fn with_top_p(mut self, top_p: f32) -> Self {
549        self.top_p = Some(top_p);
550        self
551    }
552
553    /// Set top-k sampling.
554    #[must_use]
555    pub fn with_top_k(mut self, top_k: u32) -> Self {
556        self.top_k = Some(top_k);
557        self
558    }
559
560    /// Set maximum tokens.
561    #[must_use]
562    pub fn with_max_tokens(mut self, max: u32) -> Self {
563        self.max_tokens = Some(max);
564        self
565    }
566
567    /// Add a stop sequence.
568    #[must_use]
569    pub fn with_stop(mut self, stop: impl Into<String>) -> Self {
570        self.stop.push(stop.into());
571        self
572    }
573
574    /// Set seed for reproducibility.
575    #[must_use]
576    pub fn with_seed(mut self, seed: u64) -> Self {
577        self.seed = Some(seed);
578        self
579    }
580}
581
582/// Response from a chat completion.
583#[derive(Debug, Clone, PartialEq)]
584#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
585pub struct ChatResponse {
586    /// The assistant's response message.
587    pub message: ChatMessage,
588    /// Whether the response is complete (not streaming).
589    pub done: bool,
590    /// Total duration in nanoseconds.
591    pub total_duration: Option<u64>,
592    /// Tokens generated.
593    pub eval_count: Option<u32>,
594    /// Prompt tokens.
595    pub prompt_eval_count: Option<u32>,
596}
597
598impl ChatResponse {
599    /// Get the response content.
600    #[must_use]
601    pub fn content(&self) -> &str {
602        &self.message.content
603    }
604
605    /// Get tokens per second (if metrics available).
606    #[must_use]
607    pub fn tokens_per_second(&self) -> Option<f64> {
608        match (self.eval_count, self.total_duration) {
609            (Some(count), Some(duration)) if duration > 0 => {
610                Some(count as f64 / (duration as f64 / 1_000_000_000.0))
611            }
612            _ => None,
613        }
614    }
615}
616
617// ============================================================================
618// Embedding Types
619// ============================================================================
620
621/// Response from an embedding request.
622#[derive(Debug, Clone, PartialEq)]
623#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
624pub struct EmbeddingResponse {
625    /// The embedding vector.
626    pub embedding: Vec<f32>,
627    /// Total duration in nanoseconds.
628    pub total_duration: Option<u64>,
629    /// Number of tokens in the input.
630    pub prompt_eval_count: Option<u32>,
631}
632
633impl EmbeddingResponse {
634    /// Get the dimension of the embedding.
635    #[must_use]
636    pub fn dimension(&self) -> usize {
637        self.embedding.len()
638    }
639
640    /// Calculate cosine similarity with another embedding.
641    #[must_use]
642    pub fn cosine_similarity(&self, other: &Self) -> f32 {
643        if self.embedding.len() != other.embedding.len() {
644            return 0.0;
645        }
646
647        let dot_product: f32 = self
648            .embedding
649            .iter()
650            .zip(&other.embedding)
651            .map(|(a, b)| a * b)
652            .sum();
653
654        let norm_a: f32 = self.embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
655        let norm_b: f32 = other.embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
656
657        if norm_a == 0.0 || norm_b == 0.0 {
658            0.0
659        } else {
660            dot_product / (norm_a * norm_b)
661        }
662    }
663}
664
665impl Default for LoadConfig {
666    fn default() -> Self {
667        Self {
668            gpu_ids: Vec::new(),
669            gpu_layers: -1, // All layers on GPU by default
670            context_size: None,
671            keep_alive: false,
672        }
673    }
674}
675
676impl LoadConfig {
677    /// Create a new load configuration.
678    #[must_use]
679    pub fn new() -> Self {
680        Self::default()
681    }
682
683    /// Set specific GPU IDs.
684    #[must_use]
685    pub fn with_gpus(mut self, gpu_ids: Vec<u32>) -> Self {
686        self.gpu_ids = gpu_ids;
687        self
688    }
689
690    /// Set GPU layers (-1 = all, 0 = CPU only).
691    #[must_use]
692    pub fn with_gpu_layers(mut self, layers: i32) -> Self {
693        self.gpu_layers = layers;
694        self
695    }
696
697    /// Set context size.
698    #[must_use]
699    pub fn with_context_size(mut self, size: u32) -> Self {
700        self.context_size = Some(size);
701        self
702    }
703
704    /// Set keep alive.
705    #[must_use]
706    pub fn with_keep_alive(mut self, keep: bool) -> Self {
707        self.keep_alive = keep;
708        self
709    }
710
711    /// Check if this is a CPU-only configuration.
712    #[must_use]
713    pub fn is_cpu_only(&self) -> bool {
714        self.gpu_layers == 0
715    }
716
717    /// Check if using all GPU layers.
718    #[must_use]
719    pub fn is_full_gpu(&self) -> bool {
720        self.gpu_layers < 0
721    }
722}
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727
728    #[test]
729    fn test_pull_progress() {
730        let progress = PullProgress::new("downloading", 500, 1000);
731        assert_eq!(progress.percent(), 50.0);
732        assert!(!progress.is_complete());
733
734        let complete = PullProgress::new("complete", 1000, 1000);
735        assert!(complete.is_complete());
736    }
737
738    #[test]
739    fn test_pull_progress_display() {
740        let progress = PullProgress::new("pulling", 750, 1000);
741        assert_eq!(progress.to_string(), "pulling: 75.0%");
742    }
743
744    #[test]
745    fn test_pull_progress_zero_total() {
746        let progress = PullProgress::new("starting", 0, 0);
747        assert_eq!(progress.percent(), 0.0);
748        assert!(!progress.is_complete());
749    }
750
751    #[test]
752    fn test_model_info_size() {
753        let info = ModelInfo {
754            name: "llama3.2:7b".to_string(),
755            size: 4_500_000_000,
756            quantization: Some("Q4_K_M".to_string()),
757            parameters: Some("7B".to_string()),
758            digest: None,
759        };
760
761        assert!((info.size_gb() - 4.5).abs() < 0.01);
762        assert_eq!(info.size_human(), "4.5 GB");
763    }
764
765    #[test]
766    fn test_model_info_display() {
767        let info = ModelInfo {
768            name: "test:latest".to_string(),
769            size: 500_000_000,
770            quantization: None,
771            parameters: None,
772            digest: None,
773        };
774
775        assert!(info.to_string().contains("test:latest"));
776        assert!(info.to_string().contains("500 MB"));
777    }
778
779    #[test]
780    fn test_running_model() {
781        let model = RunningModel {
782            name: "llama3.2".to_string(),
783            vram_used: Some(4_000_000_000),
784            gpu_ids: vec![0],
785        };
786
787        assert!((model.vram_gb().unwrap() - 4.0).abs() < 0.01);
788        assert!(model.to_string().contains("llama3.2"));
789        assert!(model.to_string().contains("GPU"));
790    }
791
792    #[test]
793    fn test_gpu_info() {
794        let gpu = GpuInfo {
795            id: 0,
796            name: "RTX 4090".to_string(),
797            memory_total: 24_000_000_000,
798            memory_free: 20_000_000_000,
799        };
800
801        assert!((gpu.memory_total_gb() - 24.0).abs() < 0.01);
802        assert!((gpu.memory_free_gb() - 20.0).abs() < 0.01);
803        assert!((gpu.memory_used_percent() - 16.67).abs() < 0.5);
804    }
805
806    #[test]
807    fn test_backend_error_display() {
808        let err = BackendError::NotRunning;
809        assert!(err.to_string().contains("not running"));
810
811        let err = BackendError::ModelNotFound("test".to_string());
812        assert!(err.to_string().contains("test"));
813    }
814
815    #[test]
816    fn test_load_config_default() {
817        let config = LoadConfig::default();
818        assert!(config.gpu_ids.is_empty());
819        assert_eq!(config.gpu_layers, -1);
820        assert!(config.is_full_gpu());
821        assert!(!config.is_cpu_only());
822    }
823
824    #[test]
825    fn test_load_config_builder() {
826        let config = LoadConfig::new()
827            .with_gpus(vec![0, 1])
828            .with_gpu_layers(32)
829            .with_context_size(8192)
830            .with_keep_alive(true);
831
832        assert_eq!(config.gpu_ids, vec![0, 1]);
833        assert_eq!(config.gpu_layers, 32);
834        assert_eq!(config.context_size, Some(8192));
835        assert!(config.keep_alive);
836        assert!(!config.is_cpu_only());
837        assert!(!config.is_full_gpu());
838    }
839
840    #[test]
841    fn test_load_config_cpu_only() {
842        let config = LoadConfig::new().with_gpu_layers(0);
843        assert!(config.is_cpu_only());
844        assert!(!config.is_full_gpu());
845    }
846
847    #[test]
848    fn test_chat_role_display() {
849        assert_eq!(ChatRole::System.to_string(), "system");
850        assert_eq!(ChatRole::User.to_string(), "user");
851        assert_eq!(ChatRole::Assistant.to_string(), "assistant");
852    }
853
854    #[test]
855    fn test_chat_message_constructors() {
856        let system = ChatMessage::system("You are helpful");
857        assert_eq!(system.role, ChatRole::System);
858        assert_eq!(system.content, "You are helpful");
859
860        let user = ChatMessage::user("Hello");
861        assert_eq!(user.role, ChatRole::User);
862
863        let assistant = ChatMessage::assistant("Hi there!");
864        assert_eq!(assistant.role, ChatRole::Assistant);
865    }
866
867    #[test]
868    fn test_chat_options_builder() {
869        let options = ChatOptions::new()
870            .with_temperature(0.7)
871            .with_top_p(0.9)
872            .with_top_k(40)
873            .with_max_tokens(100)
874            .with_stop("END")
875            .with_seed(42);
876
877        assert_eq!(options.temperature, Some(0.7));
878        assert_eq!(options.top_p, Some(0.9));
879        assert_eq!(options.top_k, Some(40));
880        assert_eq!(options.max_tokens, Some(100));
881        assert_eq!(options.stop, vec!["END"]);
882        assert_eq!(options.seed, Some(42));
883    }
884
885    #[test]
886    fn test_chat_response_content() {
887        let response = ChatResponse {
888            message: ChatMessage::assistant("Hello!"),
889            done: true,
890            total_duration: Some(1_000_000_000),
891            eval_count: Some(10),
892            prompt_eval_count: Some(5),
893        };
894
895        assert_eq!(response.content(), "Hello!");
896        assert!(response.done);
897    }
898
899    #[test]
900    fn test_chat_response_tokens_per_second() {
901        let response = ChatResponse {
902            message: ChatMessage::assistant("Test"),
903            done: true,
904            total_duration: Some(2_000_000_000), // 2 seconds
905            eval_count: Some(100),
906            prompt_eval_count: None,
907        };
908
909        let tps = response.tokens_per_second().unwrap();
910        assert!((tps - 50.0).abs() < 0.1);
911    }
912
913    #[test]
914    fn test_embedding_response_dimension() {
915        let response = EmbeddingResponse {
916            embedding: vec![0.1, 0.2, 0.3, 0.4],
917            total_duration: None,
918            prompt_eval_count: None,
919        };
920
921        assert_eq!(response.dimension(), 4);
922    }
923
924    #[test]
925    fn test_embedding_cosine_similarity() {
926        let a = EmbeddingResponse {
927            embedding: vec![1.0, 0.0, 0.0],
928            total_duration: None,
929            prompt_eval_count: None,
930        };
931
932        let b = EmbeddingResponse {
933            embedding: vec![1.0, 0.0, 0.0],
934            total_duration: None,
935            prompt_eval_count: None,
936        };
937
938        // Identical vectors should have similarity of 1.0
939        assert!((a.cosine_similarity(&b) - 1.0).abs() < 0.001);
940
941        let c = EmbeddingResponse {
942            embedding: vec![0.0, 1.0, 0.0],
943            total_duration: None,
944            prompt_eval_count: None,
945        };
946
947        // Orthogonal vectors should have similarity of 0.0
948        assert!((a.cosine_similarity(&c)).abs() < 0.001);
949    }
950
951    #[test]
952    fn test_embedding_cosine_similarity_different_dimensions() {
953        let a = EmbeddingResponse {
954            embedding: vec![1.0, 0.0],
955            total_duration: None,
956            prompt_eval_count: None,
957        };
958
959        let b = EmbeddingResponse {
960            embedding: vec![1.0, 0.0, 0.0],
961            total_duration: None,
962            prompt_eval_count: None,
963        };
964
965        // Different dimensions should return 0.0
966        assert_eq!(a.cosine_similarity(&b), 0.0);
967    }
968
969    #[test]
970    fn test_backend_error_is_retryable() {
971        // Retryable errors (transient failures)
972        assert!(BackendError::NetworkError("timeout".to_string()).is_retryable());
973        assert!(BackendError::NotRunning.is_retryable());
974
975        // Non-retryable errors (permanent failures)
976        assert!(!BackendError::ModelNotFound("model".to_string()).is_retryable());
977        assert!(!BackendError::AlreadyLoaded("model".to_string()).is_retryable());
978        assert!(!BackendError::InsufficientMemory.is_retryable());
979        assert!(!BackendError::ProcessError("error".to_string()).is_retryable());
980        assert!(!BackendError::BackendSpecific("error".to_string()).is_retryable());
981    }
982
983    #[test]
984    fn test_quantization_default() {
985        let quant = Quantization::default();
986        assert_eq!(quant, Quantization::Q4_K_M);
987    }
988}