1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use voirs_sdk::types::SynthesisConfig;
4
5#[derive(Debug, Clone, Serialize, Deserialize, Hash, Eq, PartialEq)]
6pub enum ModalityType {
7 Text,
8 Audio,
9 Visual,
10 Gesture,
11 Facial,
12 Prosody,
13 Contextual,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ModalityConfig {
18 pub modality_type: ModalityType,
19 pub weight: f32, pub synchronization_offset_ms: i32,
21 pub duration_ms: Option<u32>,
22 pub adaptive_weighting: bool,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct CrossModalAlignment {
27 pub text_audio_alignment: f32,
28 pub visual_audio_alignment: f32,
29 pub gesture_audio_alignment: f32,
30 pub prosody_text_alignment: f32,
31 pub alignment_threshold: f32,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct AdaptiveWeighting {
36 pub confidence_threshold: f32,
37 pub weight_adjustment_factor: f32,
38 pub minimum_weight: f32,
39 pub maximum_weight: f32,
40 pub adaptation_speed: f32,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct MultimodalSynthesisConfig {
45 pub base_config: SynthesisConfig,
46 pub modality_configs: Vec<ModalityConfig>,
47 pub cross_modal_alignment: CrossModalAlignment,
48 pub adaptive_weighting: Option<AdaptiveWeighting>,
49 pub synchronization_tolerance_ms: u32,
50 pub fallback_modality: ModalityType,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ModalityData {
55 pub modality_type: ModalityType,
56 pub data: Vec<u8>,
57 pub timestamp_ms: u64,
58 pub confidence: f32,
59 pub metadata: HashMap<String, String>,
60}
61
62pub struct MultimodalSynthesizer {
63 current_modalities: HashMap<ModalityType, ModalityData>,
64 modality_weights: HashMap<ModalityType, f32>,
65 alignment_history: Vec<CrossModalAlignment>,
66 adaptive_config: Option<AdaptiveWeighting>,
67}
68
69impl MultimodalSynthesizer {
70 pub fn new() -> Self {
71 Self {
72 current_modalities: HashMap::new(),
73 modality_weights: HashMap::new(),
74 alignment_history: Vec::new(),
75 adaptive_config: None,
76 }
77 }
78
79 pub fn with_adaptive_weighting(mut self, config: AdaptiveWeighting) -> Self {
80 self.adaptive_config = Some(config);
81 self
82 }
83
84 pub fn add_modality_data(&mut self, modality_data: ModalityData) {
85 let modality_type = modality_data.modality_type.clone();
86 self.current_modalities.insert(modality_type, modality_data);
87 }
88
89 pub fn update_modality_weights(&mut self, weights: HashMap<ModalityType, f32>) {
90 for (modality, weight) in weights {
91 let clamped_weight = weight.clamp(0.0, 1.0);
92 self.modality_weights.insert(modality, clamped_weight);
93 }
94 }
95
96 pub fn calculate_cross_modal_alignment(&self) -> CrossModalAlignment {
97 let text_audio = self.calculate_alignment_score(&ModalityType::Text, &ModalityType::Audio);
98 let visual_audio =
99 self.calculate_alignment_score(&ModalityType::Visual, &ModalityType::Audio);
100 let gesture_audio =
101 self.calculate_alignment_score(&ModalityType::Gesture, &ModalityType::Audio);
102 let prosody_text =
103 self.calculate_alignment_score(&ModalityType::Prosody, &ModalityType::Text);
104
105 CrossModalAlignment {
106 text_audio_alignment: text_audio,
107 visual_audio_alignment: visual_audio,
108 gesture_audio_alignment: gesture_audio,
109 prosody_text_alignment: prosody_text,
110 alignment_threshold: 0.7,
111 }
112 }
113
114 fn calculate_alignment_score(&self, modality1: &ModalityType, modality2: &ModalityType) -> f32 {
115 let data1 = self.current_modalities.get(modality1);
116 let data2 = self.current_modalities.get(modality2);
117
118 match (data1, data2) {
119 (Some(d1), Some(d2)) => {
120 let time_diff = (d1.timestamp_ms as i64 - d2.timestamp_ms as i64).abs();
121 let confidence_product = d1.confidence * d2.confidence;
122 let time_score = 1.0 - (time_diff as f32 / 1000.0).min(1.0);
123
124 time_score * confidence_product
125 }
126 _ => 0.0,
127 }
128 }
129
130 pub fn apply_adaptive_weighting(&mut self) {
131 if let Some(config) = &self.adaptive_config {
132 let alignment = self.calculate_cross_modal_alignment();
133
134 for (modality_type, current_weight) in self.modality_weights.iter_mut() {
135 let alignment_score = match modality_type {
136 ModalityType::Text => alignment.text_audio_alignment,
137 ModalityType::Audio => 1.0, ModalityType::Visual => alignment.visual_audio_alignment,
139 ModalityType::Gesture => alignment.gesture_audio_alignment,
140 ModalityType::Prosody => alignment.prosody_text_alignment,
141 _ => 0.5,
142 };
143
144 if alignment_score > config.confidence_threshold {
145 *current_weight = (*current_weight
146 + config.weight_adjustment_factor * alignment_score)
147 .clamp(config.minimum_weight, config.maximum_weight);
148 } else {
149 *current_weight = (*current_weight
150 - config.weight_adjustment_factor * (1.0 - alignment_score))
151 .clamp(config.minimum_weight, config.maximum_weight);
152 }
153 }
154 }
155 }
156
157 pub fn create_multimodal_synthesis_config(
158 &self,
159 base_config: SynthesisConfig,
160 modality_configs: Vec<ModalityConfig>,
161 ) -> MultimodalSynthesisConfig {
162 let alignment = self.calculate_cross_modal_alignment();
163
164 MultimodalSynthesisConfig {
165 base_config,
166 modality_configs,
167 cross_modal_alignment: alignment,
168 adaptive_weighting: self.adaptive_config.clone(),
169 synchronization_tolerance_ms: 100,
170 fallback_modality: ModalityType::Text,
171 }
172 }
173
174 pub fn synchronize_modalities(&mut self, target_timestamp_ms: u64) {
175 let tolerance_ms = 100;
176
177 for (_, modality_data) in self.current_modalities.iter_mut() {
178 let time_diff = (modality_data.timestamp_ms as i64 - target_timestamp_ms as i64).abs();
179
180 if time_diff > tolerance_ms {
181 modality_data.timestamp_ms = target_timestamp_ms;
182 }
183 }
184 }
185
186 pub fn get_active_modalities(&self) -> Vec<ModalityType> {
187 self.current_modalities.keys().cloned().collect()
188 }
189
190 pub fn get_modality_weight(&self, modality: &ModalityType) -> f32 {
191 self.modality_weights.get(modality).copied().unwrap_or(1.0)
192 }
193
194 pub fn clear_modality_data(&mut self) {
195 self.current_modalities.clear();
196 }
197}
198
199impl Default for MultimodalSynthesizer {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205impl Default for ModalityConfig {
206 fn default() -> Self {
207 Self {
208 modality_type: ModalityType::Text,
209 weight: 1.0,
210 synchronization_offset_ms: 0,
211 duration_ms: None,
212 adaptive_weighting: false,
213 }
214 }
215}
216
217impl Default for CrossModalAlignment {
218 fn default() -> Self {
219 Self {
220 text_audio_alignment: 0.8,
221 visual_audio_alignment: 0.7,
222 gesture_audio_alignment: 0.6,
223 prosody_text_alignment: 0.9,
224 alignment_threshold: 0.7,
225 }
226 }
227}
228
229impl Default for AdaptiveWeighting {
230 fn default() -> Self {
231 Self {
232 confidence_threshold: 0.7,
233 weight_adjustment_factor: 0.1,
234 minimum_weight: 0.1,
235 maximum_weight: 1.0,
236 adaptation_speed: 0.05,
237 }
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_multimodal_synthesizer_creation() {
247 let synthesizer = MultimodalSynthesizer::new();
248 assert!(synthesizer.current_modalities.is_empty());
249 assert!(synthesizer.modality_weights.is_empty());
250 }
251
252 #[test]
253 fn test_modality_data_addition() {
254 let mut synthesizer = MultimodalSynthesizer::new();
255
256 let text_data = ModalityData {
257 modality_type: ModalityType::Text,
258 data: b"Hello world".to_vec(),
259 timestamp_ms: 1000,
260 confidence: 0.9,
261 metadata: HashMap::new(),
262 };
263
264 synthesizer.add_modality_data(text_data);
265 assert_eq!(synthesizer.current_modalities.len(), 1);
266 assert!(synthesizer
267 .current_modalities
268 .contains_key(&ModalityType::Text));
269 }
270
271 #[test]
272 fn test_modality_weight_update() {
273 let mut synthesizer = MultimodalSynthesizer::new();
274
275 let mut weights = HashMap::new();
276 weights.insert(ModalityType::Text, 0.8);
277 weights.insert(ModalityType::Audio, 1.0);
278
279 synthesizer.update_modality_weights(weights);
280 assert_eq!(synthesizer.get_modality_weight(&ModalityType::Text), 0.8);
281 assert_eq!(synthesizer.get_modality_weight(&ModalityType::Audio), 1.0);
282 }
283
284 #[test]
285 fn test_cross_modal_alignment() {
286 let mut synthesizer = MultimodalSynthesizer::new();
287
288 let text_data = ModalityData {
289 modality_type: ModalityType::Text,
290 data: b"Hello".to_vec(),
291 timestamp_ms: 1000,
292 confidence: 0.9,
293 metadata: HashMap::new(),
294 };
295
296 let audio_data = ModalityData {
297 modality_type: ModalityType::Audio,
298 data: vec![0u8; 1024],
299 timestamp_ms: 1050,
300 confidence: 0.8,
301 metadata: HashMap::new(),
302 };
303
304 synthesizer.add_modality_data(text_data);
305 synthesizer.add_modality_data(audio_data);
306
307 let alignment = synthesizer.calculate_cross_modal_alignment();
308 assert!(alignment.text_audio_alignment > 0.0);
309 assert!(alignment.alignment_threshold > 0.0);
310 }
311
312 #[test]
313 fn test_synchronization() {
314 let mut synthesizer = MultimodalSynthesizer::new();
315
316 let text_data = ModalityData {
317 modality_type: ModalityType::Text,
318 data: b"Hello".to_vec(),
319 timestamp_ms: 1000,
320 confidence: 0.9,
321 metadata: HashMap::new(),
322 };
323
324 synthesizer.add_modality_data(text_data);
325 synthesizer.synchronize_modalities(1500);
326
327 let text_modality = synthesizer
328 .current_modalities
329 .get(&ModalityType::Text)
330 .unwrap();
331 assert_eq!(text_modality.timestamp_ms, 1500);
332 }
333
334 #[test]
335 fn test_config_serialization() {
336 let config = ModalityConfig {
337 modality_type: ModalityType::Audio,
338 weight: 0.8,
339 synchronization_offset_ms: 100,
340 duration_ms: Some(5000),
341 adaptive_weighting: true,
342 };
343
344 let serialized = serde_json::to_string(&config).unwrap();
345 let deserialized: ModalityConfig = serde_json::from_str(&serialized).unwrap();
346
347 assert_eq!(deserialized.weight, 0.8);
348 assert_eq!(deserialized.synchronization_offset_ms, 100);
349 assert!(deserialized.adaptive_weighting);
350 }
351}