1use scirs2_core::ndarray_ext::{Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::rand_prelude::*;
9use scirs2_core::random::Random;
10use sklears_core::error::SklearsError;
11use std::collections::HashMap;
12
13#[derive(Clone)]
15pub struct MultiViewGraphLearning {
16 pub k_neighbors: usize,
18 pub view_weights: Vec<f64>,
20 pub combination_method: String,
22 pub regularization: f64,
24 pub max_iter: usize,
26 pub tolerance: f64,
28 pub random_state: Option<u64>,
30}
31
32impl MultiViewGraphLearning {
33 pub fn new() -> Self {
35 Self {
36 k_neighbors: 5,
37 view_weights: vec![],
38 combination_method: "weighted".to_string(),
39 regularization: 0.1,
40 max_iter: 100,
41 tolerance: 1e-6,
42 random_state: None,
43 }
44 }
45
46 pub fn k_neighbors(mut self, k: usize) -> Self {
48 self.k_neighbors = k;
49 self
50 }
51
52 pub fn view_weights(mut self, weights: Vec<f64>) -> Self {
54 self.view_weights = weights;
55 self
56 }
57
58 pub fn combination_method(mut self, method: String) -> Self {
60 self.combination_method = method;
61 self
62 }
63
64 pub fn regularization(mut self, reg: f64) -> Self {
66 self.regularization = reg;
67 self
68 }
69
70 pub fn max_iter(mut self, max_iter: usize) -> Self {
72 self.max_iter = max_iter;
73 self
74 }
75
76 pub fn tolerance(mut self, tol: f64) -> Self {
78 self.tolerance = tol;
79 self
80 }
81
82 pub fn random_state(mut self, seed: u64) -> Self {
84 self.random_state = Some(seed);
85 self
86 }
87
88 pub fn fit(&self, views: &[ArrayView2<f64>]) -> Result<Array2<f64>, SklearsError> {
90 if views.is_empty() {
91 return Err(SklearsError::InvalidInput("No views provided".to_string()));
92 }
93
94 let n_samples = views[0].nrows();
95
96 for view in views.iter() {
98 if view.nrows() != n_samples {
99 return Err(SklearsError::ShapeMismatch {
100 expected: format!("All views should have {} samples", n_samples),
101 actual: format!("View has {} samples", view.nrows()),
102 });
103 }
104 }
105
106 let view_graphs = self.construct_view_graphs(views)?;
108
109 let combined_graph = self.combine_graphs(&view_graphs)?;
111
112 Ok(combined_graph)
113 }
114
115 fn construct_view_graphs(
117 &self,
118 views: &[ArrayView2<f64>],
119 ) -> Result<Vec<Array2<f64>>, SklearsError> {
120 let mut graphs = Vec::new();
121
122 for view in views.iter() {
123 let graph = self.construct_knn_graph(view)?;
124 graphs.push(graph);
125 }
126
127 Ok(graphs)
128 }
129
130 fn construct_knn_graph(&self, X: &ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
132 let n_samples = X.nrows();
133 let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
134
135 for i in 0..n_samples {
136 let mut distances: Vec<(f64, usize)> = Vec::new();
137
138 for j in 0..n_samples {
139 if i != j {
140 let dist = self.euclidean_distance(&X.row(i), &X.row(j));
141 distances.push((dist, j));
142 }
143 }
144
145 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
147
148 for (dist, j) in distances.iter().take(self.k_neighbors.min(distances.len())) {
149 let weight = (-dist.powi(2) / 2.0).exp(); graph[[i, *j]] = weight;
151 }
152 }
153
154 for i in 0..n_samples {
156 for j in i + 1..n_samples {
157 let avg_weight = (graph[[i, j]] + graph[[j, i]]) / 2.0;
158 graph[[i, j]] = avg_weight;
159 graph[[j, i]] = avg_weight;
160 }
161 }
162
163 Ok(graph)
164 }
165
166 fn combine_graphs(&self, graphs: &[Array2<f64>]) -> Result<Array2<f64>, SklearsError> {
168 if graphs.is_empty() {
169 return Err(SklearsError::InvalidInput(
170 "No graphs to combine".to_string(),
171 ));
172 }
173
174 let n_samples = graphs[0].nrows();
175 let mut combined = Array2::<f64>::zeros((n_samples, n_samples));
176
177 match self.combination_method.as_str() {
178 "weighted" => {
179 let weights = if self.view_weights.is_empty() {
180 vec![1.0 / graphs.len() as f64; graphs.len()]
181 } else {
182 self.view_weights.clone()
183 };
184
185 if weights.len() != graphs.len() {
186 return Err(SklearsError::InvalidInput(
187 "Number of weights must match number of views".to_string(),
188 ));
189 }
190
191 for (i, graph) in graphs.iter().enumerate() {
192 combined += &(graph * weights[i]);
193 }
194 }
195 "union" => {
196 for graph in graphs.iter() {
197 for i in 0..n_samples {
198 for j in 0..n_samples {
199 combined[[i, j]] = combined[[i, j]].max(graph[[i, j]]);
200 }
201 }
202 }
203 }
204 "intersection" => {
205 combined = graphs[0].clone();
206 for graph in graphs.iter().skip(1) {
207 for i in 0..n_samples {
208 for j in 0..n_samples {
209 combined[[i, j]] = combined[[i, j]].min(graph[[i, j]]);
210 }
211 }
212 }
213 }
214 "adaptive" => {
215 combined = self.adaptive_combination(graphs)?;
216 }
217 _ => {
218 return Err(SklearsError::InvalidInput(format!(
219 "Unknown combination method: {}",
220 self.combination_method
221 )));
222 }
223 }
224
225 Ok(combined)
226 }
227
228 fn adaptive_combination(&self, graphs: &[Array2<f64>]) -> Result<Array2<f64>, SklearsError> {
230 let n_views = graphs.len();
231 let n_samples = graphs[0].nrows();
232
233 let mut weights = vec![1.0 / n_views as f64; n_views];
235
236 for _iter in 0..self.max_iter {
237 let old_weights = weights.clone();
238
239 let mut combined = Array2::<f64>::zeros((n_samples, n_samples));
241 for (i, graph) in graphs.iter().enumerate() {
242 combined += &(graph * weights[i]);
243 }
244
245 for i in 0..n_views {
247 let agreement = self.compute_graph_agreement(&graphs[i], &combined);
248 weights[i] = agreement;
249 }
250
251 let weight_sum: f64 = weights.iter().sum();
253 if weight_sum > 0.0 {
254 for w in weights.iter_mut() {
255 *w /= weight_sum;
256 }
257 }
258
259 let weight_change: f64 = weights
261 .iter()
262 .zip(old_weights.iter())
263 .map(|(w1, w2)| (w1 - w2).abs())
264 .sum();
265
266 if weight_change < self.tolerance {
267 break;
268 }
269 }
270
271 let mut combined = Array2::<f64>::zeros((n_samples, n_samples));
273 for (i, graph) in graphs.iter().enumerate() {
274 combined += &(graph * weights[i]);
275 }
276
277 Ok(combined)
278 }
279
280 fn compute_graph_agreement(&self, graph1: &Array2<f64>, graph2: &Array2<f64>) -> f64 {
282 let mut agreement = 0.0;
283 let mut total = 0.0;
284
285 for i in 0..graph1.nrows() {
286 for j in 0..graph1.ncols() {
287 let diff = (graph1[[i, j]] - graph2[[i, j]]).abs();
288 agreement += 1.0 / (1.0 + diff);
289 total += 1.0;
290 }
291 }
292
293 if total > 0.0 {
294 agreement / total
295 } else {
296 0.0
297 }
298 }
299
300 fn euclidean_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
302 x1.iter()
303 .zip(x2.iter())
304 .map(|(a, b)| (a - b).powi(2))
305 .sum::<f64>()
306 .sqrt()
307 }
308}
309
310impl Default for MultiViewGraphLearning {
311 fn default() -> Self {
312 Self::new()
313 }
314}
315
316#[derive(Clone)]
318pub struct HeterogeneousGraphLearning {
319 pub node_types: Vec<String>,
321 pub edge_types: Vec<(String, String)>,
323 pub edge_weights: HashMap<(String, String), f64>,
325 pub embedding_dims: HashMap<String, usize>,
327 pub k_neighbors: HashMap<(String, String), usize>,
329 pub random_state: Option<u64>,
331}
332
333impl HeterogeneousGraphLearning {
334 pub fn new() -> Self {
336 Self {
337 node_types: vec![],
338 edge_types: vec![],
339 edge_weights: HashMap::new(),
340 embedding_dims: HashMap::new(),
341 k_neighbors: HashMap::new(),
342 random_state: None,
343 }
344 }
345
346 pub fn node_types(mut self, types: Vec<String>) -> Self {
348 self.node_types = types;
349 self
350 }
351
352 pub fn edge_types(mut self, types: Vec<(String, String)>) -> Self {
354 self.edge_types = types;
355 self
356 }
357
358 pub fn edge_weights(mut self, weights: HashMap<(String, String), f64>) -> Self {
360 self.edge_weights = weights;
361 self
362 }
363
364 pub fn embedding_dims(mut self, dims: HashMap<String, usize>) -> Self {
366 self.embedding_dims = dims;
367 self
368 }
369
370 pub fn random_state(mut self, seed: u64) -> Self {
372 self.random_state = Some(seed);
373 self
374 }
375
376 pub fn fit(
378 &self,
379 data: &HashMap<String, ArrayView2<f64>>,
380 ) -> Result<HashMap<String, Array2<f64>>, SklearsError> {
381 if data.is_empty() {
382 return Err(SklearsError::InvalidInput("No data provided".to_string()));
383 }
384
385 let mut embeddings = HashMap::new();
386 let mut rng = if let Some(seed) = self.random_state {
387 Random::seed(42)
388 } else {
389 Random::seed(42) };
391
392 for (node_type, node_data) in data.iter() {
394 let embed_dim = self.embedding_dims.get(node_type).unwrap_or(&64);
395 let n_nodes = node_data.nrows();
396
397 let mut embedding = Array2::<f64>::zeros((n_nodes, *embed_dim));
399 for i in 0..n_nodes {
400 for j in 0..*embed_dim {
401 embedding[[i, j]] = rng.random_range(-1.0..1.0);
402 }
403 }
404
405 embeddings.insert(node_type.clone(), embedding);
406 }
407
408 for (node_type, node_data) in data.iter() {
411 let features = node_data.to_owned();
412 embeddings.insert(node_type.clone(), features);
413 }
414
415 Ok(embeddings)
416 }
417}
418
419impl Default for HeterogeneousGraphLearning {
420 fn default() -> Self {
421 Self::new()
422 }
423}
424
425#[derive(Clone)]
427pub struct TemporalGraphLearning {
428 pub window_size: usize,
430 pub temporal_decay: f64,
432 pub aggregation_method: String,
434 pub k_neighbors: usize,
436 pub random_state: Option<u64>,
438}
439
440impl TemporalGraphLearning {
441 pub fn new() -> Self {
443 Self {
444 window_size: 5,
445 temporal_decay: 0.9,
446 aggregation_method: "weighted".to_string(),
447 k_neighbors: 5,
448 random_state: None,
449 }
450 }
451
452 pub fn window_size(mut self, size: usize) -> Self {
454 self.window_size = size;
455 self
456 }
457
458 pub fn temporal_decay(mut self, decay: f64) -> Self {
460 self.temporal_decay = decay;
461 self
462 }
463
464 pub fn aggregation_method(mut self, method: String) -> Self {
466 self.aggregation_method = method;
467 self
468 }
469
470 pub fn k_neighbors(mut self, k: usize) -> Self {
472 self.k_neighbors = k;
473 self
474 }
475
476 pub fn random_state(mut self, seed: u64) -> Self {
478 self.random_state = Some(seed);
479 self
480 }
481
482 pub fn fit(&self, snapshots: &[ArrayView2<f64>]) -> Result<Array2<f64>, SklearsError> {
484 if snapshots.is_empty() {
485 return Err(SklearsError::InvalidInput(
486 "No snapshots provided".to_string(),
487 ));
488 }
489
490 let n_samples = snapshots[0].nrows();
491
492 for snapshot in snapshots.iter() {
494 if snapshot.nrows() != n_samples {
495 return Err(SklearsError::ShapeMismatch {
496 expected: format!("All snapshots should have {} samples", n_samples),
497 actual: format!("Snapshot has {} samples", snapshot.nrows()),
498 });
499 }
500 }
501
502 let graphs = self.construct_temporal_graphs(snapshots)?;
504
505 let aggregated_graph = self.aggregate_temporal_graphs(&graphs)?;
507
508 Ok(aggregated_graph)
509 }
510
511 fn construct_temporal_graphs(
513 &self,
514 snapshots: &[ArrayView2<f64>],
515 ) -> Result<Vec<Array2<f64>>, SklearsError> {
516 let mut graphs = Vec::new();
517
518 for snapshot in snapshots.iter() {
519 let graph = self.construct_knn_graph(snapshot)?;
520 graphs.push(graph);
521 }
522
523 Ok(graphs)
524 }
525
526 fn construct_knn_graph(&self, X: &ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
528 let n_samples = X.nrows();
529 let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
530
531 for i in 0..n_samples {
532 let mut distances: Vec<(f64, usize)> = Vec::new();
533
534 for j in 0..n_samples {
535 if i != j {
536 let dist = self.euclidean_distance(&X.row(i), &X.row(j));
537 distances.push((dist, j));
538 }
539 }
540
541 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
543
544 for (dist, j) in distances.iter().take(self.k_neighbors.min(distances.len())) {
545 let weight = (-dist.powi(2) / 2.0).exp(); graph[[i, *j]] = weight;
547 }
548 }
549
550 for i in 0..n_samples {
552 for j in i + 1..n_samples {
553 let avg_weight = (graph[[i, j]] + graph[[j, i]]) / 2.0;
554 graph[[i, j]] = avg_weight;
555 graph[[j, i]] = avg_weight;
556 }
557 }
558
559 Ok(graph)
560 }
561
562 fn aggregate_temporal_graphs(
564 &self,
565 graphs: &[Array2<f64>],
566 ) -> Result<Array2<f64>, SklearsError> {
567 if graphs.is_empty() {
568 return Err(SklearsError::InvalidInput(
569 "No graphs to aggregate".to_string(),
570 ));
571 }
572
573 let n_samples = graphs[0].nrows();
574 let mut aggregated = Array2::<f64>::zeros((n_samples, n_samples));
575
576 match self.aggregation_method.as_str() {
577 "mean" => {
578 for graph in graphs.iter() {
579 aggregated += graph;
580 }
581 aggregated /= graphs.len() as f64;
582 }
583 "weighted" => {
584 let total_weight: f64 = (0..graphs.len())
585 .map(|i| self.temporal_decay.powi(i as i32))
586 .sum();
587
588 for (i, graph) in graphs.iter().enumerate() {
589 let weight = self.temporal_decay.powi(i as i32) / total_weight;
590 aggregated += &(graph * weight);
591 }
592 }
593 "attention" => {
594 let weights = self.compute_attention_weights(graphs)?;
596 for (i, graph) in graphs.iter().enumerate() {
597 aggregated += &(graph * weights[i]);
598 }
599 }
600 _ => {
601 return Err(SklearsError::InvalidInput(format!(
602 "Unknown aggregation method: {}",
603 self.aggregation_method
604 )));
605 }
606 }
607
608 Ok(aggregated)
609 }
610
611 fn compute_attention_weights(&self, graphs: &[Array2<f64>]) -> Result<Vec<f64>, SklearsError> {
613 let n_graphs = graphs.len();
614 let mut weights = vec![1.0 / n_graphs as f64; n_graphs];
615
616 let mut densities = Vec::new();
618 for graph in graphs.iter() {
619 let density = graph.iter().filter(|&&x| x > 0.0).count() as f64 / (graph.len() as f64);
620 densities.push(density);
621 }
622
623 let total_density: f64 = densities.iter().sum();
624 if total_density > 0.0 {
625 for (i, density) in densities.iter().enumerate() {
626 weights[i] = density / total_density;
627 }
628 }
629
630 Ok(weights)
631 }
632
633 fn euclidean_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
635 x1.iter()
636 .zip(x2.iter())
637 .map(|(a, b)| (a - b).powi(2))
638 .sum::<f64>()
639 .sqrt()
640 }
641}
642
643impl Default for TemporalGraphLearning {
644 fn default() -> Self {
645 Self::new()
646 }
647}
648
649#[allow(non_snake_case)]
650#[cfg(test)]
651mod tests {
652 use super::*;
653 use approx::assert_abs_diff_eq;
654 use scirs2_core::array;
655
656 #[test]
657 fn test_multi_view_graph_learning() {
658 let view1 = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
659 let view2 = array![[2.0, 1.0], [3.0, 2.0], [4.0, 3.0]];
660 let views = vec![view1.view(), view2.view()];
661
662 let mvgl = MultiViewGraphLearning::new()
663 .k_neighbors(2)
664 .combination_method("weighted".to_string());
665
666 let result = mvgl.fit(&views);
667 assert!(result.is_ok());
668
669 let graph = result.unwrap();
670 assert_eq!(graph.dim(), (3, 3));
671
672 assert_eq!(graph[[0, 0]], 0.0);
674 assert_eq!(graph[[1, 1]], 0.0);
675 assert_eq!(graph[[2, 2]], 0.0);
676
677 assert_abs_diff_eq!(graph[[0, 1]], graph[[1, 0]], epsilon = 1e-10);
679 assert_abs_diff_eq!(graph[[0, 2]], graph[[2, 0]], epsilon = 1e-10);
680 assert_abs_diff_eq!(graph[[1, 2]], graph[[2, 1]], epsilon = 1e-10);
681 }
682
683 #[test]
684 fn test_multi_view_graph_union() {
685 let view1 = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
686 let view2 = array![[2.0, 1.0], [3.0, 2.0], [4.0, 3.0]];
687 let views = vec![view1.view(), view2.view()];
688
689 let mvgl = MultiViewGraphLearning::new()
690 .k_neighbors(2)
691 .combination_method("union".to_string());
692
693 let result = mvgl.fit(&views);
694 assert!(result.is_ok());
695
696 let graph = result.unwrap();
697 assert_eq!(graph.dim(), (3, 3));
698 }
699
700 #[test]
701 fn test_multi_view_graph_adaptive() {
702 let view1 = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
703 let view2 = array![[2.0, 1.0], [3.0, 2.0], [4.0, 3.0]];
704 let views = vec![view1.view(), view2.view()];
705
706 let mvgl = MultiViewGraphLearning::new()
707 .k_neighbors(2)
708 .combination_method("adaptive".to_string())
709 .max_iter(10)
710 .tolerance(1e-4);
711
712 let result = mvgl.fit(&views);
713 assert!(result.is_ok());
714
715 let graph = result.unwrap();
716 assert_eq!(graph.dim(), (3, 3));
717 }
718
719 #[test]
720 fn test_heterogeneous_graph_learning() {
721 let type1_data = array![[1.0, 2.0], [2.0, 3.0]];
722 let type2_data = array![[3.0, 4.0], [4.0, 5.0]];
723 let mut data = HashMap::new();
724 data.insert("type1".to_string(), type1_data.view());
725 data.insert("type2".to_string(), type2_data.view());
726
727 let hgl = HeterogeneousGraphLearning::new()
728 .node_types(vec!["type1".to_string(), "type2".to_string()]);
729
730 let result = hgl.fit(&data);
731 assert!(result.is_ok());
732
733 let embeddings = result.unwrap();
734 assert!(embeddings.contains_key("type1"));
735 assert!(embeddings.contains_key("type2"));
736 assert_eq!(embeddings["type1"].dim(), (2, 2));
737 assert_eq!(embeddings["type2"].dim(), (2, 2));
738 }
739
740 #[test]
741 fn test_temporal_graph_learning() {
742 let snapshot1 = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
743 let snapshot2 = array![[1.1, 2.1], [2.1, 3.1], [3.1, 4.1]];
744 let snapshot3 = array![[1.2, 2.2], [2.2, 3.2], [3.2, 4.2]];
745 let snapshots = vec![snapshot1.view(), snapshot2.view(), snapshot3.view()];
746
747 let tgl = TemporalGraphLearning::new()
748 .window_size(3)
749 .temporal_decay(0.9)
750 .aggregation_method("weighted".to_string())
751 .k_neighbors(2);
752
753 let result = tgl.fit(&snapshots);
754 assert!(result.is_ok());
755
756 let graph = result.unwrap();
757 assert_eq!(graph.dim(), (3, 3));
758
759 assert_eq!(graph[[0, 0]], 0.0);
761 assert_eq!(graph[[1, 1]], 0.0);
762 assert_eq!(graph[[2, 2]], 0.0);
763 }
764
765 #[test]
766 fn test_temporal_graph_attention() {
767 let snapshot1 = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
768 let snapshot2 = array![[1.1, 2.1], [2.1, 3.1], [3.1, 4.1]];
769 let snapshots = vec![snapshot1.view(), snapshot2.view()];
770
771 let tgl = TemporalGraphLearning::new()
772 .aggregation_method("attention".to_string())
773 .k_neighbors(2);
774
775 let result = tgl.fit(&snapshots);
776 assert!(result.is_ok());
777
778 let graph = result.unwrap();
779 assert_eq!(graph.dim(), (3, 3));
780 }
781
782 #[test]
783 fn test_multi_view_graph_error_cases() {
784 let mvgl = MultiViewGraphLearning::new();
785
786 let result = mvgl.fit(&[]);
788 assert!(result.is_err());
789
790 let view1 = array![[1.0, 2.0], [2.0, 3.0]];
792 let view2 = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
793 let views = vec![view1.view(), view2.view()];
794
795 let result = mvgl.fit(&views);
796 assert!(result.is_err());
797 }
798
799 #[test]
800 fn test_temporal_graph_error_cases() {
801 let tgl = TemporalGraphLearning::new();
802
803 let result = tgl.fit(&[]);
805 assert!(result.is_err());
806
807 let snapshot1 = array![[1.0, 2.0], [2.0, 3.0]];
809 let snapshot2 = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
810 let snapshots = vec![snapshot1.view(), snapshot2.view()];
811
812 let result = tgl.fit(&snapshots);
813 assert!(result.is_err());
814 }
815}