1use super::types_config::{
8 CrossModalStrategy, FusionStrategy, InterpolationMethod, Modality, SyncMethod,
9 TransformParameter,
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::time::{Duration, SystemTime};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct MultiModalConfig {
18 pub enabled: bool,
20 pub modalities: Vec<Modality>,
22 pub fusion_strategy: FusionStrategy,
24 pub temporal_alignment: TemporalAlignmentConfig,
26 pub cross_modal_learning: CrossModalLearningConfig,
28 pub modality_weights: HashMap<Modality, f64>,
30 pub sync_requirements: SynchronizationRequirements,
32}
33
34impl Default for MultiModalConfig {
35 fn default() -> Self {
36 Self {
37 enabled: false,
38 modalities: vec![Modality::Visual],
39 fusion_strategy: FusionStrategy::LateFusion,
40 temporal_alignment: TemporalAlignmentConfig::default(),
41 cross_modal_learning: CrossModalLearningConfig::default(),
42 modality_weights: HashMap::new(),
43 sync_requirements: SynchronizationRequirements::default(),
44 }
45 }
46}
47
48impl MultiModalConfig {
49 #[must_use]
51 pub fn vision_audio() -> Self {
52 let mut weights = HashMap::new();
53 weights.insert(Modality::Visual, 0.7);
54 weights.insert(Modality::Audio, 0.3);
55
56 Self {
57 enabled: true,
58 modalities: vec![Modality::Visual, Modality::Audio],
59 fusion_strategy: FusionStrategy::EarlyFusion,
60 temporal_alignment: TemporalAlignmentConfig::strict(),
61 cross_modal_learning: CrossModalLearningConfig::contrastive(),
62 modality_weights: weights,
63 sync_requirements: SynchronizationRequirements::hardware(),
64 }
65 }
66
67 #[must_use]
69 pub fn vision_depth() -> Self {
70 let mut weights = HashMap::new();
71 weights.insert(Modality::Visual, 0.6);
72 weights.insert(Modality::Depth, 0.4);
73
74 Self {
75 enabled: true,
76 modalities: vec![Modality::Visual, Modality::Depth],
77 fusion_strategy: FusionStrategy::HybridFusion,
78 temporal_alignment: TemporalAlignmentConfig::relaxed(),
79 cross_modal_learning: CrossModalLearningConfig::shared_representation(),
80 modality_weights: weights,
81 sync_requirements: SynchronizationRequirements::software(),
82 }
83 }
84
85 #[must_use]
87 pub fn multi_sensor() -> Self {
88 let mut weights = HashMap::new();
89 weights.insert(Modality::Visual, 0.4);
90 weights.insert(Modality::LiDAR, 0.35);
91 weights.insert(Modality::Radar, 0.25);
92
93 Self {
94 enabled: true,
95 modalities: vec![Modality::Visual, Modality::LiDAR, Modality::Radar],
96 fusion_strategy: FusionStrategy::AttentionFusion,
97 temporal_alignment: TemporalAlignmentConfig::precise(),
98 cross_modal_learning: CrossModalLearningConfig::alignment(),
99 modality_weights: weights,
100 sync_requirements: SynchronizationRequirements::gps(),
101 }
102 }
103
104 #[must_use]
106 pub fn get_modality_weight(&self, modality: &Modality) -> f64 {
107 self.modality_weights.get(modality).copied().unwrap_or(1.0)
108 }
109
110 pub fn set_modality_weight(&mut self, modality: Modality, weight: f64) {
112 self.modality_weights.insert(modality, weight);
113 }
114
115 pub fn validate(&self) -> Result<(), MultiModalError> {
117 if self.enabled && self.modalities.is_empty() {
118 return Err(MultiModalError::NoModalities);
119 }
120
121 let total_weight: f64 = self.modality_weights.values().sum();
123 if (total_weight - 1.0).abs() > 0.1 {
124 return Err(MultiModalError::InvalidWeights {
125 total: total_weight,
126 });
127 }
128
129 Ok(())
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct TemporalAlignmentConfig {
136 pub enabled: bool,
138 pub sync_method: SyncMethod,
140 pub max_time_offset: Duration,
142 pub interpolation: InterpolationMethod,
144 pub buffer_size: usize,
146 pub alignment_tolerance: Duration,
148 pub predictive_alignment: bool,
150}
151
152impl Default for TemporalAlignmentConfig {
153 fn default() -> Self {
154 Self {
155 enabled: true,
156 sync_method: SyncMethod::Software,
157 max_time_offset: Duration::from_millis(100),
158 interpolation: InterpolationMethod::Linear,
159 buffer_size: 10,
160 alignment_tolerance: Duration::from_millis(50),
161 predictive_alignment: false,
162 }
163 }
164}
165
166impl TemporalAlignmentConfig {
167 #[must_use]
169 pub fn strict() -> Self {
170 Self {
171 enabled: true,
172 sync_method: SyncMethod::Hardware,
173 max_time_offset: Duration::from_millis(10),
174 interpolation: InterpolationMethod::Cubic,
175 buffer_size: 20,
176 alignment_tolerance: Duration::from_millis(5),
177 predictive_alignment: true,
178 }
179 }
180
181 #[must_use]
183 pub fn relaxed() -> Self {
184 Self {
185 enabled: true,
186 sync_method: SyncMethod::Software,
187 max_time_offset: Duration::from_millis(500),
188 interpolation: InterpolationMethod::Nearest,
189 buffer_size: 5,
190 alignment_tolerance: Duration::from_millis(200),
191 predictive_alignment: false,
192 }
193 }
194
195 #[must_use]
197 pub fn precise() -> Self {
198 Self {
199 enabled: true,
200 sync_method: SyncMethod::GPS,
201 max_time_offset: Duration::from_micros(100),
202 interpolation: InterpolationMethod::Spline,
203 buffer_size: 50,
204 alignment_tolerance: Duration::from_micros(50),
205 predictive_alignment: true,
206 }
207 }
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct CrossModalLearningConfig {
213 pub enabled: bool,
215 pub strategy: CrossModalStrategy,
217 pub contrastive_learning: ContrastiveLearningConfig,
219 pub distillation: DistillationConfig,
221 pub alignment_learning: AlignmentLearningConfig,
223 pub shared_representation_dim: usize,
225}
226
227impl Default for CrossModalLearningConfig {
228 fn default() -> Self {
229 Self {
230 enabled: false,
231 strategy: CrossModalStrategy::SharedRepresentation,
232 contrastive_learning: ContrastiveLearningConfig::default(),
233 distillation: DistillationConfig::default(),
234 alignment_learning: AlignmentLearningConfig::default(),
235 shared_representation_dim: 512,
236 }
237 }
238}
239
240impl CrossModalLearningConfig {
241 #[must_use]
243 pub fn contrastive() -> Self {
244 Self {
245 enabled: true,
246 strategy: CrossModalStrategy::Contrastive,
247 contrastive_learning: ContrastiveLearningConfig::strong(),
248 distillation: DistillationConfig::default(),
249 alignment_learning: AlignmentLearningConfig::default(),
250 shared_representation_dim: 256,
251 }
252 }
253
254 #[must_use]
256 pub fn shared_representation() -> Self {
257 Self {
258 enabled: true,
259 strategy: CrossModalStrategy::SharedRepresentation,
260 contrastive_learning: ContrastiveLearningConfig::default(),
261 distillation: DistillationConfig::default(),
262 alignment_learning: AlignmentLearningConfig::default(),
263 shared_representation_dim: 1024,
264 }
265 }
266
267 #[must_use]
269 pub fn alignment() -> Self {
270 Self {
271 enabled: true,
272 strategy: CrossModalStrategy::Alignment,
273 contrastive_learning: ContrastiveLearningConfig::default(),
274 distillation: DistillationConfig::default(),
275 alignment_learning: AlignmentLearningConfig::canonical_correlation(),
276 shared_representation_dim: 512,
277 }
278 }
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct ContrastiveLearningConfig {
284 pub enabled: bool,
286 pub temperature: f64,
288 pub negative_samples: usize,
290 pub hard_negative_mining: bool,
292 pub momentum: f64,
294 pub queue_size: usize,
296 pub projection_dim: usize,
298}
299
300impl Default for ContrastiveLearningConfig {
301 fn default() -> Self {
302 Self {
303 enabled: false,
304 temperature: 0.07,
305 negative_samples: 64,
306 hard_negative_mining: false,
307 momentum: 0.999,
308 queue_size: 4096,
309 projection_dim: 128,
310 }
311 }
312}
313
314impl ContrastiveLearningConfig {
315 #[must_use]
317 pub fn strong() -> Self {
318 Self {
319 enabled: true,
320 temperature: 0.05,
321 negative_samples: 128,
322 hard_negative_mining: true,
323 momentum: 0.9999,
324 queue_size: 8192,
325 projection_dim: 256,
326 }
327 }
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize)]
332pub struct DistillationConfig {
333 pub enabled: bool,
335 pub teacher_weight: f64,
337 pub student_weight: f64,
339 pub temperature: f64,
341 pub feature_matching_weight: f64,
343 pub attention_transfer_weight: f64,
345}
346
347impl Default for DistillationConfig {
348 fn default() -> Self {
349 Self {
350 enabled: false,
351 teacher_weight: 0.7,
352 student_weight: 0.3,
353 temperature: 4.0,
354 feature_matching_weight: 0.1,
355 attention_transfer_weight: 0.1,
356 }
357 }
358}
359
360#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct AlignmentLearningConfig {
363 pub enabled: bool,
365 pub alignment_weight: f64,
367 pub use_cca: bool,
369 pub cca_regularization: f64,
371 pub max_canonical_components: usize,
373 pub adversarial_alignment: bool,
375 pub adversarial_weight: f64,
377}
378
379impl Default for AlignmentLearningConfig {
380 fn default() -> Self {
381 Self {
382 enabled: false,
383 alignment_weight: 1.0,
384 use_cca: false,
385 cca_regularization: 1e-5,
386 max_canonical_components: 100,
387 adversarial_alignment: false,
388 adversarial_weight: 0.1,
389 }
390 }
391}
392
393impl AlignmentLearningConfig {
394 #[must_use]
396 pub fn canonical_correlation() -> Self {
397 Self {
398 enabled: true,
399 alignment_weight: 1.0,
400 use_cca: true,
401 cca_regularization: 1e-4,
402 max_canonical_components: 50,
403 adversarial_alignment: false,
404 adversarial_weight: 0.0,
405 }
406 }
407}
408
409#[derive(Debug, Clone, Serialize, Deserialize)]
411pub struct SynchronizationRequirements {
412 pub sync_accuracy: Duration,
414 pub sync_method: SyncMethod,
416 pub drift_correction: bool,
418 pub max_drift: Duration,
420 pub sync_check_interval: Duration,
422 pub fallback_method: Option<SyncMethod>,
424}
425
426impl Default for SynchronizationRequirements {
427 fn default() -> Self {
428 Self {
429 sync_accuracy: Duration::from_millis(100),
430 sync_method: SyncMethod::Software,
431 drift_correction: true,
432 max_drift: Duration::from_millis(500),
433 sync_check_interval: Duration::from_secs(10),
434 fallback_method: Some(SyncMethod::Software),
435 }
436 }
437}
438
439impl SynchronizationRequirements {
440 #[must_use]
442 pub fn hardware() -> Self {
443 Self {
444 sync_accuracy: Duration::from_micros(100),
445 sync_method: SyncMethod::Hardware,
446 drift_correction: true,
447 max_drift: Duration::from_millis(10),
448 sync_check_interval: Duration::from_secs(1),
449 fallback_method: Some(SyncMethod::Software),
450 }
451 }
452
453 #[must_use]
455 pub fn software() -> Self {
456 Self {
457 sync_accuracy: Duration::from_millis(50),
458 sync_method: SyncMethod::Software,
459 drift_correction: true,
460 max_drift: Duration::from_millis(200),
461 sync_check_interval: Duration::from_secs(5),
462 fallback_method: None,
463 }
464 }
465
466 #[must_use]
468 pub fn gps() -> Self {
469 Self {
470 sync_accuracy: Duration::from_micros(10),
471 sync_method: SyncMethod::GPS,
472 drift_correction: true,
473 max_drift: Duration::from_micros(100),
474 sync_check_interval: Duration::from_millis(100),
475 fallback_method: Some(SyncMethod::NTP),
476 }
477 }
478}
479
480#[derive(Debug, Clone)]
482pub struct MultiModalSample {
483 pub timestamp: SystemTime,
485 pub modality_data: HashMap<Modality, ModalityData>,
487 pub metadata: HashMap<String, String>,
489 pub sync_status: SyncStatus,
491}
492
493impl MultiModalSample {
494 #[must_use]
496 pub fn new(timestamp: SystemTime) -> Self {
497 Self {
498 timestamp,
499 modality_data: HashMap::new(),
500 metadata: HashMap::new(),
501 sync_status: SyncStatus::Unknown,
502 }
503 }
504
505 pub fn add_modality_data(&mut self, modality: Modality, data: ModalityData) {
507 self.modality_data.insert(modality, data);
508 }
509
510 #[must_use]
512 pub fn get_modality_data(&self, modality: &Modality) -> Option<&ModalityData> {
513 self.modality_data.get(modality)
514 }
515
516 #[must_use]
518 pub fn has_modalities(&self, required_modalities: &[Modality]) -> bool {
519 required_modalities
520 .iter()
521 .all(|m| self.modality_data.contains_key(m))
522 }
523
524 #[must_use]
526 pub fn age(&self) -> Duration {
527 SystemTime::now()
528 .duration_since(self.timestamp)
529 .unwrap_or(Duration::from_secs(0))
530 }
531}
532
533#[derive(Debug, Clone)]
535pub struct ModalityData {
536 pub data: Vec<u8>,
538 pub format: String,
540 pub metadata: HashMap<String, TransformParameter>,
542 pub quality_metrics: HashMap<String, f64>,
544}
545
546impl ModalityData {
547 #[must_use]
549 pub fn new(data: Vec<u8>, format: String) -> Self {
550 Self {
551 data,
552 format,
553 metadata: HashMap::new(),
554 quality_metrics: HashMap::new(),
555 }
556 }
557
558 #[must_use]
560 pub fn size(&self) -> usize {
561 self.data.len()
562 }
563
564 pub fn add_quality_metric(&mut self, name: String, value: f64) {
566 self.quality_metrics.insert(name, value);
567 }
568}
569
570#[derive(Debug, Clone, Copy, PartialEq, Eq)]
572pub enum SyncStatus {
573 Unknown,
575 Synchronized,
577 Drift,
579 OutOfSync,
581 Failed,
583}
584
585#[derive(Debug, Clone, PartialEq)]
587pub enum MultiModalError {
588 NoModalities,
590 InvalidWeights { total: f64 },
592 MissingModality(Modality),
594 SyncFailure(String),
596 AlignmentFailure(String),
598 CrossModalError(String),
600 ConfigurationError(String),
602}
603
604impl std::fmt::Display for MultiModalError {
605 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
606 match self {
607 Self::NoModalities => write!(f, "No modalities configured for multi-modal processing"),
608 Self::InvalidWeights { total } => {
609 write!(
610 f,
611 "Invalid modality weights: total weight is {total}, should be ~1.0"
612 )
613 }
614 Self::MissingModality(modality) => {
615 write!(f, "Missing required modality: {modality:?}")
616 }
617 Self::SyncFailure(msg) => write!(f, "Synchronization failure: {msg}"),
618 Self::AlignmentFailure(msg) => write!(f, "Temporal alignment failure: {msg}"),
619 Self::CrossModalError(msg) => write!(f, "Cross-modal learning error: {msg}"),
620 Self::ConfigurationError(msg) => write!(f, "Configuration error: {msg}"),
621 }
622 }
623}
624
625impl std::error::Error for MultiModalError {}
626
627#[allow(non_snake_case)]
628#[cfg(test)]
629mod tests {
630 use super::*;
631
632 #[test]
633 fn test_multimodal_config_presets() {
634 let vision_audio = MultiModalConfig::vision_audio();
635 assert!(vision_audio.enabled);
636 assert_eq!(vision_audio.modalities.len(), 2);
637 assert!(vision_audio.modalities.contains(&Modality::Visual));
638 assert!(vision_audio.modalities.contains(&Modality::Audio));
639
640 let vision_depth = MultiModalConfig::vision_depth();
641 assert_eq!(vision_depth.fusion_strategy, FusionStrategy::HybridFusion);
642
643 let multi_sensor = MultiModalConfig::multi_sensor();
644 assert_eq!(multi_sensor.modalities.len(), 3);
645 assert_eq!(
646 multi_sensor.fusion_strategy,
647 FusionStrategy::AttentionFusion
648 );
649 }
650
651 #[test]
652 fn test_temporal_alignment_config() {
653 let strict = TemporalAlignmentConfig::strict();
654 assert_eq!(strict.sync_method, SyncMethod::Hardware);
655 assert!(strict.predictive_alignment);
656
657 let relaxed = TemporalAlignmentConfig::relaxed();
658 assert_eq!(relaxed.sync_method, SyncMethod::Software);
659 assert!(!relaxed.predictive_alignment);
660
661 let precise = TemporalAlignmentConfig::precise();
662 assert_eq!(precise.sync_method, SyncMethod::GPS);
663 assert!(precise.predictive_alignment);
664 }
665
666 #[test]
667 fn test_cross_modal_learning_config() {
668 let contrastive = CrossModalLearningConfig::contrastive();
669 assert_eq!(contrastive.strategy, CrossModalStrategy::Contrastive);
670 assert!(contrastive.contrastive_learning.enabled);
671
672 let shared_rep = CrossModalLearningConfig::shared_representation();
673 assert_eq!(
674 shared_rep.strategy,
675 CrossModalStrategy::SharedRepresentation
676 );
677 assert_eq!(shared_rep.shared_representation_dim, 1024);
678
679 let alignment = CrossModalLearningConfig::alignment();
680 assert_eq!(alignment.strategy, CrossModalStrategy::Alignment);
681 assert!(alignment.alignment_learning.use_cca);
682 }
683
684 #[test]
685 fn test_synchronization_requirements() {
686 let hardware = SynchronizationRequirements::hardware();
687 assert_eq!(hardware.sync_method, SyncMethod::Hardware);
688 assert_eq!(hardware.sync_accuracy, Duration::from_micros(100));
689
690 let software = SynchronizationRequirements::software();
691 assert_eq!(software.sync_method, SyncMethod::Software);
692 assert_eq!(software.sync_accuracy, Duration::from_millis(50));
693
694 let gps = SynchronizationRequirements::gps();
695 assert_eq!(gps.sync_method, SyncMethod::GPS);
696 assert_eq!(gps.sync_accuracy, Duration::from_micros(10));
697 }
698
699 #[test]
700 fn test_multimodal_sample() {
701 let mut sample = MultiModalSample::new(SystemTime::now());
702
703 let visual_data = ModalityData::new(vec![1, 2, 3, 4], "jpeg".to_string());
704 sample.add_modality_data(Modality::Visual, visual_data);
705
706 let audio_data = ModalityData::new(vec![5, 6, 7, 8], "wav".to_string());
707 sample.add_modality_data(Modality::Audio, audio_data);
708
709 assert!(sample.has_modalities(&[Modality::Visual, Modality::Audio]));
710 assert!(!sample.has_modalities(&[Modality::Visual, Modality::Audio, Modality::Depth]));
711
712 assert!(sample.get_modality_data(&Modality::Visual).is_some());
713 assert!(sample.get_modality_data(&Modality::Depth).is_none());
714 }
715
716 #[test]
717 fn test_modality_weights() {
718 let mut config = MultiModalConfig::default();
719 config.set_modality_weight(Modality::Visual, 0.6);
720 config.set_modality_weight(Modality::Audio, 0.4);
721
722 assert_eq!(config.get_modality_weight(&Modality::Visual), 0.6);
723 assert_eq!(config.get_modality_weight(&Modality::Audio), 0.4);
724 assert_eq!(config.get_modality_weight(&Modality::Depth), 1.0); }
726
727 #[test]
728 fn test_multimodal_error_display() {
729 let error = MultiModalError::InvalidWeights { total: 1.5 };
730 let error_str = error.to_string();
731 assert!(error_str.contains("Invalid modality weights"));
732 assert!(error_str.contains("1.5"));
733
734 let error = MultiModalError::MissingModality(Modality::Audio);
735 let error_str = error.to_string();
736 assert!(error_str.contains("Missing required modality"));
737 assert!(error_str.contains("Audio"));
738 }
739}