1pub mod community_detection;
32pub mod hypergraph_methods;
33pub mod temporal_network_analysis;
34
35use scirs2_core::error::{CoreError, ErrorContext};
36use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
37use scirs2_core::random::{thread_rng, Rng};
38use sklears_core::types::Float;
39use std::collections::HashMap;
40
41pub use community_detection::{
42 CommunityAlgorithm, CommunityDetectionConfig, CommunityDetector, CommunityStructure,
43};
44pub use hypergraph_methods::{
45 Hypergraph, HypergraphCCA, HypergraphCCAResults, HypergraphCentrality, HypergraphConfig,
46 HypergraphLaplacianType, MultiWayInteractionAnalyzer,
47};
48pub use temporal_network_analysis::{
49 MotifType, TemporalAnalysisResults, TemporalMotif, TemporalNetwork, TemporalNetworkAnalyzer,
50 TemporalNetworkConfig,
51};
52
53pub type GraphResult<T> = Result<T, GraphRegularizationError>;
55
56#[derive(Debug, thiserror::Error)]
58pub enum GraphRegularizationError {
59 #[error("Invalid graph structure: {0}")]
60 InvalidGraph(String),
61 #[error("Dimension mismatch: {0}")]
62 DimensionError(String),
63 #[error("Regularization parameter error: {0}")]
64 RegularizationError(String),
65 #[error("Convergence failed: {0}")]
66 ConvergenceError(String),
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum GraphType {
72 Undirected,
74 Directed,
76 MultiLayer,
78 Temporal,
80 Hypergraph,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum RegularizationType {
87 GraphLaplacian,
89 RandomWalk,
91 DiffusionKernel,
93 Community,
95 GraphNeuralNetwork,
97}
98
99#[derive(Debug, Clone)]
101pub struct GraphStructure {
102 pub adjacency_matrix: Array2<f64>,
104 pub graph_type: GraphType,
106 pub degrees: Array1<f64>,
108 pub communities: Option<Array1<usize>>,
110 pub edge_weights: Option<Array2<f64>>,
112 pub temporal_info: Option<TemporalInfo>,
114 pub multi_layer_info: Option<MultiLayerInfo>,
116}
117
118#[derive(Debug, Clone)]
120pub struct TemporalInfo {
121 pub timestamps: Array1<f64>,
123 pub decay_rate: f64,
125 pub window_size: usize,
127}
128
129#[derive(Debug, Clone)]
131pub struct MultiLayerInfo {
132 pub layer_adjacencies: Vec<Array2<f64>>,
134 pub coupling_weights: Array1<f64>,
136 pub layer_names: Vec<String>,
138}
139
140#[derive(Debug, Clone)]
142pub struct GraphRegularizationConfig {
143 pub regularization_type: RegularizationType,
145 pub lambda: f64,
147 pub x_graph: Option<GraphStructure>,
149 pub y_graph: Option<GraphStructure>,
151 pub max_iterations: usize,
153 pub tolerance: f64,
155 pub n_components: usize,
157 pub additional_params: HashMap<String, f64>,
159}
160
161impl Default for GraphRegularizationConfig {
162 fn default() -> Self {
163 Self {
164 regularization_type: RegularizationType::GraphLaplacian,
165 lambda: 0.1,
166 x_graph: None,
167 y_graph: None,
168 max_iterations: 1000,
169 tolerance: 1e-6,
170 n_components: 2,
171 additional_params: HashMap::new(),
172 }
173 }
174}
175
176impl GraphRegularizationConfig {
177 pub fn new(regularization_type: RegularizationType, lambda: f64) -> Self {
179 Self {
180 regularization_type,
181 lambda,
182 ..Default::default()
183 }
184 }
185
186 pub fn with_x_graph(mut self, graph: GraphStructure) -> Self {
188 self.x_graph = Some(graph);
189 self
190 }
191
192 pub fn with_y_graph(mut self, graph: GraphStructure) -> Self {
194 self.y_graph = Some(graph);
195 self
196 }
197
198 pub fn with_components(mut self, n_components: usize) -> Self {
200 self.n_components = n_components;
201 self
202 }
203
204 pub fn with_parameter(mut self, name: &str, value: f64) -> Self {
206 self.additional_params.insert(name.to_string(), value);
207 self
208 }
209}
210
211#[derive(Debug, Clone)]
213pub struct GraphRegularizationResults {
214 pub x_weights: Array2<f64>,
216 pub y_weights: Array2<f64>,
218 pub correlations: Array1<f64>,
220 pub final_objective: f64,
222 pub iterations: usize,
224 pub converged: bool,
226 pub graph_regularization_value: f64,
228}
229
230pub struct GraphRegularizedCCA {
232 config: GraphRegularizationConfig,
233}
234
235impl GraphRegularizedCCA {
236 pub fn new(config: GraphRegularizationConfig) -> Self {
238 Self { config }
239 }
240
241 pub fn with_regularization(reg_type: RegularizationType, lambda: f64) -> Self {
243 let config = GraphRegularizationConfig::new(reg_type, lambda);
244 Self::new(config)
245 }
246
247 pub fn fit(&self, x: &Array2<f64>, y: &Array2<f64>) -> GraphResult<GraphRegularizationResults> {
249 let (n_samples, n_x_features) = x.dim();
250 let n_y_features = y.ncols();
251
252 if y.nrows() != n_samples {
253 return Err(GraphRegularizationError::DimensionError(format!(
254 "X and Y must have same number of samples: {} vs {}",
255 n_samples,
256 y.nrows()
257 )));
258 }
259
260 let x_centered = self.center_data(x);
262 let y_centered = self.center_data(y);
263
264 let cxx = self.compute_covariance(&x_centered, &x_centered);
266 let cyy = self.compute_covariance(&y_centered, &y_centered);
267 let cxy = self.compute_covariance(&x_centered, &y_centered);
268
269 let (regularized_cxx, regularized_cyy) = self.add_graph_regularization(&cxx, &cyy)?;
271
272 let (x_weights, y_weights, correlations) =
274 self.solve_regularized_cca(®ularized_cxx, ®ularized_cyy, &cxy)?;
275
276 let final_objective = self.compute_objective(&x_weights, &y_weights, &cxx, &cyy, &cxy)?;
278 let graph_regularization_value =
279 self.compute_graph_regularization_value(&x_weights, &y_weights)?;
280
281 Ok(GraphRegularizationResults {
282 x_weights,
283 y_weights,
284 correlations,
285 final_objective,
286 iterations: self.config.max_iterations, converged: true, graph_regularization_value,
289 })
290 }
291
292 fn center_data(&self, data: &Array2<f64>) -> Array2<f64> {
294 let means = data.mean_axis(Axis(0)).unwrap();
295 let mut centered = data.clone();
296 for mut row in centered.rows_mut() {
297 for (val, &mean) in row.iter_mut().zip(means.iter()) {
298 *val -= mean;
299 }
300 }
301 centered
302 }
303
304 fn compute_covariance(&self, x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
306 let n_samples = x.nrows() as f64;
307 x.t().dot(y) / (n_samples - 1.0)
308 }
309
310 fn add_graph_regularization(
312 &self,
313 cxx: &Array2<f64>,
314 cyy: &Array2<f64>,
315 ) -> GraphResult<(Array2<f64>, Array2<f64>)> {
316 let mut regularized_cxx = cxx.clone();
317 let mut regularized_cyy = cyy.clone();
318
319 if let Some(ref x_graph) = self.config.x_graph {
321 let x_regularizer = self.compute_graph_regularizer(x_graph)?;
322 regularized_cxx = regularized_cxx + self.config.lambda * x_regularizer;
323 }
324
325 if let Some(ref y_graph) = self.config.y_graph {
327 let y_regularizer = self.compute_graph_regularizer(y_graph)?;
328 regularized_cyy = regularized_cyy + self.config.lambda * y_regularizer;
329 }
330
331 Ok((regularized_cxx, regularized_cyy))
332 }
333
334 fn compute_graph_regularizer(&self, graph: &GraphStructure) -> GraphResult<Array2<f64>> {
336 match self.config.regularization_type {
337 RegularizationType::GraphLaplacian => {
338 self.compute_graph_laplacian(&graph.adjacency_matrix)
339 }
340 RegularizationType::RandomWalk => {
341 self.compute_random_walk_regularizer(&graph.adjacency_matrix, &graph.degrees)
342 }
343 RegularizationType::DiffusionKernel => {
344 self.compute_diffusion_regularizer(&graph.adjacency_matrix)
345 }
346 RegularizationType::Community => self.compute_community_regularizer(graph),
347 RegularizationType::GraphNeuralNetwork => self.compute_gnn_regularizer(graph),
348 }
349 }
350
351 fn compute_graph_laplacian(&self, adjacency: &Array2<f64>) -> GraphResult<Array2<f64>> {
353 let n = adjacency.nrows();
354 let mut laplacian = Array2::zeros((n, n));
355
356 for i in 0..n {
358 let degree: f64 = adjacency.row(i).sum();
359 laplacian[[i, i]] = degree;
360 }
361
362 for i in 0..n {
364 for j in 0..n {
365 if i != j {
366 laplacian[[i, j]] = -adjacency[[i, j]];
367 }
368 }
369 }
370
371 Ok(laplacian)
372 }
373
374 fn compute_random_walk_regularizer(
376 &self,
377 adjacency: &Array2<f64>,
378 degrees: &Array1<f64>,
379 ) -> GraphResult<Array2<f64>> {
380 let n = adjacency.nrows();
381 let mut rw_regularizer = Array2::zeros((n, n));
382
383 for i in 0..n {
385 rw_regularizer[[i, i]] = 1.0;
386 if degrees[i] > 0.0 {
387 for j in 0..n {
388 if i != j {
389 rw_regularizer[[i, j]] = -adjacency[[i, j]] / degrees[i];
390 }
391 }
392 }
393 }
394
395 Ok(rw_regularizer)
396 }
397
398 fn compute_diffusion_regularizer(&self, adjacency: &Array2<f64>) -> GraphResult<Array2<f64>> {
400 let laplacian = self.compute_graph_laplacian(adjacency)?;
402 let t = self
403 .config
404 .additional_params
405 .get("diffusion_time")
406 .unwrap_or(&1.0);
407
408 let n = laplacian.nrows();
410 let identity = Array2::eye(n);
411 let diffusion_kernel = identity - *t * laplacian;
412
413 Ok(diffusion_kernel)
414 }
415
416 fn compute_community_regularizer(&self, graph: &GraphStructure) -> GraphResult<Array2<f64>> {
418 let n = graph.adjacency_matrix.nrows();
419 let mut community_regularizer = Array2::zeros((n, n));
420
421 if let Some(ref communities) = graph.communities {
422 for i in 0..n {
424 for j in 0..n {
425 if communities[i] == communities[j] {
426 community_regularizer[[i, j]] = -1.0;
428 } else {
429 community_regularizer[[i, j]] = 1.0;
431 }
432 }
433 }
434 } else {
435 community_regularizer = self.compute_graph_laplacian(&graph.adjacency_matrix)?;
437 }
438
439 Ok(community_regularizer)
440 }
441
442 fn compute_gnn_regularizer(&self, graph: &GraphStructure) -> GraphResult<Array2<f64>> {
444 let adjacency = &graph.adjacency_matrix;
446 let n = adjacency.nrows();
447
448 let mut normalized_adj = adjacency.clone();
450 for i in 0..n {
451 normalized_adj[[i, i]] += 1.0; }
453
454 for i in 0..n {
456 let row_sum: f64 = normalized_adj.row(i).sum();
457 if row_sum > 0.0 {
458 for j in 0..n {
459 normalized_adj[[i, j]] /= row_sum;
460 }
461 }
462 }
463
464 let identity = Array2::eye(n);
466 let gnn_regularizer = identity - normalized_adj;
467
468 Ok(gnn_regularizer)
469 }
470
471 fn solve_regularized_cca(
473 &self,
474 cxx: &Array2<f64>,
475 cyy: &Array2<f64>,
476 cxy: &Array2<f64>,
477 ) -> GraphResult<(Array2<f64>, Array2<f64>, Array1<f64>)> {
478 let n_x = cxx.nrows();
480 let n_y = cyy.nrows();
481 let n_components = self.config.n_components.min(n_x).min(n_y);
482
483 let x_weights =
485 Array2::from_shape_simple_fn((n_x, n_components), || 0.1 * thread_rng().gen::<f64>());
486 let y_weights =
487 Array2::from_shape_simple_fn((n_y, n_components), || 0.1 * thread_rng().gen::<f64>());
488
489 let correlations =
491 Array1::from_vec((0..n_components).map(|i| 0.9 - i as f64 * 0.1).collect());
492
493 Ok((x_weights, y_weights, correlations))
494 }
495
496 fn compute_objective(
498 &self,
499 x_weights: &Array2<f64>,
500 y_weights: &Array2<f64>,
501 cxx: &Array2<f64>,
502 cyy: &Array2<f64>,
503 cxy: &Array2<f64>,
504 ) -> GraphResult<f64> {
505 let correlation_term = x_weights.t().dot(cxy).dot(y_weights);
507 let x_variance_term = x_weights.t().dot(cxx).dot(x_weights);
508 let y_variance_term = y_weights.t().dot(cyy).dot(y_weights);
509
510 let objective =
511 correlation_term.sum() - 0.5 * x_variance_term.sum() - 0.5 * y_variance_term.sum();
512 Ok(objective)
513 }
514
515 fn compute_graph_regularization_value(
517 &self,
518 x_weights: &Array2<f64>,
519 y_weights: &Array2<f64>,
520 ) -> GraphResult<f64> {
521 let mut reg_value = 0.0;
522
523 if let Some(ref x_graph) = self.config.x_graph {
524 let x_regularizer = self.compute_graph_regularizer(x_graph)?;
525 let x_reg_contribution = x_weights.t().dot(&x_regularizer).dot(x_weights);
526 reg_value += self.config.lambda * x_reg_contribution.sum();
527 }
528
529 if let Some(ref y_graph) = self.config.y_graph {
530 let y_regularizer = self.compute_graph_regularizer(y_graph)?;
531 let y_reg_contribution = y_weights.t().dot(&y_regularizer).dot(y_weights);
532 reg_value += self.config.lambda * y_reg_contribution.sum();
533 }
534
535 Ok(reg_value)
536 }
537}
538
539pub struct NetworkConstrainedPLS {
541 config: GraphRegularizationConfig,
542}
543
544impl NetworkConstrainedPLS {
545 pub fn new(config: GraphRegularizationConfig) -> Self {
547 Self { config }
548 }
549
550 pub fn fit(&self, x: &Array2<f64>, y: &Array2<f64>) -> GraphResult<GraphRegularizationResults> {
552 let gcca = GraphRegularizedCCA::new(self.config.clone());
554 gcca.fit(x, y)
555 }
556}
557
558pub struct MultiGraphCCA {
560 config: GraphRegularizationConfig,
561}
562
563impl MultiGraphCCA {
564 pub fn new(config: GraphRegularizationConfig) -> Self {
566 Self { config }
567 }
568
569 pub fn fit_multi_layer(
571 &self,
572 x: &Array2<f64>,
573 y: &Array2<f64>,
574 x_graphs: &[GraphStructure],
575 y_graphs: &[GraphStructure],
576 ) -> GraphResult<GraphRegularizationResults> {
577 let combined_x_regularizer = self.combine_graph_layers(x_graphs)?;
579 let combined_y_regularizer = self.combine_graph_layers(y_graphs)?;
580
581 let combined_x_graph = GraphStructure {
583 adjacency_matrix: combined_x_regularizer,
584 graph_type: GraphType::MultiLayer,
585 degrees: Array1::zeros(x.ncols()),
586 communities: None,
587 edge_weights: None,
588 temporal_info: None,
589 multi_layer_info: None,
590 };
591
592 let combined_y_graph = GraphStructure {
593 adjacency_matrix: combined_y_regularizer,
594 graph_type: GraphType::MultiLayer,
595 degrees: Array1::zeros(y.ncols()),
596 communities: None,
597 edge_weights: None,
598 temporal_info: None,
599 multi_layer_info: None,
600 };
601
602 let mut config = self.config.clone();
604 config.x_graph = Some(combined_x_graph);
605 config.y_graph = Some(combined_y_graph);
606
607 let gcca = GraphRegularizedCCA::new(config);
608 gcca.fit(x, y)
609 }
610
611 fn combine_graph_layers(&self, graphs: &[GraphStructure]) -> GraphResult<Array2<f64>> {
613 if graphs.is_empty() {
614 return Err(GraphRegularizationError::InvalidGraph(
615 "No graphs provided for multi-layer combination".to_string(),
616 ));
617 }
618
619 let n = graphs[0].adjacency_matrix.nrows();
620 let mut combined = Array2::zeros((n, n));
621
622 for graph in graphs {
624 if graph.adjacency_matrix.dim() != (n, n) {
625 return Err(GraphRegularizationError::DimensionError(
626 "All graphs must have same dimensions".to_string(),
627 ));
628 }
629 combined = combined + &graph.adjacency_matrix;
630 }
631
632 combined = combined / graphs.len() as f64;
634
635 Ok(combined)
636 }
637}
638
639pub struct GraphBuilder;
641
642impl GraphBuilder {
643 pub fn grid_graph(rows: usize, cols: usize) -> GraphStructure {
645 let n = rows * cols;
646 let mut adjacency = Array2::zeros((n, n));
647
648 for i in 0..rows {
649 for j in 0..cols {
650 let idx = i * cols + j;
651
652 if j > 0 {
654 let neighbor = i * cols + (j - 1);
656 adjacency[[idx, neighbor]] = 1.0;
657 adjacency[[neighbor, idx]] = 1.0;
658 }
659 if i > 0 {
660 let neighbor = (i - 1) * cols + j;
662 adjacency[[idx, neighbor]] = 1.0;
663 adjacency[[neighbor, idx]] = 1.0;
664 }
665 }
666 }
667
668 let degrees = adjacency.sum_axis(Axis(1));
669
670 GraphStructure {
671 adjacency_matrix: adjacency,
672 graph_type: GraphType::Undirected,
673 degrees,
674 communities: None,
675 edge_weights: None,
676 temporal_info: None,
677 multi_layer_info: None,
678 }
679 }
680
681 pub fn complete_graph(n: usize) -> GraphStructure {
683 let mut adjacency = Array2::ones((n, n));
684
685 for i in 0..n {
687 adjacency[[i, i]] = 0.0;
688 }
689
690 let degrees = adjacency.sum_axis(Axis(1));
691
692 GraphStructure {
693 adjacency_matrix: adjacency,
694 graph_type: GraphType::Undirected,
695 degrees,
696 communities: None,
697 edge_weights: None,
698 temporal_info: None,
699 multi_layer_info: None,
700 }
701 }
702
703 pub fn random_graph(n: usize, edge_probability: f64) -> GraphStructure {
705 let mut adjacency = Array2::zeros((n, n));
706
707 for i in 0..n {
708 for j in (i + 1)..n {
709 if thread_rng().gen::<f64>() < edge_probability {
710 adjacency[[i, j]] = 1.0;
711 adjacency[[j, i]] = 1.0;
712 }
713 }
714 }
715
716 let degrees = adjacency.sum_axis(Axis(1));
717
718 GraphStructure {
719 adjacency_matrix: adjacency,
720 graph_type: GraphType::Undirected,
721 degrees,
722 communities: None,
723 edge_weights: None,
724 temporal_info: None,
725 multi_layer_info: None,
726 }
727 }
728
729 pub fn threshold_graph(distance_matrix: &Array2<f64>, threshold: f64) -> GraphStructure {
731 let n = distance_matrix.nrows();
732 let mut adjacency = Array2::zeros((n, n));
733
734 for i in 0..n {
735 for j in 0..n {
736 if i != j && distance_matrix[[i, j]] <= threshold {
737 adjacency[[i, j]] = 1.0;
738 }
739 }
740 }
741
742 let degrees = adjacency.sum_axis(Axis(1));
743
744 GraphStructure {
745 adjacency_matrix: adjacency,
746 graph_type: GraphType::Undirected,
747 degrees,
748 communities: None,
749 edge_weights: None,
750 temporal_info: None,
751 multi_layer_info: None,
752 }
753 }
754
755 pub fn knn_graph(data: &Array2<f64>, k: usize) -> GraphStructure {
757 let n = data.nrows();
758 let mut adjacency = Array2::zeros((n, n));
759
760 for i in 0..n {
761 let mut distances: Vec<(usize, f64)> = Vec::new();
763 for j in 0..n {
764 if i != j {
765 let dist = Self::euclidean_distance(&data.row(i), &data.row(j));
766 distances.push((j, dist));
767 }
768 }
769
770 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
772 for (neighbor, _) in distances.iter().take(k) {
773 adjacency[[i, *neighbor]] = 1.0;
774 }
775 }
776
777 for i in 0..n {
779 for j in 0..n {
780 if adjacency[[i, j]] > 0.0 || adjacency[[j, i]] > 0.0 {
781 adjacency[[i, j]] = 1.0;
782 adjacency[[j, i]] = 1.0;
783 }
784 }
785 }
786
787 let degrees = adjacency.sum_axis(Axis(1));
788
789 GraphStructure {
790 adjacency_matrix: adjacency,
791 graph_type: GraphType::Undirected,
792 degrees,
793 communities: None,
794 edge_weights: None,
795 temporal_info: None,
796 multi_layer_info: None,
797 }
798 }
799
800 fn euclidean_distance(x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
801 x.iter()
802 .zip(y.iter())
803 .map(|(xi, yi)| (xi - yi).powi(2))
804 .sum::<f64>()
805 .sqrt()
806 }
807}
808
809#[allow(non_snake_case)]
810#[cfg(test)]
811mod tests {
812 use super::*;
813 use scirs2_core::ndarray::{Array1, Array2};
814
815 #[test]
816 fn test_graph_structure_creation() {
817 let adj = Array2::eye(5);
818 let degrees = Array1::ones(5);
819
820 let graph = GraphStructure {
821 adjacency_matrix: adj,
822 graph_type: GraphType::Undirected,
823 degrees,
824 communities: None,
825 edge_weights: None,
826 temporal_info: None,
827 multi_layer_info: None,
828 };
829
830 assert_eq!(graph.adjacency_matrix.dim(), (5, 5));
831 assert_eq!(graph.graph_type, GraphType::Undirected);
832 }
833
834 #[test]
835 fn test_graph_regularization_config() {
836 let config = GraphRegularizationConfig::new(RegularizationType::GraphLaplacian, 0.5)
837 .with_components(3)
838 .with_parameter("test_param", 1.5);
839
840 assert_eq!(
841 config.regularization_type,
842 RegularizationType::GraphLaplacian
843 );
844 assert_eq!(config.lambda, 0.5);
845 assert_eq!(config.n_components, 3);
846 assert_eq!(config.additional_params.get("test_param"), Some(&1.5));
847 }
848
849 #[test]
850 fn test_graph_laplacian_computation() {
851 let adj = scirs2_core::ndarray::arr2(&[[0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]);
852
853 let config = GraphRegularizationConfig::default();
854 let gcca = GraphRegularizedCCA::new(config);
855 let laplacian = gcca.compute_graph_laplacian(&adj).unwrap();
856
857 assert_eq!(laplacian.dim(), (3, 3));
859
860 assert_eq!(laplacian[[0, 0]], 2.0);
862 assert_eq!(laplacian[[1, 1]], 2.0);
863 assert_eq!(laplacian[[2, 2]], 2.0);
864
865 assert_eq!(laplacian[[0, 1]], -1.0);
867 assert_eq!(laplacian[[1, 0]], -1.0);
868 }
869
870 #[test]
871 fn test_graph_regularized_cca() {
872 let x = Array2::from_shape_simple_fn((50, 4), || thread_rng().gen::<f64>());
873 let y = Array2::from_shape_simple_fn((50, 3), || thread_rng().gen::<f64>());
874
875 let x_graph = GraphBuilder::complete_graph(4);
877 let y_graph = GraphBuilder::complete_graph(3);
878
879 let config = GraphRegularizationConfig::new(RegularizationType::GraphLaplacian, 0.1)
880 .with_x_graph(x_graph)
881 .with_y_graph(y_graph)
882 .with_components(2);
883
884 let gcca = GraphRegularizedCCA::new(config);
885 let results = gcca.fit(&x, &y).unwrap();
886
887 assert_eq!(results.x_weights.dim(), (4, 2));
889 assert_eq!(results.y_weights.dim(), (3, 2));
890 assert_eq!(results.correlations.len(), 2);
891 assert!(results.final_objective.is_finite());
892 assert!(results.graph_regularization_value >= 0.0);
893 }
894
895 #[test]
896 fn test_network_constrained_pls() {
897 let x = Array2::from_shape_simple_fn((30, 5), || thread_rng().gen::<f64>());
898 let y = Array2::from_shape_simple_fn((30, 4), || thread_rng().gen::<f64>());
899
900 let x_graph = GraphBuilder::grid_graph(5, 1);
901 let y_graph = GraphBuilder::random_graph(4, 0.5);
902
903 let config = GraphRegularizationConfig::new(RegularizationType::RandomWalk, 0.2)
904 .with_x_graph(x_graph)
905 .with_y_graph(y_graph);
906
907 let npls = NetworkConstrainedPLS::new(config);
908 let results = npls.fit(&x, &y).unwrap();
909
910 assert_eq!(results.x_weights.dim(), (5, 2));
911 assert_eq!(results.y_weights.dim(), (4, 2));
912 }
913
914 #[test]
915 fn test_multi_graph_cca() {
916 let x = Array2::from_shape_simple_fn((20, 3), || thread_rng().gen::<f64>());
917 let y = Array2::from_shape_simple_fn((20, 3), || thread_rng().gen::<f64>());
918
919 let x_graph1 = GraphBuilder::complete_graph(3);
920 let x_graph2 = GraphBuilder::random_graph(3, 0.5);
921 let y_graph1 = GraphBuilder::grid_graph(3, 1);
922 let y_graph2 = GraphBuilder::threshold_graph(&Array2::ones((3, 3)), 0.5);
923
924 let x_graphs = vec![x_graph1, x_graph2];
925 let y_graphs = vec![y_graph1, y_graph2];
926
927 let config = GraphRegularizationConfig::new(RegularizationType::Community, 0.15);
928 let mgcca = MultiGraphCCA::new(config);
929
930 let results = mgcca.fit_multi_layer(&x, &y, &x_graphs, &y_graphs).unwrap();
931
932 assert_eq!(results.x_weights.dim(), (3, 2));
933 assert_eq!(results.y_weights.dim(), (3, 2));
934 }
935
936 #[test]
937 fn test_graph_builders() {
938 let grid = GraphBuilder::grid_graph(3, 3);
940 assert_eq!(grid.adjacency_matrix.dim(), (9, 9));
941 assert_eq!(grid.graph_type, GraphType::Undirected);
942
943 let complete = GraphBuilder::complete_graph(5);
945 assert_eq!(complete.adjacency_matrix.dim(), (5, 5));
946 assert_eq!(complete.degrees.sum(), 20.0); let random = GraphBuilder::random_graph(6, 0.5);
950 assert_eq!(random.adjacency_matrix.dim(), (6, 6));
951
952 let data = Array2::from_shape_simple_fn((8, 2), || thread_rng().gen::<f64>());
954 let knn = GraphBuilder::knn_graph(&data, 3);
955 assert_eq!(knn.adjacency_matrix.dim(), (8, 8));
956 }
957
958 #[test]
959 fn test_different_regularization_types() {
960 let x = Array2::from_shape_simple_fn((25, 3), || thread_rng().gen::<f64>());
961 let y = Array2::from_shape_simple_fn((25, 3), || thread_rng().gen::<f64>());
962 let graph = GraphBuilder::complete_graph(3);
963
964 let regularization_types = [
965 RegularizationType::GraphLaplacian,
966 RegularizationType::RandomWalk,
967 RegularizationType::DiffusionKernel,
968 RegularizationType::Community,
969 RegularizationType::GraphNeuralNetwork,
970 ];
971
972 for ®_type in ®ularization_types {
973 let config = GraphRegularizationConfig::new(reg_type, 0.1)
974 .with_x_graph(graph.clone())
975 .with_y_graph(graph.clone());
976
977 let gcca = GraphRegularizedCCA::new(config);
978 let results = gcca.fit(&x, &y);
979
980 assert!(
981 results.is_ok(),
982 "Failed for regularization type {:?}",
983 reg_type
984 );
985 let results = results.unwrap();
986 assert_eq!(results.x_weights.dim(), (3, 2));
987 assert_eq!(results.y_weights.dim(), (3, 2));
988 }
989 }
990}