Skip to main content

voirs_acoustic/
lib.rs

1//! # VoiRS Acoustic Models
2//!
3//! Neural acoustic models for converting phonemes to mel spectrograms.
4//! Supports VITS, FastSpeech2, and other state-of-the-art architectures.
5
6// Allow pedantic lints that are acceptable for audio/DSP processing code
7#![allow(clippy::cast_precision_loss)] // Acceptable for audio sample conversions
8#![allow(clippy::cast_possible_truncation)] // Controlled truncation in audio processing
9#![allow(clippy::cast_sign_loss)] // Intentional in index calculations
10#![allow(clippy::missing_errors_doc)] // Many internal functions with self-documenting error types
11#![allow(clippy::missing_panics_doc)] // Panics are documented where relevant
12#![allow(clippy::unused_self)] // Some trait implementations require &self for consistency
13#![allow(clippy::must_use_candidate)] // Not all return values need must_use annotation
14#![allow(clippy::doc_markdown)] // Technical terms don't all need backticks
15#![allow(clippy::unnecessary_wraps)] // Result wrappers maintained for API consistency
16#![allow(clippy::float_cmp)] // Exact float comparisons are intentional in some contexts
17#![allow(clippy::match_same_arms)] // Pattern matching clarity sometimes requires duplication
18#![allow(clippy::module_name_repetitions)] // Type names often repeat module names
19#![allow(clippy::struct_excessive_bools)] // Config structs naturally have many boolean flags
20#![allow(clippy::too_many_lines)] // Some DSP functions are inherently complex
21#![allow(clippy::needless_pass_by_value)] // Some functions designed for ownership transfer
22#![allow(clippy::similar_names)] // Many similar variable names in DSP algorithms
23
24use async_trait::async_trait;
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use thiserror::Error;
28
29/// Result type for acoustic model operations
30pub type Result<T> = std::result::Result<T, AcousticError>;
31
32/// Acoustic model specific error types with enhanced diagnostic information
33#[derive(Error, Debug)]
34pub enum AcousticError {
35    /// Model inference failed during synthesis or processing
36    #[error("Model inference failed: {message}")]
37    InferenceError { message: String },
38
39    /// Model loading or initialization failed
40    #[error("Model loading failed: {message}")]
41    ModelError { message: String },
42
43    /// Invalid input provided to the model
44    #[error("Invalid input: {message}")]
45    InputError { message: String },
46
47    /// Configuration validation or parsing error
48    #[error("Configuration error: {message}")]
49    ConfigError { message: String },
50
51    /// Processing error during synthesis pipeline
52    #[error("Processing error: {message}")]
53    ProcessingError { message: String },
54
55    /// File operation error (reading, writing, or parsing)
56    #[error("File operation error: {message}")]
57    FileError { message: String },
58
59    /// Backend-specific error from Candle framework
60    #[cfg(feature = "candle")]
61    #[error("Candle error: {0}")]
62    CandleError(#[from] candle_core::Error),
63
64    /// Grapheme-to-Phoneme conversion error
65    #[error("G2P error: {0}")]
66    G2pError(#[from] voirs_g2p::G2pError),
67}
68
69impl Clone for AcousticError {
70    fn clone(&self) -> Self {
71        match self {
72            AcousticError::InferenceError { message } => AcousticError::InferenceError {
73                message: message.clone(),
74            },
75            AcousticError::ModelError { message } => AcousticError::ModelError {
76                message: message.clone(),
77            },
78            AcousticError::InputError { message } => AcousticError::InputError {
79                message: message.clone(),
80            },
81            AcousticError::ConfigError { message } => AcousticError::ConfigError {
82                message: message.clone(),
83            },
84            AcousticError::ProcessingError { message } => AcousticError::ProcessingError {
85                message: message.clone(),
86            },
87            AcousticError::FileError { message } => AcousticError::FileError {
88                message: message.clone(),
89            },
90            #[cfg(feature = "candle")]
91            AcousticError::CandleError(err) => AcousticError::InferenceError {
92                message: format!("Candle error: {err}"),
93            },
94            AcousticError::G2pError(err) => AcousticError::InferenceError {
95                message: format!("G2P error: {err}"),
96            },
97        }
98    }
99}
100
101/// Language codes supported by VoiRS
102///
103/// This enum represents the complete set of languages supported by the VoiRS acoustic models.
104/// Each language may have region-specific variations (e.g., en-US vs en-GB).
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
106pub enum LanguageCode {
107    /// English (United States)
108    EnUs,
109    /// English (United Kingdom)
110    EnGb,
111    /// Japanese (Japan)
112    JaJp,
113    /// Mandarin Chinese (China)
114    ZhCn,
115    /// Korean (South Korea)
116    KoKr,
117    /// German (Germany)
118    DeDe,
119    /// French (France)
120    FrFr,
121    /// Spanish (Spain)
122    EsEs,
123    /// Italian (Italy)
124    ItIt,
125    /// Portuguese (Brazil)
126    PtBr,
127    /// Portuguese (Portugal)
128    PtPt,
129    /// Russian (Russia)
130    RuRu,
131    /// Dutch (Netherlands)
132    NlNl,
133    /// Polish (Poland)
134    PlPl,
135    /// Turkish (Turkey)
136    TrTr,
137    /// Arabic (Saudi Arabia)
138    ArSa,
139    /// Hindi (India)
140    HiIn,
141    /// Swedish (Sweden)
142    SvSe,
143    /// Norwegian (Norway)
144    NoNo,
145    /// Finnish (Finland)
146    FiFi,
147    /// Danish (Denmark)
148    DaDk,
149    /// Czech (Czech Republic)
150    CsCz,
151    /// Greek (Greece)
152    ElGr,
153    /// Hebrew (Israel)
154    HeIl,
155    /// Thai (Thailand)
156    ThTh,
157    /// Vietnamese (Vietnam)
158    ViVn,
159    /// Indonesian (Indonesia)
160    IdId,
161    /// Malay (Malaysia)
162    MsMy,
163}
164
165impl LanguageCode {
166    /// Get string representation in BCP 47 format
167    pub fn as_str(&self) -> &'static str {
168        match self {
169            LanguageCode::EnUs => "en-US",
170            LanguageCode::EnGb => "en-GB",
171            LanguageCode::JaJp => "ja-JP",
172            LanguageCode::ZhCn => "zh-CN",
173            LanguageCode::KoKr => "ko-KR",
174            LanguageCode::DeDe => "de-DE",
175            LanguageCode::FrFr => "fr-FR",
176            LanguageCode::EsEs => "es-ES",
177            LanguageCode::ItIt => "it-IT",
178            LanguageCode::PtBr => "pt-BR",
179            LanguageCode::PtPt => "pt-PT",
180            LanguageCode::RuRu => "ru-RU",
181            LanguageCode::NlNl => "nl-NL",
182            LanguageCode::PlPl => "pl-PL",
183            LanguageCode::TrTr => "tr-TR",
184            LanguageCode::ArSa => "ar-SA",
185            LanguageCode::HiIn => "hi-IN",
186            LanguageCode::SvSe => "sv-SE",
187            LanguageCode::NoNo => "no-NO",
188            LanguageCode::FiFi => "fi-FI",
189            LanguageCode::DaDk => "da-DK",
190            LanguageCode::CsCz => "cs-CZ",
191            LanguageCode::ElGr => "el-GR",
192            LanguageCode::HeIl => "he-IL",
193            LanguageCode::ThTh => "th-TH",
194            LanguageCode::ViVn => "vi-VN",
195            LanguageCode::IdId => "id-ID",
196            LanguageCode::MsMy => "ms-MY",
197        }
198    }
199
200    /// Get ISO 639-1 language code (2-letter code)
201    pub fn language_code(&self) -> &'static str {
202        &self.as_str()[..2]
203    }
204
205    /// Get full language name in English
206    pub fn language_name(&self) -> &'static str {
207        match self {
208            LanguageCode::EnUs | LanguageCode::EnGb => "English",
209            LanguageCode::JaJp => "Japanese",
210            LanguageCode::ZhCn => "Chinese",
211            LanguageCode::KoKr => "Korean",
212            LanguageCode::DeDe => "German",
213            LanguageCode::FrFr => "French",
214            LanguageCode::EsEs => "Spanish",
215            LanguageCode::ItIt => "Italian",
216            LanguageCode::PtBr | LanguageCode::PtPt => "Portuguese",
217            LanguageCode::RuRu => "Russian",
218            LanguageCode::NlNl => "Dutch",
219            LanguageCode::PlPl => "Polish",
220            LanguageCode::TrTr => "Turkish",
221            LanguageCode::ArSa => "Arabic",
222            LanguageCode::HiIn => "Hindi",
223            LanguageCode::SvSe => "Swedish",
224            LanguageCode::NoNo => "Norwegian",
225            LanguageCode::FiFi => "Finnish",
226            LanguageCode::DaDk => "Danish",
227            LanguageCode::CsCz => "Czech",
228            LanguageCode::ElGr => "Greek",
229            LanguageCode::HeIl => "Hebrew",
230            LanguageCode::ThTh => "Thai",
231            LanguageCode::ViVn => "Vietnamese",
232            LanguageCode::IdId => "Indonesian",
233            LanguageCode::MsMy => "Malay",
234        }
235    }
236
237    /// Parse from BCP 47 language tag string (case-insensitive)
238    ///
239    /// Accepts language tags in any case and normalizes to standard format:
240    /// - "en-US", "EN-US", "en-us", "En-Us" all parse to EnUs
241    pub fn parse(s: &str) -> Option<Self> {
242        // Normalize to standard BCP 47 format: lowercase language, uppercase region
243        // Split on hyphen, lowercase first part, uppercase second part
244        let parts: Vec<&str> = s.split('-').collect();
245        if parts.len() != 2 {
246            return None;
247        }
248
249        let normalized = format!("{}-{}", parts[0].to_lowercase(), parts[1].to_uppercase());
250
251        match normalized.as_str() {
252            "en-US" => Some(LanguageCode::EnUs),
253            "en-GB" => Some(LanguageCode::EnGb),
254            "ja-JP" => Some(LanguageCode::JaJp),
255            "zh-CN" => Some(LanguageCode::ZhCn),
256            "ko-KR" => Some(LanguageCode::KoKr),
257            "de-DE" => Some(LanguageCode::DeDe),
258            "fr-FR" => Some(LanguageCode::FrFr),
259            "es-ES" => Some(LanguageCode::EsEs),
260            "it-IT" => Some(LanguageCode::ItIt),
261            "pt-BR" => Some(LanguageCode::PtBr),
262            "pt-PT" => Some(LanguageCode::PtPt),
263            "ru-RU" => Some(LanguageCode::RuRu),
264            "nl-NL" => Some(LanguageCode::NlNl),
265            "pl-PL" => Some(LanguageCode::PlPl),
266            "tr-TR" => Some(LanguageCode::TrTr),
267            "ar-SA" => Some(LanguageCode::ArSa),
268            "hi-IN" => Some(LanguageCode::HiIn),
269            "sv-SE" => Some(LanguageCode::SvSe),
270            "no-NO" => Some(LanguageCode::NoNo),
271            "fi-FI" => Some(LanguageCode::FiFi),
272            "da-DK" => Some(LanguageCode::DaDk),
273            "cs-CZ" => Some(LanguageCode::CsCz),
274            "el-GR" => Some(LanguageCode::ElGr),
275            "he-IL" => Some(LanguageCode::HeIl),
276            "th-TH" => Some(LanguageCode::ThTh),
277            "vi-VN" => Some(LanguageCode::ViVn),
278            "id-ID" => Some(LanguageCode::IdId),
279            "ms-MY" => Some(LanguageCode::MsMy),
280            _ => None,
281        }
282    }
283
284    /// Get all supported language codes
285    pub fn all() -> &'static [LanguageCode] {
286        &[
287            LanguageCode::EnUs,
288            LanguageCode::EnGb,
289            LanguageCode::JaJp,
290            LanguageCode::ZhCn,
291            LanguageCode::KoKr,
292            LanguageCode::DeDe,
293            LanguageCode::FrFr,
294            LanguageCode::EsEs,
295            LanguageCode::ItIt,
296            LanguageCode::PtBr,
297            LanguageCode::PtPt,
298            LanguageCode::RuRu,
299            LanguageCode::NlNl,
300            LanguageCode::PlPl,
301            LanguageCode::TrTr,
302            LanguageCode::ArSa,
303            LanguageCode::HiIn,
304            LanguageCode::SvSe,
305            LanguageCode::NoNo,
306            LanguageCode::FiFi,
307            LanguageCode::DaDk,
308            LanguageCode::CsCz,
309            LanguageCode::ElGr,
310            LanguageCode::HeIl,
311            LanguageCode::ThTh,
312            LanguageCode::ViVn,
313            LanguageCode::IdId,
314            LanguageCode::MsMy,
315        ]
316    }
317}
318
319/// A phoneme with its symbol and optional features
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct Phoneme {
322    /// Phoneme symbol (IPA or language-specific)
323    pub symbol: String,
324    /// Optional phoneme features
325    pub features: Option<HashMap<String, String>>,
326    /// Duration in seconds (if available)
327    pub duration: Option<f32>,
328}
329
330impl PartialEq for Phoneme {
331    fn eq(&self, other: &Self) -> bool {
332        // Only compare symbol for equality (features and duration may vary)
333        self.symbol == other.symbol
334    }
335}
336
337impl Eq for Phoneme {}
338
339impl std::hash::Hash for Phoneme {
340    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
341        // Only hash the symbol (features and duration may vary)
342        self.symbol.hash(state);
343    }
344}
345
346impl Phoneme {
347    /// Create new phoneme
348    pub fn new<S: Into<String>>(symbol: S) -> Self {
349        Self {
350            symbol: symbol.into(),
351            features: None,
352            duration: None,
353        }
354    }
355}
356
357/// Mel spectrogram representation
358#[derive(Debug, Clone, Serialize, Deserialize)]
359pub struct MelSpectrogram {
360    /// Mel filterbank data [n_mels, n_frames]
361    pub data: Vec<Vec<f32>>,
362    /// Number of mel channels
363    pub n_mels: usize,
364    /// Number of time frames
365    pub n_frames: usize,
366    /// Sample rate of original audio
367    pub sample_rate: u32,
368    /// Hop length in samples
369    pub hop_length: u32,
370}
371
372impl MelSpectrogram {
373    /// Create new mel spectrogram
374    pub fn new(data: Vec<Vec<f32>>, sample_rate: u32, hop_length: u32) -> Self {
375        let n_mels = data.len();
376        let n_frames = data.first().map_or(0, |row| row.len());
377
378        Self {
379            data,
380            n_mels,
381            n_frames,
382            sample_rate,
383            hop_length,
384        }
385    }
386
387    /// Get duration in seconds
388    pub fn duration(&self) -> f32 {
389        (self.n_frames as u32 * self.hop_length) as f32 / self.sample_rate as f32
390    }
391}
392
393/// Simple synthesis configuration for basic operations
394#[derive(Debug, Clone, Serialize, Deserialize)]
395pub struct SynthesisConfig {
396    /// Speaking rate multiplier (1.0 = normal)
397    pub speed: f32,
398    /// Pitch shift in semitones
399    pub pitch_shift: f32,
400    /// Energy/volume multiplier
401    pub energy: f32,
402    /// Speaker ID for multi-speaker models
403    pub speaker_id: Option<u32>,
404    /// Random seed for reproducible generation
405    pub seed: Option<u64>,
406    /// Emotion control configuration
407    pub emotion: Option<crate::speaker::EmotionConfig>,
408    /// Voice style control
409    pub voice_style: Option<crate::speaker::VoiceStyleControl>,
410}
411
412impl SynthesisConfig {
413    /// Create new synthesis configuration
414    pub fn new() -> Self {
415        Self {
416            speed: 1.0,
417            pitch_shift: 0.0,
418            energy: 1.0,
419            speaker_id: None,
420            seed: None,
421            emotion: None,
422            voice_style: None,
423        }
424    }
425
426    /// Set emotion for synthesis
427    pub fn with_emotion(mut self, emotion: crate::speaker::EmotionConfig) -> Self {
428        self.emotion = Some(emotion);
429        self
430    }
431
432    /// Set voice style for synthesis
433    pub fn with_voice_style(mut self, voice_style: crate::speaker::VoiceStyleControl) -> Self {
434        self.voice_style = Some(voice_style);
435        self
436    }
437}
438
439impl Default for SynthesisConfig {
440    fn default() -> Self {
441        Self::new()
442    }
443}
444
445impl std::hash::Hash for SynthesisConfig {
446    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
447        // Convert floats to bits for deterministic hashing
448        self.speed.to_bits().hash(state);
449        self.pitch_shift.to_bits().hash(state);
450        self.energy.to_bits().hash(state);
451        self.speaker_id.hash(state);
452        self.seed.hash(state);
453        // Note: emotion and voice_style are not hashed for simplicity
454        // This is acceptable for cache keys as they're less common
455    }
456}
457
458// Re-export main traits and types
459pub use backends::{Backend, BackendManager};
460pub use batch_processor::{
461    BatchProcessingStats, BatchProcessor, BatchProcessorConfig, BatchProcessorTrait, BatchRequest,
462    ErrorStats, MemoryStats, QueueStats, RequestPriority,
463};
464pub use config::*;
465pub use mel::*;
466pub use memory::{
467    lazy::{
468        ComponentRegistry, LazyComponent, MemmapFile, MemoryPressureHandler, MemoryPressureLevel,
469        MemoryPressureStatus, ProgressiveLoader,
470    },
471    AdvancedPerformanceProfiler, MemoryOptimizer, OperationTimer, PerformanceMetrics,
472    PerformanceMonitor, PerformanceReport, PerformanceSnapshot, PerformanceThresholds, PoolStats,
473    ResultCache, SystemInfo, SystemMemoryInfo, TensorMemoryPool,
474};
475pub use metrics::{
476    EvaluationConfig, EvaluationPreset, MetricStatistics, ObjectiveEvaluator, ObjectiveMetrics,
477    PerceptualEvaluator, PerceptualMetrics, ProsodyEvaluator, ProsodyFeatures, ProsodyMetrics,
478    QualityEvaluator, QualityMetrics, QualityStatistics, RhythmFeatures, WindowType,
479};
480pub use models::{DummyAcousticConfig, DummyAcousticModel, ModelLoader};
481pub use optimization::{
482    DistillationConfig, DistillationStrategy, HardwareOptimization, HardwareTarget, ModelOptimizer,
483    OptimizationConfig, OptimizationMetrics, OptimizationReport, OptimizationTargets,
484    PruningConfig, PruningStrategy, PruningType, QuantizationConfig as OptQuantizationConfig,
485    QuantizationMethod as OptQuantizationMethod, QuantizationPrecision as OptQuantizationPrecision,
486};
487pub use prosody::{
488    DurationConfig, EnergyConfig, EnergyContourPattern, IntonationPattern, PauseDurations,
489    PitchConfig, ProsodyAdjustment, ProsodyConfig, ProsodyController, RhythmPattern, VibratoConfig,
490    VoiceQualityConfig,
491};
492pub use quantization::{
493    ModelQuantizer, QuantizationBenchmark, QuantizationConfig, QuantizationMethod,
494    QuantizationParams, QuantizationPrecision, QuantizationStats, QuantizedTensor,
495};
496pub use simd::{
497    Complex, FftWindow, SimdAudioEffects, SimdAudioProcessor, SimdCapabilities, SimdDispatcher,
498    SimdFft, SimdLinearLayer, SimdMatrix, SimdMelComputer, SimdStft, StftWindow, WindowFunction,
499};
500pub use singing::{
501    ArticulationMarking, BreathControlConfig, DynamicsMarking, FormantAdjustment, KeySignature,
502    MusicalNote, MusicalPhrase, ResonanceConfig, SingingConfig, SingingTechnique,
503    SingingVibratoConfig, SingingVoiceSynthesizer, VocalRegister, VoiceType,
504};
505pub use speaker::{
506    Accent, AgeGroup, AudioFeatures, AudioReference, CloningQualityMetrics,
507    CrossLanguageSpeakerAdapter, EmotionConfig, EmotionModel, EmotionType,
508    FewShotSpeakerAdaptation, Gender, MultiSpeakerConfig, MultiSpeakerModel, PersonalityTrait,
509    SpeakerEmbedding, SpeakerId, SpeakerMetadata, SpeakerRegistry, SpeakerVerificationResult,
510    SpeakerVerifier, VoiceCharacteristics, VoiceCloningConfig, VoiceCloningQualityAssessor,
511    VoiceQuality,
512};
513pub use streaming::{
514    LatencyOptimizer, LatencyOptimizerConfig, LatencyStats, LatencyStrategy,
515    PerformanceMeasurement, PerformancePredictor, StreamingConfig, StreamingMetrics,
516    StreamingState, StreamingSynthesizer,
517};
518pub use traits::{AcousticModel, AcousticModelFeature, AcousticModelMetadata};
519pub use vits::{TextEncoder, TextEncoderConfig, VitsConfig, VitsModel, VitsStreamingState};
520
521// Advanced modules (0.1.0-alpha.3 additions)
522pub mod acoustic_utils;
523pub mod latency_optimizer;
524pub mod neural_codec;
525pub mod vad;
526
527// Re-export advanced features
528pub use latency_optimizer::{
529    ChunkStrategy, LatencyBudget, LatencyMeasurement, LatencyOptimizer as AdvancedLatencyOptimizer,
530    LatencyStatistics, ProcessingPriority,
531};
532pub use neural_codec::{CodecQualityMetrics, CodecType, NeuralCodec, NeuralCodecConfig};
533pub use vad::{VadConfig, VadSegment, VoiceActivity, VoiceActivityDetector};
534
535pub mod backends;
536pub mod batch_processor;
537pub mod batching;
538pub mod cache;
539pub mod conditioning;
540pub mod config;
541pub mod diagnostics;
542pub mod error;
543pub mod fastspeech;
544pub mod fastspeech2_trainer;
545pub mod fusion;
546pub mod mel;
547pub mod memory;
548pub mod metrics;
549pub mod model_manager;
550pub mod model_warmup;
551pub mod models;
552pub mod optimization;
553pub mod parallel_attention;
554pub mod performance_targets;
555pub mod production;
556pub mod production_monitoring;
557pub mod profiling;
558pub mod profiling_integration;
559pub mod prosody;
560pub mod quantization;
561pub mod scirs2_ops;
562pub mod simd;
563pub mod singing;
564pub mod speaker;
565pub mod streaming;
566pub mod synthesis_cache;
567pub mod traits;
568pub mod unified_conditioning;
569pub mod utils;
570pub mod vits;
571
572/// Prelude for convenient imports
573pub mod prelude {
574    pub use crate::batch_processor::{
575        BatchProcessingStats, BatchProcessor, BatchProcessorConfig, BatchProcessorTrait,
576        BatchRequest, ErrorStats, MemoryStats, QueueStats, RequestPriority,
577    };
578    pub use crate::batching::{
579        BatchStats, DynamicBatchConfig, DynamicBatcher, MemoryOptimization, PaddingStrategy,
580        PendingSequence, ProcessingBatch,
581    };
582    pub use crate::cache::{
583        AdaptiveCache, AdaptiveCacheStats, CacheStats, CacheStrategy, LfuCache, PredictiveCache,
584    };
585    pub use crate::error::{
586        ErrorCategory, ErrorContext, ErrorContextBuilder, ErrorSeverity, RecoverySuggestion,
587    };
588    pub use crate::model_manager::{ModelManager, ModelRegistry, TtsPipeline};
589    pub use crate::parallel_attention::{
590        AttentionCache, AttentionMemoryOptimization, AttentionStats, AttentionStrategy,
591        ParallelAttentionConfig, ParallelMultiHeadAttention,
592    };
593    pub use crate::production::{
594        CircuitBreaker, CircuitState, HealthChecker, HealthStatus, RateLimiter, ResourceLimits,
595        RetryPolicy,
596    };
597    pub use crate::{
598        AcousticError, AcousticModel, AcousticModelFeature, AcousticModelManager,
599        AcousticModelMetadata, LanguageCode, MelSpectrogram, Phoneme, Result, SynthesisConfig,
600    };
601    pub use async_trait::async_trait;
602}
603
604// Types are already public in the root module
605
606/// Acoustic model manager with multiple architecture support
607pub struct AcousticModelManager {
608    models: HashMap<String, Box<dyn AcousticModel>>,
609    default_model: Option<String>,
610}
611
612impl AcousticModelManager {
613    /// Create new acoustic model manager
614    pub fn new() -> Self {
615        Self {
616            models: HashMap::new(),
617            default_model: None,
618        }
619    }
620
621    /// Add acoustic model
622    pub fn add_model(&mut self, name: String, model: Box<dyn AcousticModel>) {
623        self.models.insert(name.clone(), model);
624
625        // Set as default if it's the first model
626        if self.default_model.is_none() {
627            self.default_model = Some(name);
628        }
629    }
630
631    /// Set default model
632    pub fn set_default_model(&mut self, name: String) {
633        if self.models.contains_key(&name) {
634            self.default_model = Some(name);
635        }
636    }
637
638    /// Get model by name
639    pub fn get_model(&self, name: &str) -> Result<&dyn AcousticModel> {
640        self.models
641            .get(name)
642            .map(|m| m.as_ref())
643            .ok_or_else(|| AcousticError::ModelError {
644                message: format!("Acoustic model '{name}' not found"),
645            })
646    }
647
648    /// Get default model
649    pub fn get_default_model(&self) -> Result<&dyn AcousticModel> {
650        let name = self
651            .default_model
652            .as_ref()
653            .ok_or_else(|| AcousticError::ConfigError {
654                message: "No default acoustic model set".to_string(),
655            })?;
656        self.get_model(name)
657    }
658
659    /// List available models
660    pub fn list_models(&self) -> Vec<&str> {
661        self.models.keys().map(|s| s.as_str()).collect()
662    }
663}
664
665impl Default for AcousticModelManager {
666    fn default() -> Self {
667        Self::new()
668    }
669}
670
671#[async_trait]
672impl AcousticModel for AcousticModelManager {
673    async fn synthesize(
674        &self,
675        phonemes: &[Phoneme],
676        config: Option<&SynthesisConfig>,
677    ) -> Result<MelSpectrogram> {
678        let model = self.get_default_model()?;
679        model.synthesize(phonemes, config).await
680    }
681
682    async fn synthesize_batch(
683        &self,
684        inputs: &[&[Phoneme]],
685        configs: Option<&[SynthesisConfig]>,
686    ) -> Result<Vec<MelSpectrogram>> {
687        let model = self.get_default_model()?;
688        model.synthesize_batch(inputs, configs).await
689    }
690
691    fn metadata(&self) -> AcousticModelMetadata {
692        if let Ok(model) = self.get_default_model() {
693            model.metadata()
694        } else {
695            AcousticModelMetadata {
696                name: "Acoustic Model Manager".to_string(),
697                version: env!("CARGO_PKG_VERSION").to_string(),
698                architecture: "Manager".to_string(),
699                supported_languages: vec![],
700                sample_rate: 22050,
701                mel_channels: 80,
702                is_multi_speaker: false,
703                speaker_count: None,
704            }
705        }
706    }
707
708    fn supports(&self, feature: AcousticModelFeature) -> bool {
709        if let Ok(model) = self.get_default_model() {
710            model.supports(feature)
711        } else {
712            false
713        }
714    }
715
716    async fn set_speaker(&mut self, speaker_id: Option<u32>) -> Result<()> {
717        // Forward speaker setting to the default model
718        let default_name =
719            self.default_model
720                .as_ref()
721                .ok_or_else(|| AcousticError::ConfigError {
722                    message: "No default acoustic model set".to_string(),
723                })?;
724
725        if let Some(model) = self.models.get_mut(default_name) {
726            model.set_speaker(speaker_id).await
727        } else {
728            Err(AcousticError::ModelError {
729                message: format!("Default acoustic model '{default_name}' not found"),
730            })
731        }
732    }
733}
734
735// Type conversions are handled at the SDK level to avoid circular dependencies
736
737#[cfg(test)]
738mod language_tests {
739    use super::*;
740
741    #[test]
742    fn test_language_code_string_representation() {
743        assert_eq!(LanguageCode::EnUs.as_str(), "en-US");
744        assert_eq!(LanguageCode::PtBr.as_str(), "pt-BR");
745        assert_eq!(LanguageCode::RuRu.as_str(), "ru-RU");
746        assert_eq!(LanguageCode::ArSa.as_str(), "ar-SA");
747    }
748
749    #[test]
750    fn test_language_code_parsing() {
751        assert_eq!(LanguageCode::parse("en-US"), Some(LanguageCode::EnUs));
752        assert_eq!(LanguageCode::parse("pt-BR"), Some(LanguageCode::PtBr));
753        assert_eq!(LanguageCode::parse("ru-RU"), Some(LanguageCode::RuRu));
754        assert_eq!(LanguageCode::parse("invalid"), None);
755    }
756
757    #[test]
758    fn test_language_names() {
759        assert_eq!(LanguageCode::EnUs.language_name(), "English");
760        assert_eq!(LanguageCode::PtBr.language_name(), "Portuguese");
761        assert_eq!(LanguageCode::RuRu.language_name(), "Russian");
762        assert_eq!(LanguageCode::ArSa.language_name(), "Arabic");
763        assert_eq!(LanguageCode::HiIn.language_name(), "Hindi");
764    }
765
766    #[test]
767    fn test_iso_language_codes() {
768        assert_eq!(LanguageCode::EnUs.language_code(), "en");
769        assert_eq!(LanguageCode::PtBr.language_code(), "pt");
770        assert_eq!(LanguageCode::RuRu.language_code(), "ru");
771        assert_eq!(LanguageCode::ArSa.language_code(), "ar");
772    }
773
774    #[test]
775    fn test_all_languages() {
776        let all = LanguageCode::all();
777        assert_eq!(all.len(), 28); // Total number of supported languages
778        assert!(all.contains(&LanguageCode::EnUs));
779        assert!(all.contains(&LanguageCode::PtBr));
780        assert!(all.contains(&LanguageCode::RuRu));
781        assert!(all.contains(&LanguageCode::MsMy));
782    }
783
784    #[test]
785    fn test_language_code_roundtrip() {
786        for &lang in LanguageCode::all() {
787            let string_repr = lang.as_str();
788            let parsed = LanguageCode::parse(string_repr);
789            assert_eq!(parsed, Some(lang), "Roundtrip failed for {:?}", lang);
790        }
791    }
792
793    #[test]
794    fn test_language_sorting() {
795        let mut languages = vec![
796            LanguageCode::ZhCn,
797            LanguageCode::ArSa,
798            LanguageCode::EnUs,
799            LanguageCode::JaJp,
800        ];
801        languages.sort();
802        // Should be sorted by enum order
803        assert_eq!(languages[0], LanguageCode::EnUs);
804        assert_eq!(languages[1], LanguageCode::JaJp);
805    }
806
807    #[test]
808    fn test_new_language_support() {
809        // Test newly added languages
810        let new_languages = vec![
811            (LanguageCode::PtBr, "pt-BR", "Portuguese"),
812            (LanguageCode::RuRu, "ru-RU", "Russian"),
813            (LanguageCode::NlNl, "nl-NL", "Dutch"),
814            (LanguageCode::PlPl, "pl-PL", "Polish"),
815            (LanguageCode::TrTr, "tr-TR", "Turkish"),
816            (LanguageCode::ArSa, "ar-SA", "Arabic"),
817            (LanguageCode::HiIn, "hi-IN", "Hindi"),
818            (LanguageCode::SvSe, "sv-SE", "Swedish"),
819            (LanguageCode::NoNo, "no-NO", "Norwegian"),
820            (LanguageCode::FiFi, "fi-FI", "Finnish"),
821        ];
822
823        for (code, expected_str, expected_name) in new_languages {
824            assert_eq!(code.as_str(), expected_str);
825            assert_eq!(code.language_name(), expected_name);
826        }
827    }
828}