Skip to main content

voirs_cli/synthesis/
multimodal.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use voirs_sdk::types::SynthesisConfig;
4
5#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
6pub enum ModalityType {
7    Text,
8    Audio,
9    Visual,
10    Gesture,
11    Facial,
12    Prosody,
13    Contextual,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ModalityConfig {
18    pub modality_type: ModalityType,
19    pub weight: f32, // 0.0 to 1.0
20    pub synchronization_offset_ms: i32,
21    pub duration_ms: Option<u32>,
22    pub adaptive_weighting: bool,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct CrossModalAlignment {
27    pub text_audio_alignment: f32,
28    pub visual_audio_alignment: f32,
29    pub gesture_audio_alignment: f32,
30    pub prosody_text_alignment: f32,
31    pub alignment_threshold: f32,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct AdaptiveWeighting {
36    pub confidence_threshold: f32,
37    pub weight_adjustment_factor: f32,
38    pub minimum_weight: f32,
39    pub maximum_weight: f32,
40    pub adaptation_speed: f32,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct MultimodalSynthesisConfig {
45    pub base_config: SynthesisConfig,
46    pub modality_configs: Vec<ModalityConfig>,
47    pub cross_modal_alignment: CrossModalAlignment,
48    pub adaptive_weighting: Option<AdaptiveWeighting>,
49    pub synchronization_tolerance_ms: u32,
50    pub fallback_modality: ModalityType,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ModalityData {
55    pub modality_type: ModalityType,
56    pub data: Vec<u8>,
57    pub timestamp_ms: u64,
58    pub confidence: f32,
59    pub metadata: HashMap<String, String>,
60}
61
62pub struct MultimodalSynthesizer {
63    current_modalities: HashMap<ModalityType, ModalityData>,
64    modality_weights: HashMap<ModalityType, f32>,
65    alignment_history: Vec<CrossModalAlignment>,
66    adaptive_config: Option<AdaptiveWeighting>,
67}
68
69impl MultimodalSynthesizer {
70    pub fn new() -> Self {
71        Self {
72            current_modalities: HashMap::new(),
73            modality_weights: HashMap::new(),
74            alignment_history: Vec::new(),
75            adaptive_config: None,
76        }
77    }
78
79    pub fn with_adaptive_weighting(mut self, config: AdaptiveWeighting) -> Self {
80        self.adaptive_config = Some(config);
81        self
82    }
83
84    pub fn add_modality_data(&mut self, modality_data: ModalityData) {
85        let modality_type = modality_data.modality_type.clone();
86        self.current_modalities.insert(modality_type, modality_data);
87    }
88
89    pub fn update_modality_weights(&mut self, weights: HashMap<ModalityType, f32>) {
90        for (modality, weight) in weights {
91            let clamped_weight = weight.clamp(0.0, 1.0);
92            self.modality_weights.insert(modality, clamped_weight);
93        }
94    }
95
96    pub fn calculate_cross_modal_alignment(&self) -> CrossModalAlignment {
97        let text_audio = self.calculate_alignment_score(&ModalityType::Text, &ModalityType::Audio);
98        let visual_audio =
99            self.calculate_alignment_score(&ModalityType::Visual, &ModalityType::Audio);
100        let gesture_audio =
101            self.calculate_alignment_score(&ModalityType::Gesture, &ModalityType::Audio);
102        let prosody_text =
103            self.calculate_alignment_score(&ModalityType::Prosody, &ModalityType::Text);
104
105        CrossModalAlignment {
106            text_audio_alignment: text_audio,
107            visual_audio_alignment: visual_audio,
108            gesture_audio_alignment: gesture_audio,
109            prosody_text_alignment: prosody_text,
110            alignment_threshold: 0.7,
111        }
112    }
113
114    fn calculate_alignment_score(&self, modality1: &ModalityType, modality2: &ModalityType) -> f32 {
115        let data1 = self.current_modalities.get(modality1);
116        let data2 = self.current_modalities.get(modality2);
117
118        match (data1, data2) {
119            (Some(d1), Some(d2)) => {
120                let time_diff = (d1.timestamp_ms as i64 - d2.timestamp_ms as i64).abs();
121                let confidence_product = d1.confidence * d2.confidence;
122                let time_score = 1.0 - (time_diff as f32 / 1000.0).min(1.0);
123
124                time_score * confidence_product
125            }
126            _ => 0.0,
127        }
128    }
129
130    pub fn apply_adaptive_weighting(&mut self) {
131        if let Some(config) = &self.adaptive_config {
132            let alignment = self.calculate_cross_modal_alignment();
133
134            for (modality_type, current_weight) in self.modality_weights.iter_mut() {
135                let alignment_score = match modality_type {
136                    ModalityType::Text => alignment.text_audio_alignment,
137                    ModalityType::Audio => 1.0, // Audio is reference
138                    ModalityType::Visual => alignment.visual_audio_alignment,
139                    ModalityType::Gesture => alignment.gesture_audio_alignment,
140                    ModalityType::Prosody => alignment.prosody_text_alignment,
141                    _ => 0.5,
142                };
143
144                if alignment_score > config.confidence_threshold {
145                    *current_weight = (*current_weight
146                        + config.weight_adjustment_factor * alignment_score)
147                        .clamp(config.minimum_weight, config.maximum_weight);
148                } else {
149                    *current_weight = (*current_weight
150                        - config.weight_adjustment_factor * (1.0 - alignment_score))
151                        .clamp(config.minimum_weight, config.maximum_weight);
152                }
153            }
154        }
155    }
156
157    pub fn create_multimodal_synthesis_config(
158        &self,
159        base_config: SynthesisConfig,
160        modality_configs: Vec<ModalityConfig>,
161    ) -> MultimodalSynthesisConfig {
162        let alignment = self.calculate_cross_modal_alignment();
163
164        MultimodalSynthesisConfig {
165            base_config,
166            modality_configs,
167            cross_modal_alignment: alignment,
168            adaptive_weighting: self.adaptive_config.clone(),
169            synchronization_tolerance_ms: 100,
170            fallback_modality: ModalityType::Text,
171        }
172    }
173
174    pub fn synchronize_modalities(&mut self, target_timestamp_ms: u64) {
175        let tolerance_ms = 100;
176
177        for (_, modality_data) in self.current_modalities.iter_mut() {
178            let time_diff = (modality_data.timestamp_ms as i64 - target_timestamp_ms as i64).abs();
179
180            if time_diff > tolerance_ms {
181                modality_data.timestamp_ms = target_timestamp_ms;
182            }
183        }
184    }
185
186    pub fn get_active_modalities(&self) -> Vec<ModalityType> {
187        self.current_modalities.keys().cloned().collect()
188    }
189
190    pub fn get_modality_weight(&self, modality: &ModalityType) -> f32 {
191        self.modality_weights.get(modality).copied().unwrap_or(1.0)
192    }
193
194    pub fn clear_modality_data(&mut self) {
195        self.current_modalities.clear();
196    }
197}
198
199impl Default for MultimodalSynthesizer {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205impl Default for ModalityConfig {
206    fn default() -> Self {
207        Self {
208            modality_type: ModalityType::Text,
209            weight: 1.0,
210            synchronization_offset_ms: 0,
211            duration_ms: None,
212            adaptive_weighting: false,
213        }
214    }
215}
216
217impl Default for CrossModalAlignment {
218    fn default() -> Self {
219        Self {
220            text_audio_alignment: 0.8,
221            visual_audio_alignment: 0.7,
222            gesture_audio_alignment: 0.6,
223            prosody_text_alignment: 0.9,
224            alignment_threshold: 0.7,
225        }
226    }
227}
228
229impl Default for AdaptiveWeighting {
230    fn default() -> Self {
231        Self {
232            confidence_threshold: 0.7,
233            weight_adjustment_factor: 0.1,
234            minimum_weight: 0.1,
235            maximum_weight: 1.0,
236            adaptation_speed: 0.05,
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_multimodal_synthesizer_creation() {
247        let synthesizer = MultimodalSynthesizer::new();
248        assert!(synthesizer.current_modalities.is_empty());
249        assert!(synthesizer.modality_weights.is_empty());
250    }
251
252    #[test]
253    fn test_modality_data_addition() {
254        let mut synthesizer = MultimodalSynthesizer::new();
255
256        let text_data = ModalityData {
257            modality_type: ModalityType::Text,
258            data: b"Hello world".to_vec(),
259            timestamp_ms: 1000,
260            confidence: 0.9,
261            metadata: HashMap::new(),
262        };
263
264        synthesizer.add_modality_data(text_data);
265        assert_eq!(synthesizer.current_modalities.len(), 1);
266        assert!(synthesizer
267            .current_modalities
268            .contains_key(&ModalityType::Text));
269    }
270
271    #[test]
272    fn test_modality_weight_update() {
273        let mut synthesizer = MultimodalSynthesizer::new();
274
275        let mut weights = HashMap::new();
276        weights.insert(ModalityType::Text, 0.8);
277        weights.insert(ModalityType::Audio, 1.0);
278
279        synthesizer.update_modality_weights(weights);
280        assert_eq!(synthesizer.get_modality_weight(&ModalityType::Text), 0.8);
281        assert_eq!(synthesizer.get_modality_weight(&ModalityType::Audio), 1.0);
282    }
283
284    #[test]
285    fn test_cross_modal_alignment() {
286        let mut synthesizer = MultimodalSynthesizer::new();
287
288        let text_data = ModalityData {
289            modality_type: ModalityType::Text,
290            data: b"Hello".to_vec(),
291            timestamp_ms: 1000,
292            confidence: 0.9,
293            metadata: HashMap::new(),
294        };
295
296        let audio_data = ModalityData {
297            modality_type: ModalityType::Audio,
298            data: vec![0u8; 1024],
299            timestamp_ms: 1050,
300            confidence: 0.8,
301            metadata: HashMap::new(),
302        };
303
304        synthesizer.add_modality_data(text_data);
305        synthesizer.add_modality_data(audio_data);
306
307        let alignment = synthesizer.calculate_cross_modal_alignment();
308        assert!(alignment.text_audio_alignment > 0.0);
309        assert!(alignment.alignment_threshold > 0.0);
310    }
311
312    #[test]
313    fn test_synchronization() {
314        let mut synthesizer = MultimodalSynthesizer::new();
315
316        let text_data = ModalityData {
317            modality_type: ModalityType::Text,
318            data: b"Hello".to_vec(),
319            timestamp_ms: 1000,
320            confidence: 0.9,
321            metadata: HashMap::new(),
322        };
323
324        synthesizer.add_modality_data(text_data);
325        synthesizer.synchronize_modalities(1500);
326
327        let text_modality = synthesizer
328            .current_modalities
329            .get(&ModalityType::Text)
330            .unwrap();
331        assert_eq!(text_modality.timestamp_ms, 1500);
332    }
333
334    #[test]
335    fn test_config_serialization() {
336        let config = ModalityConfig {
337            modality_type: ModalityType::Audio,
338            weight: 0.8,
339            synchronization_offset_ms: 100,
340            duration_ms: Some(5000),
341            adaptive_weighting: true,
342        };
343
344        let serialized = serde_json::to_string(&config).unwrap();
345        let deserialized: ModalityConfig = serde_json::from_str(&serialized).unwrap();
346
347        assert_eq!(deserialized.weight, 0.8);
348        assert_eq!(deserialized.synchronization_offset_ms, 100);
349        assert!(deserialized.adaptive_weighting);
350    }
351}