Skip to main content

voirs_conversion/zero_shot/
database.rs

1//! Reference voice database for zero-shot learning
2
3use crate::types::VoiceCharacteristics;
4use crate::Result;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::time::Instant;
8
9/// Speaker embedding for voice identification and conversion
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct SpeakerEmbedding {
12    /// Embedding vector data
13    pub data: Vec<f32>,
14    /// Confidence score (0.0 to 1.0)
15    pub confidence: f32,
16}
17
18/// Reference voice database for zero-shot learning
19pub struct ReferenceVoiceDatabase {
20    /// Voice entries indexed by speaker ID
21    voices: HashMap<String, ReferenceVoice>,
22
23    /// Voice embeddings for fast similarity search
24    embeddings: HashMap<String, SpeakerEmbedding>,
25
26    /// Voice characteristics
27    characteristics: HashMap<String, VoiceCharacteristics>,
28
29    /// Usage statistics
30    usage_stats: HashMap<String, UsageStatistics>,
31
32    /// Database metadata
33    metadata: DatabaseMetadata,
34}
35
36/// Reference voice entry
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ReferenceVoice {
39    /// Speaker identifier
40    pub speaker_id: String,
41
42    /// Voice name/description
43    pub name: String,
44
45    /// Audio samples
46    pub audio_samples: Vec<AudioSample>,
47
48    /// Speaker embedding
49    pub embedding: SpeakerEmbedding,
50
51    /// Voice characteristics
52    pub characteristics: VoiceCharacteristics,
53
54    /// Quality scores
55    pub quality_scores: QualityScores,
56
57    /// Metadata
58    pub metadata: VoiceMetadata,
59
60    /// Last used timestamp
61    #[serde(skip)]
62    pub last_used: Option<Instant>,
63}
64
65/// Audio sample for reference voice
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct AudioSample {
68    /// Sample identifier
69    pub id: String,
70
71    /// Audio data (placeholder - would contain actual audio)
72    #[serde(skip)]
73    pub audio_data: Vec<f32>,
74
75    /// Sample rate
76    pub sample_rate: u32,
77
78    /// Duration in seconds
79    pub duration: f32,
80
81    /// Transcription
82    pub transcription: Option<String>,
83
84    /// Quality score
85    pub quality_score: f32,
86
87    /// Phonetic content analysis
88    pub phonetic_content: PhoneticAnalysis,
89}
90
91/// Quality scores for reference voice
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct QualityScores {
94    /// Overall quality score (0.0 to 1.0)
95    pub overall: f32,
96
97    /// Clarity score
98    pub clarity: f32,
99
100    /// Naturalness score
101    pub naturalness: f32,
102
103    /// Consistency score
104    pub consistency: f32,
105
106    /// Recording quality score
107    pub recording_quality: f32,
108
109    /// Prosody quality score
110    pub prosody_quality: f32,
111}
112
113/// Voice metadata
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct VoiceMetadata {
116    /// Language
117    pub language: String,
118
119    /// Accent/dialect
120    pub accent: Option<String>,
121
122    /// Gender
123    pub gender: Option<String>,
124
125    /// Age group
126    pub age_group: Option<String>,
127
128    /// Recording environment
129    pub recording_environment: Option<String>,
130
131    /// Tags
132    pub tags: Vec<String>,
133
134    /// Creation timestamp
135    #[serde(
136        skip_serializing,
137        skip_deserializing,
138        default = "std::time::Instant::now"
139    )]
140    pub created: Instant,
141
142    /// Last modified timestamp
143    #[serde(skip_serializing, skip_deserializing, default)]
144    pub modified: Option<Instant>,
145}
146
147/// Phonetic analysis of audio sample
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct PhoneticAnalysis {
150    /// Phoneme distribution
151    pub phoneme_distribution: HashMap<String, f32>,
152
153    /// Phonetic diversity score
154    pub diversity_score: f32,
155
156    /// Vowel-consonant ratio
157    pub vowel_consonant_ratio: f32,
158
159    /// Prosodic features
160    pub prosodic_features: ProsodicFeatures,
161}
162
163/// Prosodic features
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct ProsodicFeatures {
166    /// Mean F0
167    pub mean_f0: f32,
168
169    /// F0 range
170    pub f0_range: (f32, f32),
171
172    /// Speaking rate (syllables per second)
173    pub speaking_rate: f32,
174
175    /// Pause patterns
176    pub pause_patterns: Vec<f32>,
177
178    /// Stress patterns
179    pub stress_patterns: Vec<f32>,
180}
181
182/// Usage statistics for reference voices
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct UsageStatistics {
185    /// Number of times used
186    pub usage_count: u64,
187
188    /// Average similarity scores
189    pub avg_similarity: f32,
190
191    /// Success rate
192    pub success_rate: f32,
193
194    /// Last used timestamp
195    #[serde(skip)]
196    pub last_used: Option<Instant>,
197
198    /// Preferred contexts
199    pub preferred_contexts: Vec<String>,
200}
201
202/// Database metadata
203#[derive(Debug, Clone)]
204pub struct DatabaseMetadata {
205    /// Total number of voices
206    pub total_voices: usize,
207
208    /// Total audio duration (seconds)
209    pub total_duration: f32,
210
211    /// Languages represented
212    pub languages: Vec<String>,
213
214    /// Last updated timestamp
215    pub last_updated: Instant,
216
217    /// Database version
218    pub version: String,
219
220    /// Index statistics
221    pub index_stats: IndexStatistics,
222}
223
224/// Index statistics
225#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct IndexStatistics {
227    /// Embedding index size
228    pub embedding_index_size: usize,
229
230    /// Characteristic index size
231    pub characteristic_index_size: usize,
232
233    /// Search performance metrics
234    pub search_performance: SearchPerformanceMetrics,
235}
236
237/// Search performance metrics
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct SearchPerformanceMetrics {
240    /// Average search time (ms)
241    pub avg_search_time: f32,
242
243    /// Cache hit rate
244    pub cache_hit_rate: f32,
245
246    /// Index efficiency
247    pub index_efficiency: f32,
248}
249
250impl Default for ReferenceVoiceDatabase {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256impl ReferenceVoiceDatabase {
257    /// Creates a new empty reference voice database.
258    ///
259    /// # Returns
260    ///
261    /// A new `ReferenceVoiceDatabase` instance with empty collections and initialized metadata.
262    pub fn new() -> Self {
263        Self {
264            voices: HashMap::new(),
265            embeddings: HashMap::new(),
266            characteristics: HashMap::new(),
267            usage_stats: HashMap::new(),
268            metadata: DatabaseMetadata {
269                total_voices: 0,
270                total_duration: 0.0,
271                languages: Vec::new(),
272                last_updated: Instant::now(),
273                version: "1.0.0".to_string(),
274                index_stats: IndexStatistics {
275                    embedding_index_size: 0,
276                    characteristic_index_size: 0,
277                    search_performance: SearchPerformanceMetrics {
278                        avg_search_time: 0.0,
279                        cache_hit_rate: 0.0,
280                        index_efficiency: 1.0,
281                    },
282                },
283            },
284        }
285    }
286
287    /// Adds a reference voice to the database.
288    ///
289    /// # Arguments
290    ///
291    /// * `voice` - The reference voice to add with speaker ID, embedding, and characteristics
292    ///
293    /// # Returns
294    ///
295    /// `Ok(())` if the voice was successfully added.
296    ///
297    /// # Errors
298    ///
299    /// Currently does not return errors, but returns `Result` for future error handling.
300    pub fn add_voice(&mut self, voice: ReferenceVoice) -> Result<()> {
301        let speaker_id = voice.speaker_id.clone();
302        self.embeddings
303            .insert(speaker_id.clone(), voice.embedding.clone());
304        self.characteristics
305            .insert(speaker_id.clone(), voice.characteristics.clone());
306        self.usage_stats.insert(
307            speaker_id.clone(),
308            UsageStatistics {
309                usage_count: 0,
310                avg_similarity: 0.0,
311                success_rate: 0.0,
312                last_used: None,
313                preferred_contexts: Vec::new(),
314            },
315        );
316        self.voices.insert(speaker_id, voice);
317        self.metadata.total_voices += 1;
318        Ok(())
319    }
320
321    /// Removes a reference voice from the database by speaker ID.
322    ///
323    /// # Arguments
324    ///
325    /// * `speaker_id` - The unique identifier of the speaker to remove
326    ///
327    /// # Returns
328    ///
329    /// `Ok(())` if the voice was successfully removed or if it didn't exist.
330    ///
331    /// # Errors
332    ///
333    /// Currently does not return errors, but returns `Result` for future error handling.
334    pub fn remove_voice(&mut self, speaker_id: &str) -> Result<()> {
335        self.voices.remove(speaker_id);
336        self.embeddings.remove(speaker_id);
337        self.characteristics.remove(speaker_id);
338        self.usage_stats.remove(speaker_id);
339        if self.metadata.total_voices > 0 {
340            self.metadata.total_voices -= 1;
341        }
342        Ok(())
343    }
344
345    /// Finds the most similar reference voices based on target characteristics.
346    ///
347    /// # Arguments
348    ///
349    /// * `target_characteristics` - The target voice characteristics to match against
350    /// * `max_voices` - Maximum number of similar voices to return
351    ///
352    /// # Returns
353    ///
354    /// A vector of reference voices sorted by similarity (most similar first), limited to `max_voices`.
355    ///
356    /// # Errors
357    ///
358    /// Currently does not return errors, but returns `Result` for future error handling.
359    pub fn find_similar_voices(
360        &self,
361        target_characteristics: &VoiceCharacteristics,
362        max_voices: usize,
363    ) -> Result<Vec<ReferenceVoice>> {
364        let mut similarities = Vec::new();
365
366        for voice in self.voices.values() {
367            let similarity =
368                self.calculate_similarity(&voice.characteristics, target_characteristics);
369            similarities.push((similarity, voice.clone()));
370        }
371
372        // Sort by similarity (descending)
373        similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
374
375        // Take top matches
376        Ok(similarities
377            .into_iter()
378            .take(max_voices)
379            .map(|(_, voice)| voice)
380            .collect())
381    }
382
383    /// Returns a reference to the database metadata.
384    ///
385    /// Provides access to database statistics, version information, and index performance metrics.
386    ///
387    /// # Returns
388    ///
389    /// A reference to the [`DatabaseMetadata`] containing information about the database state,
390    /// including total voices, duration, supported languages, and search performance metrics.
391    pub fn metadata(&self) -> &DatabaseMetadata {
392        &self.metadata
393    }
394
395    fn calculate_similarity(
396        &self,
397        voice1: &VoiceCharacteristics,
398        voice2: &VoiceCharacteristics,
399    ) -> f32 {
400        // Enhanced multi-dimensional similarity calculation
401        let mut similarity = 0.0;
402        let mut total_weight: f32 = 0.0;
403
404        // Gender similarity (weight: 0.25)
405        if voice1.gender == voice2.gender {
406            similarity += 0.25;
407        } else if voice1.gender.is_some() && voice2.gender.is_some() {
408            similarity += 0.1; // Partial credit for different but specified genders
409        }
410        total_weight += 0.25;
411
412        // Age group similarity (weight: 0.15)
413        if voice1.age_group == voice2.age_group {
414            similarity += 0.15;
415        } else if voice1.age_group.is_some() && voice2.age_group.is_some() {
416            // Age groups have some similarity (adult vs young_adult vs senior)
417            let age_similarity = match (voice1.age_group, voice2.age_group) {
418                (Some(a1), Some(a2)) => {
419                    use crate::types::AgeGroup;
420                    match (a1, a2) {
421                        (AgeGroup::YoungAdult, AgeGroup::MiddleAged)
422                        | (AgeGroup::MiddleAged, AgeGroup::YoungAdult) => 0.8,
423                        (AgeGroup::MiddleAged, AgeGroup::Senior)
424                        | (AgeGroup::Senior, AgeGroup::MiddleAged) => 0.6,
425                        (AgeGroup::Child, AgeGroup::YoungAdult)
426                        | (AgeGroup::YoungAdult, AgeGroup::Child) => 0.4,
427                        _ => 0.2,
428                    }
429                }
430                _ => 0.05,
431            };
432            similarity += 0.15 * age_similarity;
433        }
434        total_weight += 0.15;
435
436        // Accent similarity (weight: 0.2)
437        if voice1.accent == voice2.accent {
438            similarity += 0.2;
439        } else if voice1.accent.is_some() && voice2.accent.is_some() {
440            // Some accents are more similar than others
441            let accent_similarity =
442                if let (Some(ref a1), Some(ref a2)) = (&voice1.accent, &voice2.accent) {
443                    if (a1.contains("american") && a2.contains("canadian"))
444                        || (a2.contains("american") && a1.contains("canadian"))
445                    {
446                        0.8
447                    } else if (a1.contains("british") && a2.contains("australian"))
448                        || (a2.contains("british") && a1.contains("australian"))
449                    {
450                        0.7
451                    } else {
452                        0.3
453                    }
454                } else {
455                    0.1
456                };
457            similarity += 0.2 * accent_similarity;
458        }
459        total_weight += 0.2;
460
461        // Pitch similarity (weight: 0.2)
462        let pitch_diff = (voice1.pitch.mean_f0 - voice2.pitch.mean_f0).abs();
463        let pitch_similarity = if pitch_diff < 10.0 {
464            1.0 // Very similar pitch
465        } else if pitch_diff < 50.0 {
466            1.0 - (pitch_diff - 10.0) / 40.0 // Linear decay from 10-50 Hz
467        } else {
468            (1.0 - (pitch_diff / 200.0)).max(0.0) // Slower decay above 50 Hz
469        };
470        similarity += pitch_similarity * 0.2;
471        total_weight += 0.2;
472
473        // Spectral similarity (weight: 0.1)
474        let formant_diff = (voice1.spectral.formant_shift - voice2.spectral.formant_shift).abs();
475        let spectral_similarity = 1.0 - formant_diff.min(1.0);
476        similarity += spectral_similarity * 0.1;
477        total_weight += 0.1;
478
479        // Quality similarity (weight: 0.1)
480        let breathiness_diff = (voice1.quality.breathiness - voice2.quality.breathiness).abs();
481        let roughness_diff = (voice1.quality.roughness - voice2.quality.roughness).abs();
482        let quality_similarity = 1.0 - ((breathiness_diff + roughness_diff) / 2.0).min(1.0);
483        similarity += quality_similarity * 0.1;
484        total_weight += 0.1;
485
486        // Normalize by total weight
487        similarity / total_weight.max(1e-10)
488    }
489}
490
491impl Default for VoiceMetadata {
492    fn default() -> Self {
493        Self {
494            language: String::new(),
495            accent: None,
496            gender: None,
497            age_group: None,
498            recording_environment: None,
499            tags: Vec::new(),
500            created: Instant::now(),
501            modified: None,
502        }
503    }
504}
505
506impl Default for DatabaseMetadata {
507    fn default() -> Self {
508        Self {
509            total_voices: 0,
510            total_duration: 0.0,
511            languages: Vec::new(),
512            last_updated: Instant::now(),
513            version: "1.0.0".to_string(),
514            index_stats: IndexStatistics::default(),
515        }
516    }
517}
518
519impl Default for IndexStatistics {
520    fn default() -> Self {
521        Self {
522            embedding_index_size: 0,
523            characteristic_index_size: 0,
524            search_performance: SearchPerformanceMetrics::default(),
525        }
526    }
527}
528
529impl Default for SearchPerformanceMetrics {
530    fn default() -> Self {
531        Self {
532            avg_search_time: 0.0,
533            cache_hit_rate: 0.0,
534            index_efficiency: 1.0,
535        }
536    }
537}