sklears_semi_supervised/
multi_view_graph.rs

1//! Multi-view graph learning methods for semi-supervised learning
2//!
3//! This module provides advanced graph learning algorithms that can handle
4//! multiple views or modalities of data, enabling more robust semi-supervised
5//! learning on complex, multi-modal datasets.
6
7use scirs2_core::ndarray_ext::{Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::rand_prelude::*;
9use scirs2_core::random::Random;
10use sklears_core::error::SklearsError;
11use std::collections::HashMap;
12
13/// Multi-view graph learning that constructs graphs from multiple data views
14#[derive(Clone)]
15pub struct MultiViewGraphLearning {
16    /// Number of neighbors for k-NN graph construction
17    pub k_neighbors: usize,
18    /// Weights for combining different views
19    pub view_weights: Vec<f64>,
20    /// Method for combining views: "weighted", "union", "intersection", "adaptive"
21    pub combination_method: String,
22    /// Regularization parameter for graph structure learning
23    pub regularization: f64,
24    /// Maximum iterations for optimization
25    pub max_iter: usize,
26    /// Convergence tolerance
27    pub tolerance: f64,
28    /// Random state for reproducibility
29    pub random_state: Option<u64>,
30}
31
32impl MultiViewGraphLearning {
33    /// Create a new multi-view graph learning instance
34    pub fn new() -> Self {
35        Self {
36            k_neighbors: 5,
37            view_weights: vec![],
38            combination_method: "weighted".to_string(),
39            regularization: 0.1,
40            max_iter: 100,
41            tolerance: 1e-6,
42            random_state: None,
43        }
44    }
45
46    /// Set the number of neighbors for k-NN graph construction
47    pub fn k_neighbors(mut self, k: usize) -> Self {
48        self.k_neighbors = k;
49        self
50    }
51
52    /// Set the weights for combining different views
53    pub fn view_weights(mut self, weights: Vec<f64>) -> Self {
54        self.view_weights = weights;
55        self
56    }
57
58    /// Set the method for combining views
59    pub fn combination_method(mut self, method: String) -> Self {
60        self.combination_method = method;
61        self
62    }
63
64    /// Set the regularization parameter
65    pub fn regularization(mut self, reg: f64) -> Self {
66        self.regularization = reg;
67        self
68    }
69
70    /// Set the maximum number of iterations
71    pub fn max_iter(mut self, max_iter: usize) -> Self {
72        self.max_iter = max_iter;
73        self
74    }
75
76    /// Set the convergence tolerance
77    pub fn tolerance(mut self, tol: f64) -> Self {
78        self.tolerance = tol;
79        self
80    }
81
82    /// Set the random state for reproducibility
83    pub fn random_state(mut self, seed: u64) -> Self {
84        self.random_state = Some(seed);
85        self
86    }
87
88    /// Learn a unified graph from multiple views of data
89    pub fn fit(&self, views: &[ArrayView2<f64>]) -> Result<Array2<f64>, SklearsError> {
90        if views.is_empty() {
91            return Err(SklearsError::InvalidInput("No views provided".to_string()));
92        }
93
94        let n_samples = views[0].nrows();
95
96        // Validate that all views have the same number of samples
97        for view in views.iter() {
98            if view.nrows() != n_samples {
99                return Err(SklearsError::ShapeMismatch {
100                    expected: format!("All views should have {} samples", n_samples),
101                    actual: format!("View has {} samples", view.nrows()),
102                });
103            }
104        }
105
106        // Construct graphs for each view
107        let view_graphs = self.construct_view_graphs(views)?;
108
109        // Combine graphs according to the specified method
110        let combined_graph = self.combine_graphs(&view_graphs)?;
111
112        Ok(combined_graph)
113    }
114
115    /// Construct k-NN graphs for each view
116    fn construct_view_graphs(
117        &self,
118        views: &[ArrayView2<f64>],
119    ) -> Result<Vec<Array2<f64>>, SklearsError> {
120        let mut graphs = Vec::new();
121
122        for view in views.iter() {
123            let graph = self.construct_knn_graph(view)?;
124            graphs.push(graph);
125        }
126
127        Ok(graphs)
128    }
129
130    /// Construct a k-NN graph from a single view
131    fn construct_knn_graph(&self, X: &ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
132        let n_samples = X.nrows();
133        let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
134
135        for i in 0..n_samples {
136            let mut distances: Vec<(f64, usize)> = Vec::new();
137
138            for j in 0..n_samples {
139                if i != j {
140                    let dist = self.euclidean_distance(&X.row(i), &X.row(j));
141                    distances.push((dist, j));
142                }
143            }
144
145            // Sort by distance and take k nearest neighbors
146            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
147
148            for (dist, j) in distances.iter().take(self.k_neighbors.min(distances.len())) {
149                let weight = (-dist.powi(2) / 2.0).exp(); // RBF kernel
150                graph[[i, *j]] = weight;
151            }
152        }
153
154        // Make graph symmetric
155        for i in 0..n_samples {
156            for j in i + 1..n_samples {
157                let avg_weight = (graph[[i, j]] + graph[[j, i]]) / 2.0;
158                graph[[i, j]] = avg_weight;
159                graph[[j, i]] = avg_weight;
160            }
161        }
162
163        Ok(graph)
164    }
165
166    /// Combine multiple view graphs into a unified graph
167    fn combine_graphs(&self, graphs: &[Array2<f64>]) -> Result<Array2<f64>, SklearsError> {
168        if graphs.is_empty() {
169            return Err(SklearsError::InvalidInput(
170                "No graphs to combine".to_string(),
171            ));
172        }
173
174        let n_samples = graphs[0].nrows();
175        let mut combined = Array2::<f64>::zeros((n_samples, n_samples));
176
177        match self.combination_method.as_str() {
178            "weighted" => {
179                let weights = if self.view_weights.is_empty() {
180                    vec![1.0 / graphs.len() as f64; graphs.len()]
181                } else {
182                    self.view_weights.clone()
183                };
184
185                if weights.len() != graphs.len() {
186                    return Err(SklearsError::InvalidInput(
187                        "Number of weights must match number of views".to_string(),
188                    ));
189                }
190
191                for (i, graph) in graphs.iter().enumerate() {
192                    combined += &(graph * weights[i]);
193                }
194            }
195            "union" => {
196                for graph in graphs.iter() {
197                    for i in 0..n_samples {
198                        for j in 0..n_samples {
199                            combined[[i, j]] = combined[[i, j]].max(graph[[i, j]]);
200                        }
201                    }
202                }
203            }
204            "intersection" => {
205                combined = graphs[0].clone();
206                for graph in graphs.iter().skip(1) {
207                    for i in 0..n_samples {
208                        for j in 0..n_samples {
209                            combined[[i, j]] = combined[[i, j]].min(graph[[i, j]]);
210                        }
211                    }
212                }
213            }
214            "adaptive" => {
215                combined = self.adaptive_combination(graphs)?;
216            }
217            _ => {
218                return Err(SklearsError::InvalidInput(format!(
219                    "Unknown combination method: {}",
220                    self.combination_method
221                )));
222            }
223        }
224
225        Ok(combined)
226    }
227
228    /// Adaptive combination that learns optimal weights for views
229    fn adaptive_combination(&self, graphs: &[Array2<f64>]) -> Result<Array2<f64>, SklearsError> {
230        let n_views = graphs.len();
231        let n_samples = graphs[0].nrows();
232
233        // Initialize weights uniformly
234        let mut weights = vec![1.0 / n_views as f64; n_views];
235
236        for _iter in 0..self.max_iter {
237            let old_weights = weights.clone();
238
239            // Compute current combined graph
240            let mut combined = Array2::<f64>::zeros((n_samples, n_samples));
241            for (i, graph) in graphs.iter().enumerate() {
242                combined += &(graph * weights[i]);
243            }
244
245            // Update weights based on agreement with combined graph
246            for i in 0..n_views {
247                let agreement = self.compute_graph_agreement(&graphs[i], &combined);
248                weights[i] = agreement;
249            }
250
251            // Normalize weights
252            let weight_sum: f64 = weights.iter().sum();
253            if weight_sum > 0.0 {
254                for w in weights.iter_mut() {
255                    *w /= weight_sum;
256                }
257            }
258
259            // Check convergence
260            let weight_change: f64 = weights
261                .iter()
262                .zip(old_weights.iter())
263                .map(|(w1, w2)| (w1 - w2).abs())
264                .sum();
265
266            if weight_change < self.tolerance {
267                break;
268            }
269        }
270
271        // Compute final combined graph
272        let mut combined = Array2::<f64>::zeros((n_samples, n_samples));
273        for (i, graph) in graphs.iter().enumerate() {
274            combined += &(graph * weights[i]);
275        }
276
277        Ok(combined)
278    }
279
280    /// Compute agreement between two graphs
281    fn compute_graph_agreement(&self, graph1: &Array2<f64>, graph2: &Array2<f64>) -> f64 {
282        let mut agreement = 0.0;
283        let mut total = 0.0;
284
285        for i in 0..graph1.nrows() {
286            for j in 0..graph1.ncols() {
287                let diff = (graph1[[i, j]] - graph2[[i, j]]).abs();
288                agreement += 1.0 / (1.0 + diff);
289                total += 1.0;
290            }
291        }
292
293        if total > 0.0 {
294            agreement / total
295        } else {
296            0.0
297        }
298    }
299
300    /// Compute Euclidean distance between two vectors
301    fn euclidean_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
302        x1.iter()
303            .zip(x2.iter())
304            .map(|(a, b)| (a - b).powi(2))
305            .sum::<f64>()
306            .sqrt()
307    }
308}
309
310impl Default for MultiViewGraphLearning {
311    fn default() -> Self {
312        Self::new()
313    }
314}
315
316/// Heterogeneous graph learning for mixed data types
317#[derive(Clone)]
318pub struct HeterogeneousGraphLearning {
319    /// Node types in the heterogeneous graph
320    pub node_types: Vec<String>,
321    /// Edge types connecting different node types
322    pub edge_types: Vec<(String, String)>,
323    /// Weights for different edge types
324    pub edge_weights: HashMap<(String, String), f64>,
325    /// Embedding dimensions for each node type
326    pub embedding_dims: HashMap<String, usize>,
327    /// Number of neighbors for each edge type
328    pub k_neighbors: HashMap<(String, String), usize>,
329    /// Random state for reproducibility
330    pub random_state: Option<u64>,
331}
332
333impl HeterogeneousGraphLearning {
334    /// Create a new heterogeneous graph learning instance
335    pub fn new() -> Self {
336        Self {
337            node_types: vec![],
338            edge_types: vec![],
339            edge_weights: HashMap::new(),
340            embedding_dims: HashMap::new(),
341            k_neighbors: HashMap::new(),
342            random_state: None,
343        }
344    }
345
346    /// Set node types
347    pub fn node_types(mut self, types: Vec<String>) -> Self {
348        self.node_types = types;
349        self
350    }
351
352    /// Set edge types
353    pub fn edge_types(mut self, types: Vec<(String, String)>) -> Self {
354        self.edge_types = types;
355        self
356    }
357
358    /// Set weights for edge types
359    pub fn edge_weights(mut self, weights: HashMap<(String, String), f64>) -> Self {
360        self.edge_weights = weights;
361        self
362    }
363
364    /// Set embedding dimensions for node types
365    pub fn embedding_dims(mut self, dims: HashMap<String, usize>) -> Self {
366        self.embedding_dims = dims;
367        self
368    }
369
370    /// Set random state
371    pub fn random_state(mut self, seed: u64) -> Self {
372        self.random_state = Some(seed);
373        self
374    }
375
376    /// Learn embeddings for heterogeneous graph
377    pub fn fit(
378        &self,
379        data: &HashMap<String, ArrayView2<f64>>,
380    ) -> Result<HashMap<String, Array2<f64>>, SklearsError> {
381        if data.is_empty() {
382            return Err(SklearsError::InvalidInput("No data provided".to_string()));
383        }
384
385        let mut embeddings = HashMap::new();
386        let mut rng = if let Some(seed) = self.random_state {
387            Random::seed(42)
388        } else {
389            Random::seed(42) // Use a default seed instead of from_entropy
390        };
391
392        // Initialize embeddings for each node type
393        for (node_type, node_data) in data.iter() {
394            let embed_dim = self.embedding_dims.get(node_type).unwrap_or(&64);
395            let n_nodes = node_data.nrows();
396
397            // Initialize random embeddings
398            let mut embedding = Array2::<f64>::zeros((n_nodes, *embed_dim));
399            for i in 0..n_nodes {
400                for j in 0..*embed_dim {
401                    embedding[[i, j]] = rng.random_range(-1.0..1.0);
402                }
403            }
404
405            embeddings.insert(node_type.clone(), embedding);
406        }
407
408        // Simple implementation: use input features as embeddings
409        // In practice, this would involve more sophisticated learning
410        for (node_type, node_data) in data.iter() {
411            let features = node_data.to_owned();
412            embeddings.insert(node_type.clone(), features);
413        }
414
415        Ok(embeddings)
416    }
417}
418
419impl Default for HeterogeneousGraphLearning {
420    fn default() -> Self {
421        Self::new()
422    }
423}
424
425/// Temporal graph learning for time-evolving graphs
426#[derive(Clone)]
427pub struct TemporalGraphLearning {
428    /// Window size for temporal analysis
429    pub window_size: usize,
430    /// Decay factor for temporal weighting
431    pub temporal_decay: f64,
432    /// Method for temporal aggregation: "mean", "weighted", "attention"
433    pub aggregation_method: String,
434    /// Number of neighbors for graph construction
435    pub k_neighbors: usize,
436    /// Random state for reproducibility
437    pub random_state: Option<u64>,
438}
439
440impl TemporalGraphLearning {
441    /// Create a new temporal graph learning instance
442    pub fn new() -> Self {
443        Self {
444            window_size: 5,
445            temporal_decay: 0.9,
446            aggregation_method: "weighted".to_string(),
447            k_neighbors: 5,
448            random_state: None,
449        }
450    }
451
452    /// Set window size
453    pub fn window_size(mut self, size: usize) -> Self {
454        self.window_size = size;
455        self
456    }
457
458    /// Set temporal decay factor
459    pub fn temporal_decay(mut self, decay: f64) -> Self {
460        self.temporal_decay = decay;
461        self
462    }
463
464    /// Set aggregation method
465    pub fn aggregation_method(mut self, method: String) -> Self {
466        self.aggregation_method = method;
467        self
468    }
469
470    /// Set number of neighbors
471    pub fn k_neighbors(mut self, k: usize) -> Self {
472        self.k_neighbors = k;
473        self
474    }
475
476    /// Set random state
477    pub fn random_state(mut self, seed: u64) -> Self {
478        self.random_state = Some(seed);
479        self
480    }
481
482    /// Learn from temporal graph snapshots
483    pub fn fit(&self, snapshots: &[ArrayView2<f64>]) -> Result<Array2<f64>, SklearsError> {
484        if snapshots.is_empty() {
485            return Err(SklearsError::InvalidInput(
486                "No snapshots provided".to_string(),
487            ));
488        }
489
490        let n_samples = snapshots[0].nrows();
491
492        // Validate that all snapshots have the same dimensions
493        for snapshot in snapshots.iter() {
494            if snapshot.nrows() != n_samples {
495                return Err(SklearsError::ShapeMismatch {
496                    expected: format!("All snapshots should have {} samples", n_samples),
497                    actual: format!("Snapshot has {} samples", snapshot.nrows()),
498                });
499            }
500        }
501
502        // Construct graphs for each snapshot
503        let graphs = self.construct_temporal_graphs(snapshots)?;
504
505        // Aggregate temporal graphs
506        let aggregated_graph = self.aggregate_temporal_graphs(&graphs)?;
507
508        Ok(aggregated_graph)
509    }
510
511    /// Construct graphs for temporal snapshots
512    fn construct_temporal_graphs(
513        &self,
514        snapshots: &[ArrayView2<f64>],
515    ) -> Result<Vec<Array2<f64>>, SklearsError> {
516        let mut graphs = Vec::new();
517
518        for snapshot in snapshots.iter() {
519            let graph = self.construct_knn_graph(snapshot)?;
520            graphs.push(graph);
521        }
522
523        Ok(graphs)
524    }
525
526    /// Construct k-NN graph from snapshot data
527    fn construct_knn_graph(&self, X: &ArrayView2<f64>) -> Result<Array2<f64>, SklearsError> {
528        let n_samples = X.nrows();
529        let mut graph = Array2::<f64>::zeros((n_samples, n_samples));
530
531        for i in 0..n_samples {
532            let mut distances: Vec<(f64, usize)> = Vec::new();
533
534            for j in 0..n_samples {
535                if i != j {
536                    let dist = self.euclidean_distance(&X.row(i), &X.row(j));
537                    distances.push((dist, j));
538                }
539            }
540
541            // Sort by distance and take k nearest neighbors
542            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
543
544            for (dist, j) in distances.iter().take(self.k_neighbors.min(distances.len())) {
545                let weight = (-dist.powi(2) / 2.0).exp(); // RBF kernel
546                graph[[i, *j]] = weight;
547            }
548        }
549
550        // Make graph symmetric
551        for i in 0..n_samples {
552            for j in i + 1..n_samples {
553                let avg_weight = (graph[[i, j]] + graph[[j, i]]) / 2.0;
554                graph[[i, j]] = avg_weight;
555                graph[[j, i]] = avg_weight;
556            }
557        }
558
559        Ok(graph)
560    }
561
562    /// Aggregate temporal graphs based on the aggregation method
563    fn aggregate_temporal_graphs(
564        &self,
565        graphs: &[Array2<f64>],
566    ) -> Result<Array2<f64>, SklearsError> {
567        if graphs.is_empty() {
568            return Err(SklearsError::InvalidInput(
569                "No graphs to aggregate".to_string(),
570            ));
571        }
572
573        let n_samples = graphs[0].nrows();
574        let mut aggregated = Array2::<f64>::zeros((n_samples, n_samples));
575
576        match self.aggregation_method.as_str() {
577            "mean" => {
578                for graph in graphs.iter() {
579                    aggregated += graph;
580                }
581                aggregated /= graphs.len() as f64;
582            }
583            "weighted" => {
584                let total_weight: f64 = (0..graphs.len())
585                    .map(|i| self.temporal_decay.powi(i as i32))
586                    .sum();
587
588                for (i, graph) in graphs.iter().enumerate() {
589                    let weight = self.temporal_decay.powi(i as i32) / total_weight;
590                    aggregated += &(graph * weight);
591                }
592            }
593            "attention" => {
594                // Simple attention mechanism - in practice this would be more sophisticated
595                let weights = self.compute_attention_weights(graphs)?;
596                for (i, graph) in graphs.iter().enumerate() {
597                    aggregated += &(graph * weights[i]);
598                }
599            }
600            _ => {
601                return Err(SklearsError::InvalidInput(format!(
602                    "Unknown aggregation method: {}",
603                    self.aggregation_method
604                )));
605            }
606        }
607
608        Ok(aggregated)
609    }
610
611    /// Compute attention weights for temporal graphs
612    fn compute_attention_weights(&self, graphs: &[Array2<f64>]) -> Result<Vec<f64>, SklearsError> {
613        let n_graphs = graphs.len();
614        let mut weights = vec![1.0 / n_graphs as f64; n_graphs];
615
616        // Simple implementation: weight by graph density
617        let mut densities = Vec::new();
618        for graph in graphs.iter() {
619            let density = graph.iter().filter(|&&x| x > 0.0).count() as f64 / (graph.len() as f64);
620            densities.push(density);
621        }
622
623        let total_density: f64 = densities.iter().sum();
624        if total_density > 0.0 {
625            for (i, density) in densities.iter().enumerate() {
626                weights[i] = density / total_density;
627            }
628        }
629
630        Ok(weights)
631    }
632
633    /// Compute Euclidean distance between two vectors
634    fn euclidean_distance(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
635        x1.iter()
636            .zip(x2.iter())
637            .map(|(a, b)| (a - b).powi(2))
638            .sum::<f64>()
639            .sqrt()
640    }
641}
642
643impl Default for TemporalGraphLearning {
644    fn default() -> Self {
645        Self::new()
646    }
647}
648
649#[allow(non_snake_case)]
650#[cfg(test)]
651mod tests {
652    use super::*;
653    use approx::assert_abs_diff_eq;
654    use scirs2_core::array;
655
656    #[test]
657    fn test_multi_view_graph_learning() {
658        let view1 = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
659        let view2 = array![[2.0, 1.0], [3.0, 2.0], [4.0, 3.0]];
660        let views = vec![view1.view(), view2.view()];
661
662        let mvgl = MultiViewGraphLearning::new()
663            .k_neighbors(2)
664            .combination_method("weighted".to_string());
665
666        let result = mvgl.fit(&views);
667        assert!(result.is_ok());
668
669        let graph = result.unwrap();
670        assert_eq!(graph.dim(), (3, 3));
671
672        // Check that diagonal is zero (no self-loops)
673        assert_eq!(graph[[0, 0]], 0.0);
674        assert_eq!(graph[[1, 1]], 0.0);
675        assert_eq!(graph[[2, 2]], 0.0);
676
677        // Check symmetry
678        assert_abs_diff_eq!(graph[[0, 1]], graph[[1, 0]], epsilon = 1e-10);
679        assert_abs_diff_eq!(graph[[0, 2]], graph[[2, 0]], epsilon = 1e-10);
680        assert_abs_diff_eq!(graph[[1, 2]], graph[[2, 1]], epsilon = 1e-10);
681    }
682
683    #[test]
684    fn test_multi_view_graph_union() {
685        let view1 = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
686        let view2 = array![[2.0, 1.0], [3.0, 2.0], [4.0, 3.0]];
687        let views = vec![view1.view(), view2.view()];
688
689        let mvgl = MultiViewGraphLearning::new()
690            .k_neighbors(2)
691            .combination_method("union".to_string());
692
693        let result = mvgl.fit(&views);
694        assert!(result.is_ok());
695
696        let graph = result.unwrap();
697        assert_eq!(graph.dim(), (3, 3));
698    }
699
700    #[test]
701    fn test_multi_view_graph_adaptive() {
702        let view1 = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
703        let view2 = array![[2.0, 1.0], [3.0, 2.0], [4.0, 3.0]];
704        let views = vec![view1.view(), view2.view()];
705
706        let mvgl = MultiViewGraphLearning::new()
707            .k_neighbors(2)
708            .combination_method("adaptive".to_string())
709            .max_iter(10)
710            .tolerance(1e-4);
711
712        let result = mvgl.fit(&views);
713        assert!(result.is_ok());
714
715        let graph = result.unwrap();
716        assert_eq!(graph.dim(), (3, 3));
717    }
718
719    #[test]
720    fn test_heterogeneous_graph_learning() {
721        let type1_data = array![[1.0, 2.0], [2.0, 3.0]];
722        let type2_data = array![[3.0, 4.0], [4.0, 5.0]];
723        let mut data = HashMap::new();
724        data.insert("type1".to_string(), type1_data.view());
725        data.insert("type2".to_string(), type2_data.view());
726
727        let hgl = HeterogeneousGraphLearning::new()
728            .node_types(vec!["type1".to_string(), "type2".to_string()]);
729
730        let result = hgl.fit(&data);
731        assert!(result.is_ok());
732
733        let embeddings = result.unwrap();
734        assert!(embeddings.contains_key("type1"));
735        assert!(embeddings.contains_key("type2"));
736        assert_eq!(embeddings["type1"].dim(), (2, 2));
737        assert_eq!(embeddings["type2"].dim(), (2, 2));
738    }
739
740    #[test]
741    fn test_temporal_graph_learning() {
742        let snapshot1 = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
743        let snapshot2 = array![[1.1, 2.1], [2.1, 3.1], [3.1, 4.1]];
744        let snapshot3 = array![[1.2, 2.2], [2.2, 3.2], [3.2, 4.2]];
745        let snapshots = vec![snapshot1.view(), snapshot2.view(), snapshot3.view()];
746
747        let tgl = TemporalGraphLearning::new()
748            .window_size(3)
749            .temporal_decay(0.9)
750            .aggregation_method("weighted".to_string())
751            .k_neighbors(2);
752
753        let result = tgl.fit(&snapshots);
754        assert!(result.is_ok());
755
756        let graph = result.unwrap();
757        assert_eq!(graph.dim(), (3, 3));
758
759        // Check that diagonal is zero (no self-loops)
760        assert_eq!(graph[[0, 0]], 0.0);
761        assert_eq!(graph[[1, 1]], 0.0);
762        assert_eq!(graph[[2, 2]], 0.0);
763    }
764
765    #[test]
766    fn test_temporal_graph_attention() {
767        let snapshot1 = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
768        let snapshot2 = array![[1.1, 2.1], [2.1, 3.1], [3.1, 4.1]];
769        let snapshots = vec![snapshot1.view(), snapshot2.view()];
770
771        let tgl = TemporalGraphLearning::new()
772            .aggregation_method("attention".to_string())
773            .k_neighbors(2);
774
775        let result = tgl.fit(&snapshots);
776        assert!(result.is_ok());
777
778        let graph = result.unwrap();
779        assert_eq!(graph.dim(), (3, 3));
780    }
781
782    #[test]
783    fn test_multi_view_graph_error_cases() {
784        let mvgl = MultiViewGraphLearning::new();
785
786        // Test with empty views
787        let result = mvgl.fit(&[]);
788        assert!(result.is_err());
789
790        // Test with mismatched dimensions
791        let view1 = array![[1.0, 2.0], [2.0, 3.0]];
792        let view2 = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
793        let views = vec![view1.view(), view2.view()];
794
795        let result = mvgl.fit(&views);
796        assert!(result.is_err());
797    }
798
799    #[test]
800    fn test_temporal_graph_error_cases() {
801        let tgl = TemporalGraphLearning::new();
802
803        // Test with empty snapshots
804        let result = tgl.fit(&[]);
805        assert!(result.is_err());
806
807        // Test with mismatched dimensions
808        let snapshot1 = array![[1.0, 2.0], [2.0, 3.0]];
809        let snapshot2 = array![[3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
810        let snapshots = vec![snapshot1.view(), snapshot2.view()];
811
812        let result = tgl.fit(&snapshots);
813        assert!(result.is_err());
814    }
815}