1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum ModelType {
9 Acoustic,
11 Vocoder,
13 G2P,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ModelInfo {
20 pub id: String,
22 pub name: String,
24 pub model_type: ModelType,
26 pub language: String,
28 pub description: String,
30 pub version: String,
32 pub size_mb: f64,
34 pub sample_rate: u32,
36 pub quality_score: f32,
38 pub supported_backends: Vec<String>,
40 pub is_installed: bool,
42 pub installation_path: Option<String>,
44 pub metadata: HashMap<String, String>,
46}
47
48impl ModelInfo {
49 pub fn new(
51 id: String,
52 name: String,
53 model_type: ModelType,
54 language: String,
55 description: String,
56 ) -> Self {
57 Self {
58 id,
59 name,
60 model_type,
61 language,
62 description,
63 version: "1.0.0".to_string(),
64 size_mb: 0.0,
65 sample_rate: 22050,
66 quality_score: 3.5,
67 supported_backends: vec!["pytorch".to_string()],
68 is_installed: false,
69 installation_path: None,
70 metadata: HashMap::new(),
71 }
72 }
73
74 pub fn supports_backend(&self, backend: &str) -> bool {
76 self.supported_backends.iter().any(|b| b == backend)
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[test]
85 fn test_model_info_creation() {
86 let model = ModelInfo::new(
87 "test-model".to_string(),
88 "Test Model".to_string(),
89 ModelType::Acoustic,
90 "en".to_string(),
91 "A test model".to_string(),
92 );
93
94 assert_eq!(model.id, "test-model");
95 assert_eq!(model.model_type, ModelType::Acoustic);
96 assert!(!model.is_installed);
97 }
98
99 #[test]
100 fn test_supports_backend() {
101 let mut model = ModelInfo::new(
102 "test".to_string(),
103 "Test".to_string(),
104 ModelType::Vocoder,
105 "en".to_string(),
106 "Test".to_string(),
107 );
108
109 model.supported_backends = vec!["pytorch".to_string(), "onnx".to_string()];
110
111 assert!(model.supports_backend("pytorch"));
112 assert!(model.supports_backend("onnx"));
113 assert!(!model.supports_backend("tensorflow"));
114 }
115}