Skip to main content

voirs_cli/
model_types.rs

1//! Model management types for CLI.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// Model type classification
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum ModelType {
9    /// Acoustic models (text-to-mel conversion)
10    Acoustic,
11    /// Vocoder models (mel-to-audio conversion)
12    Vocoder,
13    /// Grapheme-to-phoneme models
14    G2P,
15}
16
17/// Model information for CLI operations
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ModelInfo {
20    /// Unique model identifier
21    pub id: String,
22    /// Human-readable model name
23    pub name: String,
24    /// Type of model
25    pub model_type: ModelType,
26    /// Primary language supported
27    pub language: String,
28    /// Model description
29    pub description: String,
30    /// Model version
31    pub version: String,
32    /// Model size in megabytes
33    pub size_mb: f64,
34    /// Supported sample rate in Hz
35    pub sample_rate: u32,
36    /// Quality score (0.0-5.0)
37    pub quality_score: f32,
38    /// Supported inference backends
39    pub supported_backends: Vec<String>,
40    /// Whether model is installed locally
41    pub is_installed: bool,
42    /// Local installation path if installed
43    pub installation_path: Option<String>,
44    /// Additional metadata
45    pub metadata: HashMap<String, String>,
46}
47
48impl ModelInfo {
49    /// Create a new ModelInfo
50    pub fn new(
51        id: String,
52        name: String,
53        model_type: ModelType,
54        language: String,
55        description: String,
56    ) -> Self {
57        Self {
58            id,
59            name,
60            model_type,
61            language,
62            description,
63            version: "1.0.0".to_string(),
64            size_mb: 0.0,
65            sample_rate: 22050,
66            quality_score: 3.5,
67            supported_backends: vec!["pytorch".to_string()],
68            is_installed: false,
69            installation_path: None,
70            metadata: HashMap::new(),
71        }
72    }
73
74    /// Check if model supports a specific backend
75    pub fn supports_backend(&self, backend: &str) -> bool {
76        self.supported_backends.iter().any(|b| b == backend)
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83
84    #[test]
85    fn test_model_info_creation() {
86        let model = ModelInfo::new(
87            "test-model".to_string(),
88            "Test Model".to_string(),
89            ModelType::Acoustic,
90            "en".to_string(),
91            "A test model".to_string(),
92        );
93
94        assert_eq!(model.id, "test-model");
95        assert_eq!(model.model_type, ModelType::Acoustic);
96        assert!(!model.is_installed);
97    }
98
99    #[test]
100    fn test_supports_backend() {
101        let mut model = ModelInfo::new(
102            "test".to_string(),
103            "Test".to_string(),
104            ModelType::Vocoder,
105            "en".to_string(),
106            "Test".to_string(),
107        );
108
109        model.supported_backends = vec!["pytorch".to_string(), "onnx".to_string()];
110
111        assert!(model.supports_backend("pytorch"));
112        assert!(model.supports_backend("onnx"));
113        assert!(!model.supports_backend("tensorflow"));
114    }
115}