1use crate::synthesis::SynthesisResult;
38use crate::types::{Expression, VoiceCharacteristics};
39use crate::{Error, Result};
40use chrono::{DateTime, Utc};
41use serde::{Deserialize, Serialize};
42use std::collections::HashMap;
43use std::sync::Arc;
44use tokio::sync::RwLock;
45use uuid::Uuid;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct AdaptiveLearningConfig {
50 pub learning_rate: f32,
52 pub min_samples: usize,
54 pub decay_factor: f32,
56 pub auto_finetune: bool,
58 pub confidence_threshold: f32,
60}
61
62impl Default for AdaptiveLearningConfig {
63 fn default() -> Self {
64 Self {
65 learning_rate: 0.01,
66 min_samples: 10,
67 decay_factor: 0.95,
68 auto_finetune: true,
69 confidence_threshold: 0.7,
70 }
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct UserFeedback {
77 pub id: String,
79 pub user_id: String,
81 pub sample_id: String,
83 #[serde(skip)]
85 pub audio_data: Vec<f32>,
86 pub rating: f32,
88 pub quality_ratings: QualityRatings,
90 pub comments: Option<String>,
92 pub timestamp: DateTime<Utc>,
94}
95
96impl UserFeedback {
97 pub fn new(
99 user_id: impl Into<String>,
100 audio_data: Vec<f32>,
101 rating: f32,
102 comments: Option<impl Into<String>>,
103 ) -> Self {
104 Self {
105 id: Uuid::new_v4().to_string(),
106 user_id: user_id.into(),
107 sample_id: Uuid::new_v4().to_string(),
108 audio_data,
109 rating: rating.clamp(0.0, 5.0),
110 quality_ratings: QualityRatings::default(),
111 comments: comments.map(|c| c.into()),
112 timestamp: Utc::now(),
113 }
114 }
115
116 pub fn with_quality_ratings(mut self, ratings: QualityRatings) -> Self {
118 self.quality_ratings = ratings;
119 self
120 }
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct QualityRatings {
126 pub pitch_accuracy: f32,
128 pub timing_precision: f32,
130 pub naturalness: f32,
132 pub expression: f32,
134 pub voice_quality: f32,
136}
137
138impl Default for QualityRatings {
139 fn default() -> Self {
140 Self {
141 pitch_accuracy: 3.0,
142 timing_precision: 3.0,
143 naturalness: 3.0,
144 expression: 3.0,
145 voice_quality: 3.0,
146 }
147 }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct UserPreferences {
153 pub user_id: String,
155 pub voice_preferences: VoiceCharacteristics,
157 pub expression_preferences: HashMap<String, f32>,
159 pub quality_weights: QualityWeights,
161 pub sample_count: usize,
163 pub confidence: f32,
165 pub last_updated: DateTime<Utc>,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct QualityWeights {
172 pub pitch_weight: f32,
174 pub timing_weight: f32,
176 pub naturalness_weight: f32,
178 pub expression_weight: f32,
180 pub voice_quality_weight: f32,
182}
183
184impl Default for QualityWeights {
185 fn default() -> Self {
186 Self {
187 pitch_weight: 1.0,
188 timing_weight: 1.0,
189 naturalness_weight: 1.0,
190 expression_weight: 1.0,
191 voice_quality_weight: 1.0,
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct StyleAdaptation {
199 pub style_id: String,
201 pub vibrato_params: VibratoParams,
203 pub dynamics_params: DynamicsParams,
205 pub articulation_params: ArticulationParams,
207 pub example_count: usize,
209 pub confidence: f32,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct VibratoParams {
216 pub rate: f32,
218 pub depth: f32,
220 pub onset_delay: f32,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct DynamicsParams {
227 pub average_level: f32,
229 pub dynamic_range: f32,
231 pub crescendo_rate: f32,
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct ArticulationParams {
238 pub legato_amount: f32,
240 pub staccato_amount: f32,
242 pub accent_strength: f32,
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct PersonalizedRecommendations {
249 pub voice_characteristics: VoiceCharacteristics,
251 pub parameter_adjustments: HashMap<String, f32>,
253 pub techniques: Vec<String>,
255 pub confidence: f32,
257}
258
259#[derive(Debug, Clone, Serialize, Deserialize)]
261pub struct ModelImprovement {
262 pub iteration: usize,
264 pub quality_delta: f32,
266 pub satisfaction_delta: f32,
268 pub training_samples: usize,
270 pub timestamp: DateTime<Utc>,
272}
273
274pub struct AdaptiveLearningSystem {
276 config: AdaptiveLearningConfig,
277 user_preferences: Arc<RwLock<HashMap<String, UserPreferences>>>,
278 style_adaptations: Arc<RwLock<HashMap<String, StyleAdaptation>>>,
279 feedback_history: Arc<RwLock<Vec<UserFeedback>>>,
280 improvement_history: Arc<RwLock<Vec<ModelImprovement>>>,
281}
282
283impl AdaptiveLearningSystem {
284 pub fn new(config: AdaptiveLearningConfig) -> Self {
286 Self {
287 config,
288 user_preferences: Arc::new(RwLock::new(HashMap::new())),
289 style_adaptations: Arc::new(RwLock::new(HashMap::new())),
290 feedback_history: Arc::new(RwLock::new(Vec::new())),
291 improvement_history: Arc::new(RwLock::new(Vec::new())),
292 }
293 }
294
295 pub async fn add_feedback(&self, feedback: UserFeedback) -> Result<()> {
297 let user_id = feedback.user_id.clone();
298
299 self.feedback_history.write().await.push(feedback.clone());
301
302 self.update_user_preferences(&user_id, &feedback).await?;
304
305 self.update_style_adaptations(&feedback).await?;
307
308 if self.config.auto_finetune {
310 self.check_and_finetune(&user_id).await?;
311 }
312
313 Ok(())
314 }
315
316 pub async fn get_recommendations(&self, user_id: &str) -> Result<PersonalizedRecommendations> {
318 let prefs = self.user_preferences.read().await;
319
320 let user_pref = prefs.get(user_id).ok_or_else(|| {
321 Error::Processing(format!("No preferences found for user: {}", user_id))
322 })?;
323
324 if user_pref.confidence < self.config.confidence_threshold {
326 return Err(Error::Processing(
327 "Insufficient data for personalized recommendations".to_string(),
328 ));
329 }
330
331 let mut parameter_adjustments = HashMap::new();
333
334 parameter_adjustments.insert(
336 "vibrato_frequency".to_string(),
337 user_pref.voice_preferences.vibrato_frequency,
338 );
339 parameter_adjustments.insert(
340 "vibrato_depth".to_string(),
341 user_pref.voice_preferences.vibrato_depth,
342 );
343
344 for (key, value) in &user_pref.expression_preferences {
346 parameter_adjustments.insert(key.clone(), *value);
347 }
348
349 let mut techniques = Vec::new();
351 if user_pref.quality_weights.expression_weight > 1.2 {
352 techniques.push("enhanced_expression".to_string());
353 }
354 if user_pref.quality_weights.naturalness_weight > 1.2 {
355 techniques.push("natural_breath_patterns".to_string());
356 }
357
358 Ok(PersonalizedRecommendations {
359 voice_characteristics: user_pref.voice_preferences.clone(),
360 parameter_adjustments,
361 techniques,
362 confidence: user_pref.confidence,
363 })
364 }
365
366 pub async fn get_user_preferences(&self, user_id: &str) -> Option<UserPreferences> {
368 self.user_preferences.read().await.get(user_id).cloned()
369 }
370
371 pub async fn get_style_adaptation(&self, style_id: &str) -> Option<StyleAdaptation> {
373 self.style_adaptations.read().await.get(style_id).cloned()
374 }
375
376 pub async fn get_improvement_history(&self) -> Vec<ModelImprovement> {
378 self.improvement_history.read().await.clone()
379 }
380
381 async fn update_user_preferences(&self, user_id: &str, feedback: &UserFeedback) -> Result<()> {
383 let mut prefs = self.user_preferences.write().await;
384
385 let user_pref = prefs
386 .entry(user_id.to_string())
387 .or_insert_with(|| UserPreferences {
388 user_id: user_id.to_string(),
389 voice_preferences: VoiceCharacteristics::default(),
390 expression_preferences: HashMap::new(),
391 quality_weights: QualityWeights::default(),
392 sample_count: 0,
393 confidence: 0.0,
394 last_updated: Utc::now(),
395 });
396
397 user_pref.sample_count += 1;
399
400 let lr = self.config.learning_rate;
402 user_pref.quality_weights.pitch_weight +=
403 lr * (feedback.quality_ratings.pitch_accuracy - 3.0);
404 user_pref.quality_weights.timing_weight +=
405 lr * (feedback.quality_ratings.timing_precision - 3.0);
406 user_pref.quality_weights.naturalness_weight +=
407 lr * (feedback.quality_ratings.naturalness - 3.0);
408 user_pref.quality_weights.expression_weight +=
409 lr * (feedback.quality_ratings.expression - 3.0);
410 user_pref.quality_weights.voice_quality_weight +=
411 lr * (feedback.quality_ratings.voice_quality - 3.0);
412
413 let weight_sum = user_pref.quality_weights.pitch_weight
415 + user_pref.quality_weights.timing_weight
416 + user_pref.quality_weights.naturalness_weight
417 + user_pref.quality_weights.expression_weight
418 + user_pref.quality_weights.voice_quality_weight;
419
420 if weight_sum > 0.0 {
421 user_pref.quality_weights.pitch_weight /= weight_sum / 5.0;
422 user_pref.quality_weights.timing_weight /= weight_sum / 5.0;
423 user_pref.quality_weights.naturalness_weight /= weight_sum / 5.0;
424 user_pref.quality_weights.expression_weight /= weight_sum / 5.0;
425 user_pref.quality_weights.voice_quality_weight /= weight_sum / 5.0;
426 }
427
428 user_pref.confidence =
430 (user_pref.sample_count as f32 / self.config.min_samples as f32).min(1.0);
431
432 user_pref.last_updated = Utc::now();
433
434 Ok(())
435 }
436
437 async fn update_style_adaptations(&self, feedback: &UserFeedback) -> Result<()> {
439 let style_id = format!("user_{}_style", feedback.user_id);
441
442 let mut adaptations = self.style_adaptations.write().await;
443
444 let adaptation = adaptations
445 .entry(style_id.clone())
446 .or_insert_with(|| StyleAdaptation {
447 style_id,
448 vibrato_params: VibratoParams {
449 rate: 5.0,
450 depth: 0.5,
451 onset_delay: 0.1,
452 },
453 dynamics_params: DynamicsParams {
454 average_level: 0.7,
455 dynamic_range: 0.4,
456 crescendo_rate: 0.1,
457 },
458 articulation_params: ArticulationParams {
459 legato_amount: 0.7,
460 staccato_amount: 0.3,
461 accent_strength: 0.5,
462 },
463 example_count: 0,
464 confidence: 0.0,
465 });
466
467 let lr = self.config.learning_rate;
469 if feedback.rating > 4.0 {
470 adaptation.example_count += 1;
472 } else if feedback.rating < 3.0 {
473 adaptation.vibrato_params.rate *= 1.0 - lr;
475 adaptation.vibrato_params.depth *= 1.0 - lr;
476 }
477
478 adaptation.confidence =
480 (adaptation.example_count as f32 / self.config.min_samples as f32).min(1.0);
481
482 Ok(())
483 }
484
485 async fn check_and_finetune(&self, user_id: &str) -> Result<()> {
487 let prefs = self.user_preferences.read().await;
488
489 if let Some(user_pref) = prefs.get(user_id) {
490 if user_pref.sample_count >= self.config.min_samples
491 && user_pref.confidence >= self.config.confidence_threshold
492 {
493 drop(prefs); self.perform_finetuning(user_id).await?;
496 }
497 }
498
499 Ok(())
500 }
501
502 async fn perform_finetuning(&self, user_id: &str) -> Result<()> {
504 let feedback = self.feedback_history.read().await;
506 let user_feedback: Vec<_> = feedback.iter().filter(|f| f.user_id == user_id).collect();
507
508 if user_feedback.is_empty() {
509 return Ok(());
510 }
511
512 let avg_rating: f32 =
514 user_feedback.iter().map(|f| f.rating).sum::<f32>() / user_feedback.len() as f32;
515 let baseline_rating = 3.0;
516 let quality_delta = avg_rating - baseline_rating;
517
518 let improvement = ModelImprovement {
519 iteration: self.improvement_history.read().await.len(),
520 quality_delta,
521 satisfaction_delta: quality_delta * 0.2, training_samples: user_feedback.len(),
523 timestamp: Utc::now(),
524 };
525
526 self.improvement_history.write().await.push(improvement);
527
528 Ok(())
529 }
530
531 pub async fn get_statistics(&self) -> LearningStatistics {
533 let prefs = self.user_preferences.read().await;
534 let feedback = self.feedback_history.read().await;
535 let improvements = self.improvement_history.read().await;
536
537 let avg_rating = if !feedback.is_empty() {
538 feedback.iter().map(|f| f.rating).sum::<f32>() / feedback.len() as f32
539 } else {
540 0.0
541 };
542
543 let total_improvement = improvements.iter().map(|i| i.quality_delta).sum::<f32>();
544
545 LearningStatistics {
546 total_users: prefs.len(),
547 total_feedback: feedback.len(),
548 average_rating: avg_rating,
549 total_improvements: improvements.len(),
550 cumulative_improvement: total_improvement,
551 }
552 }
553}
554
555#[derive(Debug, Clone, Serialize, Deserialize)]
557pub struct LearningStatistics {
558 pub total_users: usize,
560 pub total_feedback: usize,
562 pub average_rating: f32,
564 pub total_improvements: usize,
566 pub cumulative_improvement: f32,
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573
574 #[tokio::test]
575 async fn test_feedback_collection() {
576 let config = AdaptiveLearningConfig::default();
577 let system = AdaptiveLearningSystem::new(config);
578
579 let feedback = UserFeedback::new(
580 "user1",
581 vec![0.0; 1000],
582 4.5,
583 Some("Great quality".to_string()),
584 );
585
586 system.add_feedback(feedback).await.unwrap();
587
588 let stats = system.get_statistics().await;
589 assert_eq!(stats.total_feedback, 1);
590 assert_eq!(stats.total_users, 1);
591 }
592
593 #[tokio::test]
594 async fn test_preference_learning() {
595 let config = AdaptiveLearningConfig {
596 min_samples: 2,
597 ..Default::default()
598 };
599 let system = AdaptiveLearningSystem::new(config);
600
601 for i in 0..3 {
603 let mut feedback = UserFeedback::new(
604 "user1",
605 vec![0.0; 1000],
606 4.0 + i as f32 * 0.2,
607 None::<String>,
608 );
609 feedback.quality_ratings.pitch_accuracy = 4.5;
610 feedback.quality_ratings.naturalness = 4.0;
611 system.add_feedback(feedback).await.unwrap();
612 }
613
614 let prefs = system.get_user_preferences("user1").await;
615 assert!(prefs.is_some());
616
617 let prefs = prefs.unwrap();
618 assert!(prefs.confidence > 0.5);
619 assert_eq!(prefs.sample_count, 3);
620 }
621
622 #[tokio::test]
623 async fn test_personalized_recommendations() {
624 let config = AdaptiveLearningConfig {
625 min_samples: 2,
626 confidence_threshold: 0.5,
627 ..Default::default()
628 };
629 let system = AdaptiveLearningSystem::new(config);
630
631 for _ in 0..3 {
633 let feedback = UserFeedback::new("user1", vec![0.0; 1000], 4.5, None::<String>);
634 system.add_feedback(feedback).await.unwrap();
635 }
636
637 let recommendations = system.get_recommendations("user1").await;
638 assert!(recommendations.is_ok());
639
640 let recs = recommendations.unwrap();
641 assert!(recs.confidence >= 0.5);
642 assert!(!recs.parameter_adjustments.is_empty());
643 }
644
645 #[tokio::test]
646 async fn test_style_adaptation() {
647 let config = AdaptiveLearningConfig::default();
648 let system = AdaptiveLearningSystem::new(config);
649
650 let feedback = UserFeedback::new(
651 "user1",
652 vec![0.0; 1000],
653 4.8,
654 Some("Excellent style".to_string()),
655 );
656
657 system.add_feedback(feedback).await.unwrap();
658
659 let style_id = "user_user1_style";
660 let adaptation = system.get_style_adaptation(style_id).await;
661 assert!(adaptation.is_some());
662 }
663
664 #[tokio::test]
665 async fn test_quality_weight_updates() {
666 let config = AdaptiveLearningConfig {
667 learning_rate: 0.1,
668 ..Default::default()
669 };
670 let system = AdaptiveLearningSystem::new(config);
671
672 let mut feedback = UserFeedback::new("user1", vec![0.0; 1000], 4.0, None::<String>);
673 feedback.quality_ratings.pitch_accuracy = 5.0;
674 feedback.quality_ratings.timing_precision = 2.0;
675
676 system.add_feedback(feedback).await.unwrap();
677
678 let prefs = system.get_user_preferences("user1").await.unwrap();
679
680 assert!(prefs.quality_weights.pitch_weight > 1.0);
682 assert!(prefs.quality_weights.timing_weight < 1.0);
683 }
684
685 #[tokio::test]
686 async fn test_model_improvement_tracking() {
687 let config = AdaptiveLearningConfig {
688 min_samples: 2,
689 auto_finetune: true,
690 confidence_threshold: 0.5,
691 ..Default::default()
692 };
693 let system = AdaptiveLearningSystem::new(config);
694
695 for _ in 0..3 {
697 let feedback = UserFeedback::new("user1", vec![0.0; 1000], 4.5, None::<String>);
698 system.add_feedback(feedback).await.unwrap();
699 }
700
701 let history = system.get_improvement_history().await;
702 assert!(!history.is_empty());
703 }
704
705 #[tokio::test]
706 async fn test_confidence_calculation() {
707 let config = AdaptiveLearningConfig {
708 min_samples: 10,
709 ..Default::default()
710 };
711 let system = AdaptiveLearningSystem::new(config);
712
713 for _ in 0..5 {
715 let feedback = UserFeedback::new("user1", vec![0.0; 1000], 4.0, None::<String>);
716 system.add_feedback(feedback).await.unwrap();
717 }
718
719 let prefs = system.get_user_preferences("user1").await.unwrap();
720 assert!((prefs.confidence - 0.5).abs() < 0.01);
721 }
722
723 #[tokio::test]
724 async fn test_learning_statistics() {
725 let config = AdaptiveLearningConfig::default();
726 let system = AdaptiveLearningSystem::new(config);
727
728 for user_id in &["user1", "user2", "user3"] {
730 for _ in 0..2 {
731 let feedback = UserFeedback::new(*user_id, vec![0.0; 1000], 4.0, None::<String>);
732 system.add_feedback(feedback).await.unwrap();
733 }
734 }
735
736 let stats = system.get_statistics().await;
737 assert_eq!(stats.total_users, 3);
738 assert_eq!(stats.total_feedback, 6);
739 assert!((stats.average_rating - 4.0).abs() < 0.01);
740 }
741
742 #[tokio::test]
743 async fn test_insufficient_data_error() {
744 let config = AdaptiveLearningConfig {
745 min_samples: 10,
746 confidence_threshold: 0.8,
747 ..Default::default()
748 };
749 let system = AdaptiveLearningSystem::new(config);
750
751 let feedback = UserFeedback::new("user1", vec![0.0; 1000], 4.0, None::<String>);
753 system.add_feedback(feedback).await.unwrap();
754
755 let result = system.get_recommendations("user1").await;
757 assert!(result.is_err());
758 }
759}