1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use voirs_sdk::types::SynthesisConfig;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub enum CloningMethod {
7 FineTuning,
8 SpeakerEmbedding,
9 ZeroShot,
10 FewShot,
11 Adaptive,
12 Neural,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct VoiceProfile {
17 pub id: String,
18 pub name: String,
19 pub embedding: Vec<f32>,
20 pub sample_rate: u32,
21 pub channels: u16,
22 pub duration_samples: u64,
23 pub quality_score: f32,
24 pub metadata: HashMap<String, String>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct CloningConfig {
29 pub method: CloningMethod,
30 pub target_voice_profile: VoiceProfile,
31 pub similarity_threshold: f32,
32 pub adaptation_rate: f32,
33 pub quality_threshold: f32,
34 pub max_training_iterations: u32,
35 pub use_speaker_verification: bool,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct AdaptationConfig {
40 pub learning_rate: f32,
41 pub momentum: f32,
42 pub weight_decay: f32,
43 pub batch_size: u32,
44 pub gradient_clipping: f32,
45 pub convergence_threshold: f32,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct SpeakerEmbeddingConfig {
50 pub embedding_dimension: u32,
51 pub network_depth: u32,
52 pub attention_heads: u32,
53 pub dropout_rate: f32,
54 pub normalization: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct VoiceCloningConfig {
59 pub base_config: SynthesisConfig,
60 pub cloning_config: CloningConfig,
61 pub adaptation_config: Option<AdaptationConfig>,
62 pub embedding_config: Option<SpeakerEmbeddingConfig>,
63 pub reference_audio_paths: Vec<String>,
64 pub output_quality_target: f32,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct CloningProgress {
69 pub current_iteration: u32,
70 pub total_iterations: u32,
71 pub current_loss: f32,
72 pub best_loss: f32,
73 pub similarity_score: f32,
74 pub quality_score: f32,
75 pub eta_seconds: u32,
76}
77
78pub struct VoiceCloner {
79 voice_profiles: HashMap<String, VoiceProfile>,
80 active_cloning_sessions: HashMap<String, CloningProgress>,
81 embedding_cache: HashMap<String, Vec<f32>>,
82 quality_assessor: QualityAssessor,
83}
84
85struct QualityAssessor {
86 similarity_threshold: f32,
87 quality_threshold: f32,
88}
89
90impl VoiceCloner {
91 pub fn new() -> Self {
92 Self {
93 voice_profiles: HashMap::new(),
94 active_cloning_sessions: HashMap::new(),
95 embedding_cache: HashMap::new(),
96 quality_assessor: QualityAssessor {
97 similarity_threshold: 0.8,
98 quality_threshold: 0.7,
99 },
100 }
101 }
102
103 pub fn add_voice_profile(&mut self, profile: VoiceProfile) -> Result<(), String> {
104 if profile.embedding.is_empty() {
105 return Err("Voice profile embedding cannot be empty".to_string());
106 }
107
108 if profile.quality_score < self.quality_assessor.quality_threshold {
109 return Err("Voice profile quality score below threshold".to_string());
110 }
111
112 self.voice_profiles.insert(profile.id.clone(), profile);
113 Ok(())
114 }
115
116 pub fn create_voice_profile_from_audio(
117 &mut self,
118 id: String,
119 name: String,
120 audio_data: &[f32],
121 sample_rate: u32,
122 channels: u16,
123 ) -> Result<VoiceProfile, String> {
124 if audio_data.is_empty() {
125 return Err("Audio data cannot be empty".to_string());
126 }
127
128 let embedding = self.extract_speaker_embedding(audio_data, sample_rate, channels)?;
129 let quality_score = self.assess_audio_quality(audio_data, sample_rate);
130
131 let profile = VoiceProfile {
132 id: id.clone(),
133 name,
134 embedding,
135 sample_rate,
136 channels,
137 duration_samples: audio_data.len() as u64,
138 quality_score,
139 metadata: HashMap::new(),
140 };
141
142 self.embedding_cache
143 .insert(id.clone(), profile.embedding.clone());
144 Ok(profile)
145 }
146
147 fn extract_speaker_embedding(
148 &self,
149 audio_data: &[f32],
150 sample_rate: u32,
151 channels: u16,
152 ) -> Result<Vec<f32>, String> {
153 let frame_size = (sample_rate as usize / 100) * channels as usize; let mut embeddings = Vec::new();
158
159 for chunk in audio_data.chunks(frame_size) {
160 let mean = chunk.iter().sum::<f32>() / chunk.len() as f32;
161 let variance =
162 chunk.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / chunk.len() as f32;
163 let energy = chunk.iter().map(|x| x.powi(2)).sum::<f32>();
164
165 embeddings.push(mean);
166 embeddings.push(variance.sqrt());
167 embeddings.push(energy.ln().max(-10.0));
168 }
169
170 if embeddings.len() < 128 {
171 embeddings.resize(128, 0.0);
172 } else {
173 embeddings.truncate(128);
174 }
175
176 Ok(embeddings)
177 }
178
179 fn assess_audio_quality(&self, audio_data: &[f32], _sample_rate: u32) -> f32 {
180 let rms =
181 (audio_data.iter().map(|x| x.powi(2)).sum::<f32>() / audio_data.len() as f32).sqrt();
182 let peak = audio_data.iter().map(|x| x.abs()).fold(0.0, f32::max);
183 let dynamic_range = if peak > 0.0 {
184 20.0 * (peak / rms).log10()
185 } else {
186 0.0
187 };
188
189 (dynamic_range / 60.0).clamp(0.0, 1.0)
190 }
191
192 pub fn calculate_voice_similarity(
193 &self,
194 profile1: &VoiceProfile,
195 profile2: &VoiceProfile,
196 ) -> f32 {
197 if profile1.embedding.len() != profile2.embedding.len() {
198 return 0.0;
199 }
200
201 let dot_product: f32 = profile1
202 .embedding
203 .iter()
204 .zip(profile2.embedding.iter())
205 .map(|(a, b)| a * b)
206 .sum();
207
208 let norm1: f32 = profile1
209 .embedding
210 .iter()
211 .map(|x| x.powi(2))
212 .sum::<f32>()
213 .sqrt();
214 let norm2: f32 = profile2
215 .embedding
216 .iter()
217 .map(|x| x.powi(2))
218 .sum::<f32>()
219 .sqrt();
220
221 if norm1 == 0.0 || norm2 == 0.0 {
222 return 0.0;
223 }
224
225 (dot_product / (norm1 * norm2)).clamp(-1.0, 1.0)
226 }
227
228 pub fn start_cloning_session(
229 &mut self,
230 session_id: String,
231 target_profile: &VoiceProfile,
232 config: &CloningConfig,
233 ) -> Result<(), String> {
234 let total_iterations = config.max_training_iterations;
235
236 let progress = CloningProgress {
237 current_iteration: 0,
238 total_iterations,
239 current_loss: 1.0,
240 best_loss: 1.0,
241 similarity_score: 0.0,
242 quality_score: 0.0,
243 eta_seconds: total_iterations * 10, };
245
246 self.active_cloning_sessions.insert(session_id, progress);
247 Ok(())
248 }
249
250 pub fn update_cloning_progress(
251 &mut self,
252 session_id: &str,
253 iteration: u32,
254 loss: f32,
255 similarity_score: f32,
256 quality_score: f32,
257 ) -> Result<(), String> {
258 if let Some(progress) = self.active_cloning_sessions.get_mut(session_id) {
259 progress.current_iteration = iteration;
260 progress.current_loss = loss;
261 progress.similarity_score = similarity_score;
262 progress.quality_score = quality_score;
263
264 if loss < progress.best_loss {
265 progress.best_loss = loss;
266 }
267
268 let remaining_iterations = progress.total_iterations.saturating_sub(iteration);
269 progress.eta_seconds = remaining_iterations * 10;
270
271 Ok(())
272 } else {
273 Err("Cloning session not found".to_string())
274 }
275 }
276
277 pub fn create_cloning_synthesis_config(
278 &self,
279 base_config: SynthesisConfig,
280 cloning_config: CloningConfig,
281 reference_audio_paths: Vec<String>,
282 ) -> VoiceCloningConfig {
283 VoiceCloningConfig {
284 base_config,
285 cloning_config,
286 adaptation_config: Some(AdaptationConfig::default()),
287 embedding_config: Some(SpeakerEmbeddingConfig::default()),
288 reference_audio_paths,
289 output_quality_target: 0.8,
290 }
291 }
292
293 pub fn get_cloning_progress(&self, session_id: &str) -> Option<&CloningProgress> {
294 self.active_cloning_sessions.get(session_id)
295 }
296
297 pub fn is_cloning_complete(&self, session_id: &str) -> bool {
298 if let Some(progress) = self.active_cloning_sessions.get(session_id) {
299 progress.current_iteration >= progress.total_iterations
300 || progress.similarity_score >= self.quality_assessor.similarity_threshold
301 } else {
302 false
303 }
304 }
305
306 pub fn get_voice_profile(&self, profile_id: &str) -> Option<&VoiceProfile> {
307 self.voice_profiles.get(profile_id)
308 }
309
310 pub fn list_voice_profiles(&self) -> Vec<&VoiceProfile> {
311 self.voice_profiles.values().collect()
312 }
313
314 pub fn remove_voice_profile(&mut self, profile_id: &str) -> Option<VoiceProfile> {
315 self.embedding_cache.remove(profile_id);
316 self.voice_profiles.remove(profile_id)
317 }
318
319 pub fn clear_completed_sessions(&mut self) {
320 let completed_sessions: Vec<String> = self
321 .active_cloning_sessions
322 .iter()
323 .filter(|(id, _)| self.is_cloning_complete(id))
324 .map(|(id, _)| id.clone())
325 .collect();
326
327 for session_id in completed_sessions {
328 self.active_cloning_sessions.remove(&session_id);
329 }
330 }
331}
332
333impl Default for VoiceCloner {
334 fn default() -> Self {
335 Self::new()
336 }
337}
338
339impl Default for CloningConfig {
340 fn default() -> Self {
341 Self {
342 method: CloningMethod::SpeakerEmbedding,
343 target_voice_profile: VoiceProfile {
344 id: "default".to_string(),
345 name: "Default Voice".to_string(),
346 embedding: vec![0.0; 128],
347 sample_rate: 22050,
348 channels: 1,
349 duration_samples: 0,
350 quality_score: 0.5,
351 metadata: HashMap::new(),
352 },
353 similarity_threshold: 0.8,
354 adaptation_rate: 0.01,
355 quality_threshold: 0.7,
356 max_training_iterations: 100,
357 use_speaker_verification: true,
358 }
359 }
360}
361
362impl Default for AdaptationConfig {
363 fn default() -> Self {
364 Self {
365 learning_rate: 0.001,
366 momentum: 0.9,
367 weight_decay: 0.0001,
368 batch_size: 32,
369 gradient_clipping: 1.0,
370 convergence_threshold: 1e-6,
371 }
372 }
373}
374
375impl Default for SpeakerEmbeddingConfig {
376 fn default() -> Self {
377 Self {
378 embedding_dimension: 128,
379 network_depth: 4,
380 attention_heads: 8,
381 dropout_rate: 0.1,
382 normalization: true,
383 }
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_voice_cloner_creation() {
393 let cloner = VoiceCloner::new();
394 assert!(cloner.voice_profiles.is_empty());
395 assert!(cloner.active_cloning_sessions.is_empty());
396 }
397
398 #[test]
399 fn test_voice_profile_creation() {
400 let mut cloner = VoiceCloner::new();
401 let audio_data = vec![0.1, 0.2, 0.3, 0.4, 0.5];
402
403 let profile = cloner.create_voice_profile_from_audio(
404 "test_id".to_string(),
405 "Test Voice".to_string(),
406 &audio_data,
407 22050,
408 1,
409 );
410
411 assert!(profile.is_ok());
412 let profile = profile.unwrap();
413 assert_eq!(profile.id, "test_id");
414 assert_eq!(profile.name, "Test Voice");
415 assert_eq!(profile.embedding.len(), 128);
416 }
417
418 #[test]
419 fn test_voice_similarity_calculation() {
420 let cloner = VoiceCloner::new();
421
422 let profile1 = VoiceProfile {
423 id: "voice1".to_string(),
424 name: "Voice 1".to_string(),
425 embedding: vec![1.0, 0.0, 0.0, 1.0],
426 sample_rate: 22050,
427 channels: 1,
428 duration_samples: 1000,
429 quality_score: 0.8,
430 metadata: HashMap::new(),
431 };
432
433 let profile2 = VoiceProfile {
434 id: "voice2".to_string(),
435 name: "Voice 2".to_string(),
436 embedding: vec![1.0, 0.0, 0.0, 1.0],
437 sample_rate: 22050,
438 channels: 1,
439 duration_samples: 1000,
440 quality_score: 0.8,
441 metadata: HashMap::new(),
442 };
443
444 let similarity = cloner.calculate_voice_similarity(&profile1, &profile2);
445 assert!((similarity - 1.0).abs() < 1e-6);
446 }
447
448 #[test]
449 fn test_cloning_session_management() {
450 let mut cloner = VoiceCloner::new();
451 let profile = VoiceProfile {
452 id: "target".to_string(),
453 name: "Target Voice".to_string(),
454 embedding: vec![0.0; 128],
455 sample_rate: 22050,
456 channels: 1,
457 duration_samples: 1000,
458 quality_score: 0.8,
459 metadata: HashMap::new(),
460 };
461
462 let config = CloningConfig::default();
463 let session_id = "session1".to_string();
464
465 assert!(cloner
466 .start_cloning_session(session_id.clone(), &profile, &config)
467 .is_ok());
468 assert!(cloner.get_cloning_progress(&session_id).is_some());
469 assert!(!cloner.is_cloning_complete(&session_id));
470 }
471
472 #[test]
473 fn test_cloning_progress_update() {
474 let mut cloner = VoiceCloner::new();
475 let profile = VoiceProfile {
476 id: "target".to_string(),
477 name: "Target Voice".to_string(),
478 embedding: vec![0.0; 128],
479 sample_rate: 22050,
480 channels: 1,
481 duration_samples: 1000,
482 quality_score: 0.8,
483 metadata: HashMap::new(),
484 };
485
486 let config = CloningConfig::default();
487 let session_id = "session1".to_string();
488
489 cloner
490 .start_cloning_session(session_id.clone(), &profile, &config)
491 .unwrap();
492
493 assert!(cloner
494 .update_cloning_progress(&session_id, 10, 0.5, 0.7, 0.8)
495 .is_ok());
496
497 let progress = cloner.get_cloning_progress(&session_id).unwrap();
498 assert_eq!(progress.current_iteration, 10);
499 assert_eq!(progress.current_loss, 0.5);
500 assert_eq!(progress.similarity_score, 0.7);
501 assert_eq!(progress.quality_score, 0.8);
502 }
503
504 #[test]
505 fn test_config_serialization() {
506 let config = CloningConfig::default();
507 let serialized = serde_json::to_string(&config).unwrap();
508 let deserialized: CloningConfig = serde_json::from_str(&serialized).unwrap();
509
510 assert_eq!(
511 deserialized.similarity_threshold,
512 config.similarity_threshold
513 );
514 assert_eq!(
515 deserialized.max_training_iterations,
516 config.max_training_iterations
517 );
518 }
519}