1#![allow(dead_code)]
17use crate::parameter::Parameter;
18use crate::{GraphData, GraphLayer};
19use std::collections::{HashMap, HashSet};
20use torsh_tensor::{
21 creation::{from_vec, ones, randn, zeros},
22 Tensor,
23};
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum Modality {
28 Text,
29 Image,
30 Audio,
31 Tabular,
32 Graph,
33 Video,
34 TimeSeries,
35}
36
37#[derive(Debug, Clone)]
39pub struct MultiModalNodeData {
40 pub modalities: HashMap<Modality, Tensor>,
41 pub node_id: usize,
42 pub labels: Option<Tensor>,
43}
44
45impl MultiModalNodeData {
46 pub fn new(node_id: usize) -> Self {
48 Self {
49 modalities: HashMap::new(),
50 node_id,
51 labels: None,
52 }
53 }
54
55 pub fn add_modality(mut self, modality: Modality, data: Tensor) -> Self {
57 self.modalities.insert(modality, data);
58 self
59 }
60
61 pub fn with_labels(mut self, labels: Tensor) -> Self {
63 self.labels = Some(labels);
64 self
65 }
66
67 pub fn available_modalities(&self) -> Vec<Modality> {
69 self.modalities.keys().copied().collect()
70 }
71
72 pub fn has_modality(&self, modality: Modality) -> bool {
74 self.modalities.contains_key(&modality)
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct MultiModalGraphData {
81 pub graph: GraphData,
83 pub node_data: HashMap<usize, MultiModalNodeData>,
85 pub available_modalities: HashSet<Modality>,
87 pub modality_dims: HashMap<Modality, usize>,
89}
90
91impl MultiModalGraphData {
92 pub fn new(graph: GraphData) -> Self {
94 Self {
95 graph,
96 node_data: HashMap::new(),
97 available_modalities: HashSet::new(),
98 modality_dims: HashMap::new(),
99 }
100 }
101
102 pub fn add_node_data(&mut self, node_data: MultiModalNodeData) {
104 let node_id = node_data.node_id;
105
106 for modality in node_data.available_modalities() {
108 self.available_modalities.insert(modality);
109
110 if let Some(data) = node_data.modalities.get(&modality) {
112 let dim = data.shape().dims().iter().product::<usize>();
113 self.modality_dims.insert(modality, dim);
114 }
115 }
116
117 self.node_data.insert(node_id, node_data);
118 }
119
120 pub fn get_modality_data(&self, modality: Modality) -> Vec<(usize, &Tensor)> {
122 self.node_data
123 .iter()
124 .filter_map(|(&node_id, data)| {
125 data.modalities
126 .get(&modality)
127 .map(|tensor| (node_id, tensor))
128 })
129 .collect()
130 }
131
132 pub fn get_complete_nodes(&self, modalities: &[Modality]) -> Vec<usize> {
134 self.node_data
135 .iter()
136 .filter(|(_, data)| {
137 modalities
138 .iter()
139 .all(|&modality| data.has_modality(modality))
140 })
141 .map(|(&node_id, _)| node_id)
142 .collect()
143 }
144
145 pub fn modality_statistics(&self) -> HashMap<Modality, f32> {
147 let total_nodes = self.graph.num_nodes;
148 let mut stats = HashMap::new();
149
150 for &modality in &self.available_modalities {
151 let count = self
152 .node_data
153 .values()
154 .filter(|data| data.has_modality(modality))
155 .count();
156
157 let coverage = count as f32 / total_nodes as f32;
158 stats.insert(modality, coverage);
159 }
160
161 stats
162 }
163}
164
165#[derive(Debug)]
167pub struct CrossModalGraphAttention {
168 modalities: Vec<Modality>,
169 feature_dim: usize,
170 attention_dim: usize,
171 num_heads: usize,
172
173 modality_projections: HashMap<Modality, Parameter>,
175
176 query_weights: Parameter,
178 key_weights: Parameter,
179 value_weights: Parameter,
180
181 output_projection: Parameter,
183
184 layer_norm_weight: Parameter,
186 layer_norm_bias: Parameter,
187
188 dropout: f32,
189}
190
191impl CrossModalGraphAttention {
192 pub fn new(
194 modalities: Vec<Modality>,
195 modality_dims: HashMap<Modality, usize>,
196 feature_dim: usize,
197 attention_dim: usize,
198 num_heads: usize,
199 dropout: f32,
200 ) -> Self {
201 let mut modality_projections = HashMap::new();
202
203 for modality in &modalities {
205 let input_dim = modality_dims.get(modality).copied().unwrap_or(feature_dim);
206 modality_projections.insert(
207 *modality,
208 Parameter::new(
209 randn(&[input_dim, feature_dim])
210 .expect("failed to create modality projection tensor"),
211 ),
212 );
213 }
214
215 let query_weights = Parameter::new(
216 randn(&[feature_dim, attention_dim]).expect("failed to create query_weights tensor"),
217 );
218 let key_weights = Parameter::new(
219 randn(&[feature_dim, attention_dim]).expect("failed to create key_weights tensor"),
220 );
221 let value_weights = Parameter::new(
222 randn(&[feature_dim, attention_dim]).expect("failed to create value_weights tensor"),
223 );
224 let output_projection = Parameter::new(
225 randn(&[attention_dim, feature_dim])
226 .expect("failed to create output_projection tensor"),
227 );
228
229 let layer_norm_weight = Parameter::new(
230 ones(&[feature_dim]).expect("failed to create layer_norm_weight tensor"),
231 );
232 let layer_norm_bias = Parameter::new(
233 zeros::<f32>(&[feature_dim]).expect("failed to create layer_norm_bias tensor"),
234 );
235
236 Self {
237 modalities,
238 feature_dim,
239 attention_dim,
240 num_heads,
241 modality_projections,
242 query_weights,
243 key_weights,
244 value_weights,
245 output_projection,
246 layer_norm_weight,
247 layer_norm_bias,
248 dropout,
249 }
250 }
251
252 pub fn forward(&self, mm_graph: &MultiModalGraphData) -> Tensor {
254 let num_nodes = mm_graph.graph.num_nodes;
255
256 let mut modality_features = HashMap::new();
258
259 for &modality in &self.modalities {
260 let projection = &self.modality_projections[&modality];
261 let modality_data = mm_graph.get_modality_data(modality);
262
263 if !modality_data.is_empty() {
264 let features =
265 self.project_modality_features(&modality_data, projection, num_nodes);
266 modality_features.insert(modality, features);
267 }
268 }
269
270 let attended_features = self.apply_cross_modal_attention(&modality_features);
272
273 self.layer_norm(&attended_features)
275 }
276
277 fn project_modality_features(
279 &self,
280 modality_data: &[(usize, &Tensor)],
281 projection: &Parameter,
282 num_nodes: usize,
283 ) -> Tensor {
284 let mut projected_data = vec![0.0f32; num_nodes * self.feature_dim];
285
286 for &(node_id, features) in modality_data {
287 if node_id < num_nodes {
288 let feature_data = features.to_vec().expect("conversion should succeed");
289 let input_tensor = from_vec(
290 feature_data,
291 &[1, features.shape().dims().iter().product::<usize>()],
292 torsh_core::device::DeviceType::Cpu,
293 )
294 .expect("input tensor creation should succeed");
295
296 let projected = input_tensor
297 .matmul(&projection.clone_data())
298 .expect("operation should succeed");
299 let projected_data_vec = projected.to_vec().expect("conversion should succeed");
300
301 for (i, &val) in projected_data_vec.iter().enumerate() {
302 if i < self.feature_dim {
303 projected_data[node_id * self.feature_dim + i] = val;
304 }
305 }
306 }
307 }
308
309 from_vec(
310 projected_data,
311 &[num_nodes, self.feature_dim],
312 torsh_core::device::DeviceType::Cpu,
313 )
314 .expect("projected features tensor creation should succeed")
315 }
316
317 fn apply_cross_modal_attention(&self, modality_features: &HashMap<Modality, Tensor>) -> Tensor {
319 if modality_features.is_empty() {
320 return zeros::<f32>(&[1, self.feature_dim])
321 .expect("empty attention features tensor creation should succeed");
322 }
323
324 let first_modality = modality_features
326 .keys()
327 .next()
328 .expect("modality_features should not be empty");
329 let base_features = &modality_features[first_modality];
330 let _num_nodes = base_features.shape().dims()[0];
331
332 let queries = base_features
334 .matmul(&self.query_weights.clone_data())
335 .expect("operation should succeed");
336 let _keys = base_features
337 .matmul(&self.key_weights.clone_data())
338 .expect("operation should succeed");
339 let values = base_features
340 .matmul(&self.value_weights.clone_data())
341 .expect("operation should succeed");
342
343 let mut attended_values = values.clone();
345
346 for (modality, features) in modality_features {
347 if *modality != *first_modality {
348 let modal_keys = features
349 .matmul(&self.key_weights.clone_data())
350 .expect("operation should succeed");
351 let modal_values = features
352 .matmul(&self.value_weights.clone_data())
353 .expect("operation should succeed");
354
355 let attention_scores = queries
357 .matmul(&modal_keys.t().expect("operation should succeed"))
358 .expect("operation should succeed");
359 let attention_weights = self.softmax(&attention_scores);
360 let attended = attention_weights
361 .matmul(&modal_values)
362 .expect("operation should succeed");
363
364 attended_values = attended_values
365 .add(&attended)
366 .expect("operation should succeed");
367 }
368 }
369
370 attended_values
372 .matmul(&self.output_projection.clone_data())
373 .expect("operation should succeed")
374 }
375
376 fn softmax(&self, x: &Tensor) -> Tensor {
378 let data = x.to_vec().expect("conversion should succeed");
379 let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
380
381 let exp_data: Vec<f32> = data.iter().map(|&val| (val - max_val).exp()).collect();
382 let sum_exp: f32 = exp_data.iter().sum();
383
384 let softmax_data: Vec<f32> = exp_data.iter().map(|&val| val / sum_exp).collect();
385
386 from_vec(
387 softmax_data,
388 x.shape().dims(),
389 torsh_core::device::DeviceType::Cpu,
390 )
391 .expect("softmax tensor creation should succeed")
392 }
393
394 fn layer_norm(&self, x: &Tensor) -> Tensor {
396 let data = x.to_vec().expect("conversion should succeed");
397 let num_features = self.feature_dim;
398 let num_samples = data.len() / num_features;
399
400 let mut normalized_data = Vec::new();
401
402 for sample in 0..num_samples {
403 let start_idx = sample * num_features;
404 let end_idx = start_idx + num_features;
405 let sample_data = &data[start_idx..end_idx];
406
407 let mean: f32 = sample_data.iter().sum::<f32>() / num_features as f32;
409 let variance: f32 =
410 sample_data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / num_features as f32;
411 let std = (variance + 1e-5).sqrt();
412
413 for &val in sample_data {
415 let normalized = (val - mean) / std;
416 normalized_data.push(normalized);
417 }
418 }
419
420 let normalized_tensor = from_vec(
421 normalized_data,
422 x.shape().dims(),
423 torsh_core::device::DeviceType::Cpu,
424 )
425 .expect("normalized tensor creation should succeed");
426
427 normalized_tensor
429 .mul(&self.layer_norm_weight.clone_data())
430 .expect("operation should succeed")
431 .add(&self.layer_norm_bias.clone_data())
432 .expect("operation should succeed")
433 }
434}
435
436impl GraphLayer for CrossModalGraphAttention {
437 fn forward(&self, graph: &GraphData) -> GraphData {
438 let mut mm_graph = MultiModalGraphData::new(graph.clone());
440
441 for node_id in 0..graph.num_nodes {
442 let node_features = graph
443 .x
444 .slice_tensor(0, node_id, node_id + 1)
445 .expect("node feature slice should succeed");
446 let node_data =
447 MultiModalNodeData::new(node_id).add_modality(Modality::Graph, node_features);
448 mm_graph.add_node_data(node_data);
449 }
450
451 let output_features = self.forward(&mm_graph);
452
453 let mut output_graph = graph.clone();
454 output_graph.x = output_features;
455 output_graph
456 }
457
458 fn parameters(&self) -> Vec<Tensor> {
459 let mut params = vec![
460 self.query_weights.clone_data(),
461 self.key_weights.clone_data(),
462 self.value_weights.clone_data(),
463 self.output_projection.clone_data(),
464 self.layer_norm_weight.clone_data(),
465 self.layer_norm_bias.clone_data(),
466 ];
467
468 for projection in self.modality_projections.values() {
469 params.push(projection.clone_data());
470 }
471
472 params
473 }
474}
475
476#[derive(Debug)]
478pub struct MultiModalFusion {
479 fusion_strategy: FusionStrategy,
480 modalities: Vec<Modality>,
481 feature_dim: usize,
482 fusion_weights: Option<Parameter>,
483 gating_network: Option<Vec<Parameter>>,
484}
485
486#[derive(Debug, Clone, Copy)]
487pub enum FusionStrategy {
488 Concatenation,
489 ElementwiseSum,
490 WeightedSum,
491 AttentionFusion,
492 GatedFusion,
493}
494
495impl MultiModalFusion {
496 pub fn new(
498 fusion_strategy: FusionStrategy,
499 modalities: Vec<Modality>,
500 feature_dim: usize,
501 ) -> Self {
502 let fusion_weights = match fusion_strategy {
503 FusionStrategy::WeightedSum => Some(Parameter::new(
504 ones(&[modalities.len()]).expect("failed to create fusion_weights tensor"),
505 )),
506 _ => None,
507 };
508
509 let gating_network = match fusion_strategy {
510 FusionStrategy::GatedFusion => {
511 let mut gates = Vec::new();
512 for _ in 0..modalities.len() {
513 gates.push(Parameter::new(
514 randn(&[feature_dim, 1]).expect("failed to create gate tensor"),
515 ));
516 }
517 Some(gates)
518 }
519 _ => None,
520 };
521
522 Self {
523 fusion_strategy,
524 modalities,
525 feature_dim,
526 fusion_weights,
527 gating_network,
528 }
529 }
530
531 pub fn fuse_features(&self, modality_features: &HashMap<Modality, Tensor>) -> Tensor {
533 match self.fusion_strategy {
534 FusionStrategy::Concatenation => self.concatenate_features(modality_features),
535 FusionStrategy::ElementwiseSum => self.elementwise_sum_features(modality_features),
536 FusionStrategy::WeightedSum => self.weighted_sum_features(modality_features),
537 FusionStrategy::AttentionFusion => self.attention_fusion_features(modality_features),
538 FusionStrategy::GatedFusion => self.gated_fusion_features(modality_features),
539 }
540 }
541
542 fn concatenate_features(&self, modality_features: &HashMap<Modality, Tensor>) -> Tensor {
544 let mut concatenated_data = Vec::new();
545
546 for &modality in &self.modalities {
547 if let Some(features) = modality_features.get(&modality) {
548 concatenated_data.extend(features.to_vec().expect("conversion should succeed"));
549 } else {
550 concatenated_data.extend(vec![0.0f32; self.feature_dim]);
552 }
553 }
554
555 let num_nodes = modality_features
556 .values()
557 .next()
558 .map(|t| t.shape().dims()[0])
559 .unwrap_or(1);
560
561 from_vec(
562 concatenated_data,
563 &[num_nodes, self.modalities.len() * self.feature_dim],
564 torsh_core::device::DeviceType::Cpu,
565 )
566 .expect("concatenated features tensor creation should succeed")
567 }
568
569 fn elementwise_sum_features(&self, modality_features: &HashMap<Modality, Tensor>) -> Tensor {
571 let mut sum_features: Option<Tensor> = None;
572
573 for &modality in &self.modalities {
574 if let Some(features) = modality_features.get(&modality) {
575 if let Some(ref sum) = sum_features {
576 sum_features = Some(sum.add(features).expect("operation should succeed"));
577 } else {
578 sum_features = Some(features.clone());
579 }
580 }
581 }
582
583 sum_features.unwrap_or_else(|| {
584 zeros::<f32>(&[1, self.feature_dim])
585 .expect("fallback sum features tensor creation should succeed")
586 })
587 }
588
589 fn weighted_sum_features(&self, modality_features: &HashMap<Modality, Tensor>) -> Tensor {
591 let weights = self
592 .fusion_weights
593 .as_ref()
594 .expect("fusion weights should be present for weighted sum")
595 .clone_data()
596 .to_vec()
597 .expect("fusion weights conversion should succeed");
598 let mut weighted_sum: Option<Tensor> = None;
599
600 for (i, &modality) in self.modalities.iter().enumerate() {
601 if let Some(features) = modality_features.get(&modality) {
602 let weight = weights.get(i).copied().unwrap_or(1.0);
603 let weighted_features = features
604 .mul_scalar(weight)
605 .expect("operation should succeed");
606
607 if let Some(ref sum) = weighted_sum {
608 weighted_sum = Some(
609 sum.add(&weighted_features)
610 .expect("operation should succeed"),
611 );
612 } else {
613 weighted_sum = Some(weighted_features);
614 }
615 }
616 }
617
618 weighted_sum.unwrap_or_else(|| {
619 zeros::<f32>(&[1, self.feature_dim])
620 .expect("fallback weighted sum tensor creation should succeed")
621 })
622 }
623
624 fn attention_fusion_features(&self, modality_features: &HashMap<Modality, Tensor>) -> Tensor {
626 let available_features: Vec<&Tensor> = modality_features.values().collect();
628
629 if available_features.is_empty() {
630 return zeros::<f32>(&[1, self.feature_dim])
631 .expect("empty attention fusion tensor creation should succeed");
632 }
633
634 let mut attention_weights = Vec::new();
636 let mut total_norm = 0.0;
637
638 for features in &available_features {
639 let data = features.to_vec().expect("conversion should succeed");
640 let norm: f32 = data.iter().map(|&x| x * x).sum::<f32>().sqrt();
641 attention_weights.push(norm);
642 total_norm += norm;
643 }
644
645 if total_norm > 0.0 {
647 for weight in &mut attention_weights {
648 *weight /= total_norm;
649 }
650 }
651
652 let mut attended_features: Option<Tensor> = None;
654 for (features, &weight) in available_features.iter().zip(attention_weights.iter()) {
655 let weighted = features
656 .mul_scalar(weight)
657 .expect("operation should succeed");
658
659 if let Some(ref sum) = attended_features {
660 attended_features = Some(sum.add(&weighted).expect("operation should succeed"));
661 } else {
662 attended_features = Some(weighted);
663 }
664 }
665
666 attended_features.unwrap_or_else(|| {
667 zeros::<f32>(&[1, self.feature_dim])
668 .expect("fallback attended features tensor creation should succeed")
669 })
670 }
671
672 fn gated_fusion_features(&self, modality_features: &HashMap<Modality, Tensor>) -> Tensor {
674 let gates = self
675 .gating_network
676 .as_ref()
677 .expect("gating network should be present for gated fusion");
678 let mut gated_sum: Option<Tensor> = None;
679
680 for (i, &modality) in self.modalities.iter().enumerate() {
681 if let Some(features) = modality_features.get(&modality) {
682 let gate = &gates[i];
683 let gate_values = features
684 .matmul(&gate.clone_data())
685 .expect("operation should succeed");
686 let gate_probs = self.sigmoid(&gate_values);
687
688 let gated_features = features.mul(&gate_probs).expect("operation should succeed");
690
691 if let Some(ref sum) = gated_sum {
692 gated_sum = Some(sum.add(&gated_features).expect("operation should succeed"));
693 } else {
694 gated_sum = Some(gated_features);
695 }
696 }
697 }
698
699 gated_sum.unwrap_or_else(|| {
700 zeros::<f32>(&[1, self.feature_dim])
701 .expect("fallback gated sum tensor creation should succeed")
702 })
703 }
704
705 fn sigmoid(&self, x: &Tensor) -> Tensor {
707 let data = x.to_vec().expect("conversion should succeed");
708 let sigmoid_data: Vec<f32> = data.iter().map(|&val| 1.0 / (1.0 + (-val).exp())).collect();
709
710 from_vec(
711 sigmoid_data,
712 x.shape().dims(),
713 torsh_core::device::DeviceType::Cpu,
714 )
715 .expect("sigmoid tensor creation should succeed")
716 }
717}
718
719#[derive(Debug)]
721pub struct MultiModalContrastiveLearning {
722 temperature: f32,
723 projection_dim: usize,
724 modality_projectors: HashMap<Modality, Parameter>,
725}
726
727impl MultiModalContrastiveLearning {
728 pub fn new(
730 modalities: Vec<Modality>,
731 modality_dims: HashMap<Modality, usize>,
732 projection_dim: usize,
733 temperature: f32,
734 ) -> Self {
735 let mut modality_projectors = HashMap::new();
736
737 for modality in modalities {
738 let input_dim = modality_dims.get(&modality).copied().unwrap_or(128);
739 modality_projectors.insert(
740 modality,
741 Parameter::new(
742 randn(&[input_dim, projection_dim])
743 .expect("failed to create modality projector tensor"),
744 ),
745 );
746 }
747
748 Self {
749 temperature,
750 projection_dim,
751 modality_projectors,
752 }
753 }
754
755 pub fn contrastive_loss(
757 &self,
758 modality1: Modality,
759 features1: &Tensor,
760 modality2: Modality,
761 features2: &Tensor,
762 ) -> f32 {
763 let proj1 = features1
765 .matmul(&self.modality_projectors[&modality1].clone_data())
766 .expect("operation should succeed");
767 let proj2 = features2
768 .matmul(&self.modality_projectors[&modality2].clone_data())
769 .expect("operation should succeed");
770
771 let similarity = proj1
773 .matmul(&proj2.t().expect("operation should succeed"))
774 .expect("operation should succeed");
775 let scaled_similarity = similarity
776 .div_scalar(self.temperature)
777 .expect("operation should succeed");
778
779 let sim_data = scaled_similarity
781 .to_vec()
782 .expect("conversion should succeed");
783 let max_sim = sim_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
784 let exp_sims: Vec<f32> = sim_data.iter().map(|&x| (x - max_sim).exp()).collect();
785 let sum_exp: f32 = exp_sims.iter().sum();
786
787 let num_samples = proj1.shape().dims()[0];
789 let mut loss = 0.0;
790
791 for i in 0..num_samples {
792 let positive_sim = exp_sims[i * num_samples + i];
793 loss -= (positive_sim / sum_exp).ln();
794 }
795
796 loss / num_samples as f32
797 }
798
799 pub fn generate_contrastive_pairs(
801 &self,
802 mm_graph: &MultiModalGraphData,
803 modality1: Modality,
804 modality2: Modality,
805 ) -> Vec<(Tensor, Tensor, bool)> {
806 let mut pairs = Vec::new();
807
808 let data1 = mm_graph.get_modality_data(modality1);
809 let data2 = mm_graph.get_modality_data(modality2);
810
811 for (node_id, features1) in &data1 {
813 if let Some((_, features2)) = data2.iter().find(|(id, _)| id == node_id) {
814 pairs.push(((*features1).clone(), (*features2).clone(), true));
815 }
816 }
817
818 for (node_id1, features1) in &data1 {
820 for (node_id2, features2) in &data2 {
821 if node_id1 != node_id2 {
822 pairs.push(((*features1).clone(), (*features2).clone(), false));
823
824 if pairs.len() > 1000 {
826 break;
827 }
828 }
829 }
830 if pairs.len() > 1000 {
831 break;
832 }
833 }
834
835 pairs
836 }
837}
838
839pub mod utils {
841 use super::*;
842
843 pub fn create_synthetic_multimodal_graph(
845 num_nodes: usize,
846 base_feature_dim: usize,
847 modalities: Vec<Modality>,
848 ) -> MultiModalGraphData {
849 let mut rng = scirs2_core::random::thread_rng();
850
851 let base_features = randn(&[num_nodes, base_feature_dim])
853 .expect("base features tensor creation should succeed");
854 let mut edge_data = Vec::new();
855
856 for _ in 0..(num_nodes * 2) {
857 let src = rng.gen_range(0..num_nodes) as f32;
858 let dst = rng.gen_range(0..num_nodes) as f32;
859 edge_data.push(src);
860 edge_data.push(dst);
861 }
862
863 let edge_index = from_vec(
864 edge_data,
865 &[2, num_nodes * 2],
866 torsh_core::device::DeviceType::Cpu,
867 )
868 .expect("edge index tensor creation should succeed");
869
870 let graph = GraphData::new(base_features, edge_index);
871 let mut mm_graph = MultiModalGraphData::new(graph);
872
873 for node_id in 0..num_nodes {
875 let mut node_data = MultiModalNodeData::new(node_id);
876
877 for &modality in &modalities {
878 let feature_dim = match modality {
880 Modality::Text => 768, Modality::Image => 2048, Modality::Audio => 128, Modality::Tabular => 64, Modality::Graph => base_feature_dim,
885 Modality::Video => 1024, Modality::TimeSeries => 256, };
888
889 if rng.gen_range(0.0..1.0) < 0.8 {
891 let features = randn(&[feature_dim])
892 .expect("modality features tensor creation should succeed");
893 node_data = node_data.add_modality(modality, features);
894 }
895 }
896
897 mm_graph.add_node_data(node_data);
898 }
899
900 mm_graph
901 }
902
903 pub fn evaluate_multimodal_quality(
905 mm_graph: &MultiModalGraphData,
906 representations: &HashMap<Modality, Tensor>,
907 ) -> HashMap<String, f32> {
908 let mut metrics = HashMap::new();
909
910 let modality_stats = mm_graph.modality_statistics();
912 for (modality, coverage) in modality_stats {
913 metrics.insert(format!("{:?}_coverage", modality), coverage);
914 }
915
916 for (modality, tensor) in representations {
918 let data = tensor.to_vec().expect("conversion should succeed");
919 let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
920 let variance: f32 =
921 data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
922
923 metrics.insert(format!("{:?}_mean", modality), mean);
924 metrics.insert(format!("{:?}_variance", modality), variance);
925 }
926
927 if representations.len() > 1 {
929 let modalities: Vec<_> = representations.keys().collect();
930 for i in 0..modalities.len() {
931 for j in (i + 1)..modalities.len() {
932 let rep1 = &representations[modalities[i]];
933 let rep2 = &representations[modalities[j]];
934
935 let consistency = compute_tensor_similarity(rep1, rep2);
936 metrics.insert(
937 format!("{:?}_{:?}_consistency", modalities[i], modalities[j]),
938 consistency,
939 );
940 }
941 }
942 }
943
944 metrics
945 }
946
947 fn compute_tensor_similarity(tensor1: &Tensor, tensor2: &Tensor) -> f32 {
949 let data1 = tensor1.to_vec().expect("conversion should succeed");
950 let data2 = tensor2.to_vec().expect("conversion should succeed");
951
952 if data1.len() != data2.len() {
953 return 0.0;
954 }
955
956 let dot_product: f32 = data1.iter().zip(data2.iter()).map(|(&a, &b)| a * b).sum();
958 let norm1: f32 = data1.iter().map(|&x| x * x).sum::<f32>().sqrt();
959 let norm2: f32 = data2.iter().map(|&x| x * x).sum::<f32>().sqrt();
960
961 if norm1 > 0.0 && norm2 > 0.0 {
962 dot_product / (norm1 * norm2)
963 } else {
964 0.0
965 }
966 }
967
968 pub fn generate_alignment_tasks(
970 mm_graph: &MultiModalGraphData,
971 source_modality: Modality,
972 target_modality: Modality,
973 num_tasks: usize,
974 ) -> Vec<(usize, Tensor, Tensor)> {
975 let source_data = mm_graph.get_modality_data(source_modality);
976 let target_data = mm_graph.get_modality_data(target_modality);
977
978 let mut tasks = Vec::new();
979 let mut rng = scirs2_core::random::thread_rng();
980
981 let common_nodes: Vec<usize> = source_data
983 .iter()
984 .filter_map(|&(node_id, _)| {
985 if target_data.iter().any(|&(id, _)| id == node_id) {
986 Some(node_id)
987 } else {
988 None
989 }
990 })
991 .collect();
992
993 for _ in 0..num_tasks.min(common_nodes.len()) {
994 let &node_id = common_nodes
995 .choose(&mut rng)
996 .expect("collection should not be empty");
997
998 let source_features = source_data
999 .iter()
1000 .find(|&&(id, _)| id == node_id)
1001 .map(|(_, tensor)| (*tensor).clone())
1002 .expect("value should be present");
1003
1004 let target_features = target_data
1005 .iter()
1006 .find(|&&(id, _)| id == node_id)
1007 .map(|(_, tensor)| (*tensor).clone())
1008 .expect("value should be present");
1009
1010 tasks.push((node_id, source_features, target_features));
1011 }
1012
1013 tasks
1014 }
1015}
1016
1017trait RandomChoice<T> {
1019 fn choose(
1020 &self,
1021 rng: &mut scirs2_core::random::CoreRandom<scirs2_core::rngs::ThreadRng>,
1022 ) -> Option<&T>;
1023}
1024
1025impl<T> RandomChoice<T> for Vec<T> {
1026 fn choose(
1027 &self,
1028 rng: &mut scirs2_core::random::CoreRandom<scirs2_core::rngs::ThreadRng>,
1029 ) -> Option<&T> {
1030 if self.is_empty() {
1031 None
1032 } else {
1033 let index = rng.gen_range(0..self.len());
1034 self.get(index)
1035 }
1036 }
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041 use super::*;
1042 use torsh_core::device::DeviceType;
1043
1044 #[test]
1045 fn test_multimodal_node_data_creation() {
1046 let text_features = randn(&[768]).unwrap();
1047 let image_features = randn(&[2048]).unwrap();
1048
1049 let node_data = MultiModalNodeData::new(0)
1050 .add_modality(Modality::Text, text_features)
1051 .add_modality(Modality::Image, image_features);
1052
1053 assert_eq!(node_data.node_id, 0);
1054 assert!(node_data.has_modality(Modality::Text));
1055 assert!(node_data.has_modality(Modality::Image));
1056 assert!(!node_data.has_modality(Modality::Audio));
1057 assert_eq!(node_data.available_modalities().len(), 2);
1058 }
1059
1060 #[test]
1061 fn test_multimodal_graph_data() {
1062 let features = randn(&[3, 4]).unwrap();
1063 let edges = vec![0.0, 1.0, 1.0, 2.0];
1064 let edge_index = from_vec(edges, &[2, 2], DeviceType::Cpu).unwrap();
1065 let graph = GraphData::new(features, edge_index);
1066
1067 let mut mm_graph = MultiModalGraphData::new(graph);
1068
1069 for i in 0..3 {
1071 let node_data = MultiModalNodeData::new(i)
1072 .add_modality(Modality::Text, randn(&[768]).unwrap())
1073 .add_modality(Modality::Image, randn(&[2048]).unwrap());
1074 mm_graph.add_node_data(node_data);
1075 }
1076
1077 assert_eq!(mm_graph.available_modalities.len(), 2);
1078 assert_eq!(mm_graph.get_modality_data(Modality::Text).len(), 3);
1079 assert_eq!(
1080 mm_graph
1081 .get_complete_nodes(&[Modality::Text, Modality::Image])
1082 .len(),
1083 3
1084 );
1085
1086 let stats = mm_graph.modality_statistics();
1087 assert_eq!(stats[&Modality::Text], 1.0); assert_eq!(stats[&Modality::Image], 1.0); }
1090
1091 #[test]
1092 fn test_cross_modal_attention() {
1093 let modalities = vec![Modality::Text, Modality::Image];
1094 let mut modality_dims = HashMap::new();
1095 modality_dims.insert(Modality::Text, 768);
1096 modality_dims.insert(Modality::Image, 2048);
1097
1098 let attention = CrossModalGraphAttention::new(
1099 modalities,
1100 modality_dims,
1101 256, 128, 4, 0.1, );
1106
1107 assert_eq!(attention.feature_dim, 256);
1108 assert_eq!(attention.attention_dim, 128);
1109 assert_eq!(attention.num_heads, 4);
1110 }
1111
1112 #[test]
1113 fn test_multimodal_fusion() {
1114 let modalities = vec![Modality::Text, Modality::Image];
1115 let fusion = MultiModalFusion::new(FusionStrategy::WeightedSum, modalities, 128);
1116
1117 let mut modality_features = HashMap::new();
1118 modality_features.insert(Modality::Text, randn(&[3, 128]).unwrap());
1119 modality_features.insert(Modality::Image, randn(&[3, 128]).unwrap());
1120
1121 let fused = fusion.fuse_features(&modality_features);
1122 assert_eq!(fused.shape().dims(), &[3, 128]);
1123 }
1124
1125 #[test]
1126 fn test_contrastive_learning() {
1127 let modalities = vec![Modality::Text, Modality::Image];
1128 let mut modality_dims = HashMap::new();
1129 modality_dims.insert(Modality::Text, 768);
1130 modality_dims.insert(Modality::Image, 2048);
1131
1132 let contrastive = MultiModalContrastiveLearning::new(
1133 modalities,
1134 modality_dims,
1135 256, 0.07, );
1138
1139 let text_features = randn(&[4, 768]).unwrap();
1140 let image_features = randn(&[4, 2048]).unwrap();
1141
1142 let loss = contrastive.contrastive_loss(
1143 Modality::Text,
1144 &text_features,
1145 Modality::Image,
1146 &image_features,
1147 );
1148
1149 assert!(loss > 0.0);
1150 }
1151
1152 #[test]
1153 fn test_synthetic_multimodal_graph() {
1154 let modalities = vec![Modality::Text, Modality::Image, Modality::Audio];
1155 let mm_graph = utils::create_synthetic_multimodal_graph(5, 64, modalities);
1156
1157 assert_eq!(mm_graph.graph.num_nodes, 5);
1158 assert!(mm_graph.available_modalities.len() <= 3);
1159
1160 assert!(!mm_graph.node_data.is_empty());
1162
1163 let stats = mm_graph.modality_statistics();
1164 for coverage in stats.values() {
1165 assert!(*coverage >= 0.0 && *coverage <= 1.0);
1166 }
1167 }
1168
1169 #[test]
1170 fn test_multimodal_quality_evaluation() {
1171 let modalities = vec![Modality::Text, Modality::Image];
1172 let mm_graph = utils::create_synthetic_multimodal_graph(4, 32, modalities);
1173
1174 let mut representations = HashMap::new();
1175 representations.insert(Modality::Text, randn(&[4, 128]).unwrap());
1176 representations.insert(Modality::Image, randn(&[4, 128]).unwrap());
1177
1178 let metrics = utils::evaluate_multimodal_quality(&mm_graph, &representations);
1179
1180 assert!(metrics.contains_key("Text_mean"));
1181 assert!(metrics.contains_key("Image_variance"));
1182
1183 let consistency_keys: Vec<_> = metrics
1185 .keys()
1186 .filter(|k| k.contains("consistency"))
1187 .collect();
1188 assert!(!consistency_keys.is_empty());
1189 }
1190
1191 #[test]
1192 fn test_alignment_task_generation() {
1193 let modalities = vec![Modality::Text, Modality::Image];
1194 let mm_graph = utils::create_synthetic_multimodal_graph(3, 32, modalities);
1195
1196 let tasks = utils::generate_alignment_tasks(&mm_graph, Modality::Text, Modality::Image, 5);
1197
1198 assert!(tasks.len() <= 5);
1200
1201 for (node_id, source, target) in &tasks {
1202 assert!(*node_id < 3);
1203 assert!(!source
1204 .to_vec()
1205 .expect("conversion should succeed")
1206 .is_empty());
1207 assert!(!target
1208 .to_vec()
1209 .expect("conversion should succeed")
1210 .is_empty());
1211 }
1212 }
1213}