Skip to main content

voirs_cli/plugins/
voices.rs

1use super::{Plugin, PluginError, PluginResult, PluginType};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use voirs_sdk::types::{AgeRange, Gender, QualityLevel, SpeakingStyle};
6use voirs_sdk::voice::VoiceInfo;
7use voirs_sdk::VoiceCharacteristics;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct VoicePluginConfig {
11    pub voice_id: String,
12    pub language: String,
13    pub gender: String,
14    pub style: String,
15    pub speed_multiplier: f32,
16    pub pitch_shift: f32,
17    pub volume_gain: f32,
18    pub emotion: Option<String>,
19    pub custom_parameters: HashMap<String, serde_json::Value>,
20}
21
22impl Default for VoicePluginConfig {
23    fn default() -> Self {
24        Self {
25            voice_id: "default".to_string(),
26            language: "en-US".to_string(),
27            gender: "neutral".to_string(),
28            style: "standard".to_string(),
29            speed_multiplier: 1.0,
30            pitch_shift: 0.0,
31            volume_gain: 0.0,
32            emotion: None,
33            custom_parameters: HashMap::new(),
34        }
35    }
36}
37
38pub trait VoicePlugin: Plugin {
39    fn synthesize(&self, text: &str, config: &VoicePluginConfig) -> PluginResult<Vec<f32>>;
40    fn get_voice_info(&self) -> VoiceInfo;
41    fn get_supported_languages(&self) -> Vec<String>;
42    fn get_voice_characteristics(&self) -> VoiceCharacteristics;
43    fn supports_ssml(&self) -> bool;
44    fn supports_emotions(&self) -> bool;
45    fn get_supported_emotions(&self) -> Vec<String>;
46    fn get_sample_rate(&self) -> u32;
47    fn get_quality_levels(&self) -> Vec<String>;
48    fn validate_text(&self, text: &str) -> PluginResult<()>;
49    fn estimate_duration(&self, text: &str, config: &VoicePluginConfig) -> PluginResult<f32>;
50}
51
52pub struct VoicePluginManager {
53    voices: HashMap<String, Arc<dyn VoicePlugin>>,
54    voice_configs: HashMap<String, VoicePluginConfig>,
55    active_voice: Option<String>,
56}
57
58impl VoicePluginManager {
59    pub fn new() -> Self {
60        Self {
61            voices: HashMap::new(),
62            voice_configs: HashMap::new(),
63            active_voice: None,
64        }
65    }
66
67    pub fn register_voice(&mut self, voice_id: String, voice: Arc<dyn VoicePlugin>) {
68        self.voices.insert(voice_id.clone(), voice);
69        self.voice_configs
70            .insert(voice_id, VoicePluginConfig::default());
71    }
72
73    pub fn unregister_voice(&mut self, voice_id: &str) -> bool {
74        let removed = self.voices.remove(voice_id).is_some();
75        self.voice_configs.remove(voice_id);
76
77        if self.active_voice.as_ref() == Some(&voice_id.to_string()) {
78            self.active_voice = None;
79        }
80
81        removed
82    }
83
84    pub fn list_voices(&self) -> Vec<String> {
85        self.voices.keys().cloned().collect()
86    }
87
88    pub fn get_voice(&self, voice_id: &str) -> Option<Arc<dyn VoicePlugin>> {
89        self.voices.get(voice_id).cloned()
90    }
91
92    pub fn set_active_voice(&mut self, voice_id: &str) -> PluginResult<()> {
93        if self.voices.contains_key(voice_id) {
94            self.active_voice = Some(voice_id.to_string());
95            Ok(())
96        } else {
97            Err(PluginError::NotFound(voice_id.to_string()))
98        }
99    }
100
101    pub fn get_active_voice(&self) -> Option<&String> {
102        self.active_voice.as_ref()
103    }
104
105    pub fn synthesize_with_voice(
106        &self,
107        voice_id: &str,
108        text: &str,
109        config: Option<&VoicePluginConfig>,
110    ) -> PluginResult<Vec<f32>> {
111        let voice = self
112            .voices
113            .get(voice_id)
114            .ok_or_else(|| PluginError::NotFound(voice_id.to_string()))?;
115
116        let default_config = VoicePluginConfig::default();
117        let config = match config {
118            Some(c) => c,
119            None => self.voice_configs.get(voice_id).unwrap_or(&default_config),
120        };
121
122        voice.synthesize(text, config)
123    }
124
125    pub fn synthesize_with_active_voice(
126        &self,
127        text: &str,
128        config: Option<&VoicePluginConfig>,
129    ) -> PluginResult<Vec<f32>> {
130        let voice_id = self
131            .active_voice
132            .as_ref()
133            .ok_or_else(|| PluginError::ExecutionFailed("No active voice set".to_string()))?;
134
135        self.synthesize_with_voice(voice_id, text, config)
136    }
137
138    pub fn update_voice_config(
139        &mut self,
140        voice_id: &str,
141        config: VoicePluginConfig,
142    ) -> PluginResult<()> {
143        if self.voices.contains_key(voice_id) {
144            self.voice_configs.insert(voice_id.to_string(), config);
145            Ok(())
146        } else {
147            Err(PluginError::NotFound(voice_id.to_string()))
148        }
149    }
150
151    pub fn get_voice_config(&self, voice_id: &str) -> Option<&VoicePluginConfig> {
152        self.voice_configs.get(voice_id)
153    }
154
155    pub fn get_voice_info(&self, voice_id: &str) -> PluginResult<VoiceInfo> {
156        let voice = self
157            .voices
158            .get(voice_id)
159            .ok_or_else(|| PluginError::NotFound(voice_id.to_string()))?;
160
161        Ok(voice.get_voice_info())
162    }
163
164    pub fn search_voices(
165        &self,
166        language: Option<&str>,
167        gender: Option<&str>,
168        style: Option<&str>,
169    ) -> Vec<String> {
170        self.voices
171            .iter()
172            .filter(|(voice_id, voice)| {
173                let characteristics = voice.get_voice_characteristics();
174                let voice_info = voice.get_voice_info();
175
176                if let Some(lang) = language {
177                    if !voice.get_supported_languages().contains(&lang.to_string()) {
178                        return false;
179                    }
180                }
181
182                if let Some(g) = gender {
183                    if let Some(voice_gender) = &voice_info.config.characteristics.gender {
184                        if voice_gender.to_string().to_lowercase() != g.to_lowercase() {
185                            return false;
186                        }
187                    }
188                }
189
190                if let Some(s) = style {
191                    if voice_info
192                        .config
193                        .characteristics
194                        .style
195                        .to_string()
196                        .to_lowercase()
197                        != s.to_lowercase()
198                    {
199                        return false;
200                    }
201                }
202
203                true
204            })
205            .map(|(voice_id, _)| voice_id.clone())
206            .collect()
207    }
208
209    pub fn validate_text_for_voice(&self, voice_id: &str, text: &str) -> PluginResult<()> {
210        let voice = self
211            .voices
212            .get(voice_id)
213            .ok_or_else(|| PluginError::NotFound(voice_id.to_string()))?;
214
215        voice.validate_text(text)
216    }
217
218    pub fn estimate_synthesis_duration(
219        &self,
220        voice_id: &str,
221        text: &str,
222        config: Option<&VoicePluginConfig>,
223    ) -> PluginResult<f32> {
224        let voice = self
225            .voices
226            .get(voice_id)
227            .ok_or_else(|| PluginError::NotFound(voice_id.to_string()))?;
228
229        let default_config = VoicePluginConfig::default();
230        let config = match config {
231            Some(c) => c,
232            None => self.voice_configs.get(voice_id).unwrap_or(&default_config),
233        };
234
235        voice.estimate_duration(text, config)
236    }
237}
238
239impl Default for VoicePluginManager {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245// Example built-in voice plugin implementation
246pub struct DefaultVoicePlugin {
247    name: String,
248    version: String,
249    voice_id: String,
250}
251
252impl DefaultVoicePlugin {
253    pub fn new(voice_id: &str) -> Self {
254        Self {
255            name: format!("default-voice-{}", voice_id),
256            version: "1.0.0".to_string(),
257            voice_id: voice_id.to_string(),
258        }
259    }
260}
261
262impl Plugin for DefaultVoicePlugin {
263    fn name(&self) -> &str {
264        &self.name
265    }
266
267    fn version(&self) -> &str {
268        &self.version
269    }
270
271    fn description(&self) -> &str {
272        "Built-in default voice plugin"
273    }
274
275    fn plugin_type(&self) -> PluginType {
276        PluginType::Voice
277    }
278
279    fn initialize(&mut self, _config: &serde_json::Value) -> PluginResult<()> {
280        Ok(())
281    }
282
283    fn cleanup(&mut self) -> PluginResult<()> {
284        Ok(())
285    }
286
287    fn get_capabilities(&self) -> Vec<String> {
288        vec![
289            "synthesize".to_string(),
290            "get_voice_info".to_string(),
291            "get_supported_languages".to_string(),
292            "validate_text".to_string(),
293            "estimate_duration".to_string(),
294        ]
295    }
296
297    fn execute(&self, command: &str, args: &serde_json::Value) -> PluginResult<serde_json::Value> {
298        match command {
299            "get_voice_info" => {
300                let info = self.get_voice_info();
301                Ok(serde_json::to_value(info).map_err(PluginError::SerializationError)?)
302            }
303            "get_supported_languages" => Ok(serde_json::json!(self.get_supported_languages())),
304            "validate_text" => {
305                let text = args.get("text").and_then(|v| v.as_str()).ok_or_else(|| {
306                    PluginError::ExecutionFailed("Missing text parameter".to_string())
307                })?;
308
309                self.validate_text(text)?;
310                Ok(serde_json::json!({"valid": true}))
311            }
312            _ => Err(PluginError::ExecutionFailed(format!(
313                "Unknown command: {}",
314                command
315            ))),
316        }
317    }
318}
319
320impl VoicePlugin for DefaultVoicePlugin {
321    fn synthesize(&self, text: &str, config: &VoicePluginConfig) -> PluginResult<Vec<f32>> {
322        // Enhanced formant-based synthesis for more speech-like audio
323        // Duration based on character count with speed multiplier
324        let base_duration = text.len() as f32 * 0.08 * config.speed_multiplier;
325        let sample_rate = self.get_sample_rate() as f32;
326        let samples = (base_duration * sample_rate) as usize;
327
328        // Fundamental frequency (F0) with pitch shift
329        let f0 = 150.0 * (2.0_f32).powf(config.pitch_shift / 12.0);
330
331        // Speech formant frequencies (approximate vowel /a/)
332        let formants = [
333            (800.0, 0.3),   // F1: First formant
334            (1200.0, 0.2),  // F2: Second formant
335            (2500.0, 0.15), // F3: Third formant
336            (3500.0, 0.1),  // F4: Fourth formant
337        ];
338
339        // Base amplitude with volume control
340        let base_amplitude = 0.08 * (10.0_f32).powf(config.volume_gain / 20.0);
341
342        let mut audio = Vec::with_capacity(samples);
343        let two_pi = 2.0 * std::f32::consts::PI;
344
345        for i in 0..samples {
346            let t = i as f32 / sample_rate;
347
348            // Generate fundamental frequency with harmonics
349            let mut sample = 0.0;
350
351            // Add harmonics (up to 8th harmonic for richness)
352            for harmonic in 1..=8 {
353                let freq = f0 * harmonic as f32;
354                let amplitude = base_amplitude / harmonic as f32; // Harmonic rolloff
355                sample += amplitude * (two_pi * freq * t).sin();
356            }
357
358            // Apply formant filtering (simplified resonance)
359            for (formant_freq, formant_amp) in &formants {
360                let formant_phase = (two_pi * formant_freq * t).sin();
361                sample += formant_phase * formant_amp * base_amplitude;
362            }
363
364            // Apply simple envelope (attack-sustain-release)
365            let envelope = if t < 0.02 {
366                // Attack (20ms)
367                t / 0.02
368            } else if t > base_duration - 0.05 {
369                // Release (50ms)
370                (base_duration - t) / 0.05
371            } else {
372                // Sustain
373                1.0
374            };
375
376            sample *= envelope;
377
378            // Soft clipping to prevent distortion
379            sample = sample.clamp(-0.95, 0.95);
380
381            audio.push(sample);
382        }
383
384        Ok(audio)
385    }
386
387    fn get_voice_info(&self) -> VoiceInfo {
388        use voirs_sdk::types::{Gender, QualityLevel, SpeakingStyle};
389        use voirs_sdk::VoiceConfig;
390
391        let config = VoiceConfig {
392            id: self.voice_id.clone(),
393            name: format!("Default Voice {}", self.voice_id),
394            characteristics: VoiceCharacteristics {
395                gender: Some(Gender::NonBinary),
396                age: Some(AgeRange::Adult),
397                style: SpeakingStyle::Neutral,
398                emotion_support: true,
399                quality: QualityLevel::Medium,
400            },
401            language: voirs_sdk::types::LanguageCode::EnUs,
402            model_config: Default::default(),
403            metadata: HashMap::new(),
404        };
405
406        VoiceInfo::from_config(config)
407    }
408
409    fn get_supported_languages(&self) -> Vec<String> {
410        vec!["en-US".to_string(), "en-GB".to_string()]
411    }
412
413    fn get_voice_characteristics(&self) -> VoiceCharacteristics {
414        VoiceCharacteristics {
415            gender: Some(Gender::NonBinary),
416            age: Some(AgeRange::Adult),
417            style: SpeakingStyle::Neutral,
418            emotion_support: true,
419            quality: QualityLevel::Medium,
420        }
421    }
422
423    fn supports_ssml(&self) -> bool {
424        false
425    }
426
427    fn supports_emotions(&self) -> bool {
428        false
429    }
430
431    fn get_supported_emotions(&self) -> Vec<String> {
432        vec![]
433    }
434
435    fn get_sample_rate(&self) -> u32 {
436        22050
437    }
438
439    fn get_quality_levels(&self) -> Vec<String> {
440        vec!["low".to_string(), "medium".to_string(), "high".to_string()]
441    }
442
443    fn validate_text(&self, text: &str) -> PluginResult<()> {
444        if text.is_empty() {
445            return Err(PluginError::ExecutionFailed(
446                "Empty text not allowed".to_string(),
447            ));
448        }
449
450        if text.len() > 10000 {
451            return Err(PluginError::ExecutionFailed(
452                "Text too long (max 10000 characters)".to_string(),
453            ));
454        }
455
456        Ok(())
457    }
458
459    fn estimate_duration(&self, text: &str, config: &VoicePluginConfig) -> PluginResult<f32> {
460        // Simple estimation: ~10 characters per second, adjusted by speed
461        let base_duration = text.len() as f32 * 0.1;
462        Ok(base_duration / config.speed_multiplier)
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_voice_plugin_manager() {
472        let mut manager = VoicePluginManager::new();
473        let voice = Arc::new(DefaultVoicePlugin::new("test"));
474
475        manager.register_voice("test".to_string(), voice);
476
477        let voices = manager.list_voices();
478        assert!(voices.contains(&"test".to_string()));
479
480        let voice = manager.get_voice("test");
481        assert!(voice.is_some());
482    }
483
484    #[test]
485    fn test_default_voice_plugin() {
486        let voice = DefaultVoicePlugin::new("test");
487        assert_eq!(voice.name(), "default-voice-test");
488        assert_eq!(voice.version(), "1.0.0");
489
490        let config = VoicePluginConfig::default();
491        let audio = voice.synthesize("Hello world", &config).unwrap();
492        assert!(!audio.is_empty());
493    }
494
495    #[test]
496    fn test_voice_validation() {
497        let voice = DefaultVoicePlugin::new("test");
498
499        // Valid text
500        assert!(voice.validate_text("Hello world").is_ok());
501
502        // Empty text should fail
503        assert!(voice.validate_text("").is_err());
504
505        // Very long text should fail
506        let long_text = "a".repeat(10001);
507        assert!(voice.validate_text(&long_text).is_err());
508    }
509
510    #[test]
511    fn test_voice_search() {
512        let mut manager = VoicePluginManager::new();
513        let voice = Arc::new(DefaultVoicePlugin::new("test"));
514
515        manager.register_voice("test".to_string(), voice);
516
517        let results = manager.search_voices(Some("en-US"), None, None);
518        assert!(results.contains(&"test".to_string()));
519
520        let results = manager.search_voices(Some("fr-FR"), None, None);
521        assert!(results.is_empty());
522    }
523
524    #[test]
525    fn test_active_voice() {
526        let mut manager = VoicePluginManager::new();
527        let voice = Arc::new(DefaultVoicePlugin::new("test"));
528
529        manager.register_voice("test".to_string(), voice);
530
531        assert!(manager.set_active_voice("test").is_ok());
532        assert_eq!(manager.get_active_voice(), Some(&"test".to_string()));
533
534        let audio = manager.synthesize_with_active_voice("Hello", None).unwrap();
535        assert!(!audio.is_empty());
536    }
537}