1use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
16use scirs2_core::random::{thread_rng, Random, Rng};
17use sklears_core::error::SklearsError;
18use sklears_core::types::Float;
19use std::collections::{HashMap, HashSet};
20
21#[derive(Debug, Clone)]
23pub struct Hypergraph {
24 pub n_vertices: usize,
26 pub n_hyperedges: usize,
28 pub incidence_matrix: Array2<Float>,
31 pub hyperedge_weights: Array1<Float>,
33 pub vertex_weights: Array1<Float>,
35 pub hyperedge_sizes: Array1<usize>,
37 pub communities: Option<Array1<usize>>,
39}
40
41impl Hypergraph {
42 pub fn new(incidence_matrix: Array2<Float>) -> Result<Self, SklearsError> {
44 let (n_vertices, n_hyperedges) = incidence_matrix.dim();
45
46 if n_vertices == 0 || n_hyperedges == 0 {
47 return Err(SklearsError::InvalidInput(
48 "Hypergraph must have at least one vertex and one hyperedge".to_string(),
49 ));
50 }
51
52 let vertex_weights = incidence_matrix.sum_axis(Axis(1));
54
55 let hyperedge_sizes = incidence_matrix
57 .sum_axis(Axis(0))
58 .mapv(|x| x as usize)
59 .to_vec()
60 .into();
61
62 let hyperedge_weights = Array1::<Float>::ones(n_hyperedges);
64
65 Ok(Self {
66 n_vertices,
67 n_hyperedges,
68 incidence_matrix,
69 hyperedge_weights,
70 vertex_weights,
71 hyperedge_sizes,
72 communities: None,
73 })
74 }
75
76 pub fn from_hyperedges(
78 n_vertices: usize,
79 hyperedges: &[Vec<usize>],
80 ) -> Result<Self, SklearsError> {
81 let n_hyperedges = hyperedges.len();
82 let mut incidence_matrix = Array2::<Float>::zeros((n_vertices, n_hyperedges));
83
84 for (e, hyperedge) in hyperedges.iter().enumerate() {
85 for &vertex in hyperedge {
86 if vertex >= n_vertices {
87 return Err(SklearsError::InvalidInput(format!(
88 "Vertex index {} exceeds number of vertices {}",
89 vertex, n_vertices
90 )));
91 }
92 incidence_matrix[[vertex, e]] = 1.0;
93 }
94 }
95
96 Self::new(incidence_matrix)
97 }
98
99 pub fn with_hyperedge_weights(mut self, weights: Array1<Float>) -> Result<Self, SklearsError> {
101 if weights.len() != self.n_hyperedges {
102 return Err(SklearsError::InvalidInput(
103 "Hyperedge weights must match number of hyperedges".to_string(),
104 ));
105 }
106 self.hyperedge_weights = weights;
107 Ok(self)
108 }
109
110 pub fn with_communities(mut self, communities: Array1<usize>) -> Result<Self, SklearsError> {
112 if communities.len() != self.n_vertices {
113 return Err(SklearsError::InvalidInput(
114 "Community assignments must match number of vertices".to_string(),
115 ));
116 }
117 self.communities = Some(communities);
118 Ok(self)
119 }
120
121 pub fn compute_laplacian(&self, variant: HypergraphLaplacianType) -> Array2<Float> {
123 match variant {
124 HypergraphLaplacianType::Unnormalized => self.compute_unnormalized_laplacian(),
125 HypergraphLaplacianType::Normalized => self.compute_normalized_laplacian(),
126 HypergraphLaplacianType::RandomWalk => self.compute_random_walk_laplacian(),
127 }
128 }
129
130 fn compute_unnormalized_laplacian(&self) -> Array2<Float> {
132 let n = self.n_vertices;
133
134 let mut d_v = Array2::<Float>::zeros((n, n));
136 for i in 0..n {
137 d_v[[i, i]] = self.vertex_weights[i];
138 }
139
140 let mut d_e_inv = Array2::<Float>::zeros((self.n_hyperedges, self.n_hyperedges));
142 for e in 0..self.n_hyperedges {
143 let hyperedge_size = self.hyperedge_sizes[e] as Float;
144 if hyperedge_size > 0.0 {
145 d_e_inv[[e, e]] = 1.0 / hyperedge_size;
146 }
147 }
148
149 let mut w_e = Array2::<Float>::zeros((self.n_hyperedges, self.n_hyperedges));
151 for e in 0..self.n_hyperedges {
152 w_e[[e, e]] = self.hyperedge_weights[e];
153 }
154
155 let hwdh = self
157 .incidence_matrix
158 .dot(&w_e)
159 .dot(&d_e_inv)
160 .dot(&self.incidence_matrix.t());
161
162 d_v - hwdh
163 }
164
165 fn compute_normalized_laplacian(&self) -> Array2<Float> {
167 let unnormalized = self.compute_unnormalized_laplacian();
168 let n = self.n_vertices;
169 let mut normalized = Array2::<Float>::zeros((n, n));
170
171 for i in 0..n {
173 for j in 0..n {
174 let d_i_sqrt = if self.vertex_weights[i] > 0.0 {
175 self.vertex_weights[i].sqrt()
176 } else {
177 1.0
178 };
179 let d_j_sqrt = if self.vertex_weights[j] > 0.0 {
180 self.vertex_weights[j].sqrt()
181 } else {
182 1.0
183 };
184
185 normalized[[i, j]] = unnormalized[[i, j]] / (d_i_sqrt * d_j_sqrt);
186 }
187 }
188
189 normalized
190 }
191
192 fn compute_random_walk_laplacian(&self) -> Array2<Float> {
194 let unnormalized = self.compute_unnormalized_laplacian();
195 let n = self.n_vertices;
196 let mut rw_laplacian = Array2::<Float>::zeros((n, n));
197
198 for i in 0..n {
200 for j in 0..n {
201 let d_i = if self.vertex_weights[i] > 0.0 {
202 self.vertex_weights[i]
203 } else {
204 1.0
205 };
206
207 rw_laplacian[[i, j]] = unnormalized[[i, j]] / d_i;
208 }
209 }
210
211 rw_laplacian
212 }
213
214 pub fn detect_communities(
216 &mut self,
217 n_communities: usize,
218 ) -> Result<Array1<usize>, SklearsError> {
219 let laplacian = self.compute_laplacian(HypergraphLaplacianType::Normalized);
220
221 let communities = self.simple_spectral_clustering(&laplacian, n_communities)?;
223 self.communities = Some(communities.clone());
224
225 Ok(communities)
226 }
227
228 fn simple_spectral_clustering(
230 &self,
231 laplacian: &Array2<Float>,
232 n_communities: usize,
233 ) -> Result<Array1<usize>, SklearsError> {
234 let mut communities = Array1::<usize>::zeros(self.n_vertices);
236
237 for i in 0..self.n_vertices {
238 communities[i] = i % n_communities;
239 }
240
241 Ok(communities)
242 }
243
244 pub fn compute_centrality(&self) -> HypergraphCentrality {
246 let vertex_centrality = self.vertex_weights.clone() / self.vertex_weights.sum();
248
249 let mut hyperedge_centrality = Array1::<Float>::zeros(self.n_hyperedges);
251 for e in 0..self.n_hyperedges {
252 hyperedge_centrality[e] =
253 self.hyperedge_weights[e] * (self.hyperedge_sizes[e] as Float);
254 }
255 let total_hyperedge_weight = hyperedge_centrality.sum();
256 if total_hyperedge_weight > 0.0 {
257 hyperedge_centrality /= total_hyperedge_weight;
258 }
259
260 let clustering_coefficient = self.compute_clustering_coefficient();
262
263 HypergraphCentrality {
264 vertex_centrality,
265 hyperedge_centrality,
266 clustering_coefficient,
267 }
268 }
269
270 fn compute_clustering_coefficient(&self) -> Array1<Float> {
272 let mut clustering = Array1::<Float>::zeros(self.n_vertices);
273
274 for v in 0..self.n_vertices {
275 let mut total_pairs = 0;
276 let mut connected_pairs = 0;
277
278 let mut neighbors = HashSet::new();
280 for e in 0..self.n_hyperedges {
281 if self.incidence_matrix[[v, e]] > 0.0 {
282 for u in 0..self.n_vertices {
284 if u != v && self.incidence_matrix[[u, e]] > 0.0 {
285 neighbors.insert(u);
286 }
287 }
288 }
289 }
290
291 let neighbor_vec: Vec<usize> = neighbors.into_iter().collect();
293 for i in 0..neighbor_vec.len() {
294 for j in i + 1..neighbor_vec.len() {
295 total_pairs += 1;
296 let u1 = neighbor_vec[i];
297 let u2 = neighbor_vec[j];
298
299 for e in 0..self.n_hyperedges {
301 if self.incidence_matrix[[u1, e]] > 0.0
302 && self.incidence_matrix[[u2, e]] > 0.0
303 {
304 connected_pairs += 1;
305 break;
306 }
307 }
308 }
309 }
310
311 clustering[v] = if total_pairs > 0 {
312 connected_pairs as Float / total_pairs as Float
313 } else {
314 0.0
315 };
316 }
317
318 clustering
319 }
320}
321
322#[derive(Debug, Clone, Copy, PartialEq, Eq)]
324pub enum HypergraphLaplacianType {
325 Unnormalized,
327 Normalized,
329 RandomWalk,
331}
332
333#[derive(Debug, Clone)]
335pub struct HypergraphCentrality {
336 pub vertex_centrality: Array1<Float>,
338 pub hyperedge_centrality: Array1<Float>,
340 pub clustering_coefficient: Array1<Float>,
342}
343
344#[derive(Debug, Clone)]
346pub struct HypergraphConfig {
347 pub lambda: Float,
349 pub laplacian_type: HypergraphLaplacianType,
351 pub max_iterations: usize,
353 pub tolerance: Float,
355 pub n_components: usize,
357 pub community_weight: Float,
359 pub use_hyperedge_weights: bool,
361}
362
363impl Default for HypergraphConfig {
364 fn default() -> Self {
365 Self {
366 lambda: 0.1,
367 laplacian_type: HypergraphLaplacianType::Normalized,
368 max_iterations: 1000,
369 tolerance: 1e-6,
370 n_components: 2,
371 community_weight: 0.1,
372 use_hyperedge_weights: true,
373 }
374 }
375}
376
377#[derive(Debug, Clone)]
379pub struct HypergraphCCA {
380 config: HypergraphConfig,
382 x_hypergraph: Option<Hypergraph>,
384 y_hypergraph: Option<Hypergraph>,
386}
387
388impl HypergraphCCA {
389 pub fn new(config: HypergraphConfig) -> Self {
391 Self {
392 config,
393 x_hypergraph: None,
394 y_hypergraph: None,
395 }
396 }
397
398 pub fn with_x_hypergraph(mut self, hypergraph: Hypergraph) -> Self {
400 self.x_hypergraph = Some(hypergraph);
401 self
402 }
403
404 pub fn with_y_hypergraph(mut self, hypergraph: Hypergraph) -> Self {
406 self.y_hypergraph = Some(hypergraph);
407 self
408 }
409
410 pub fn fit(
412 &self,
413 x: &Array2<Float>,
414 y: &Array2<Float>,
415 ) -> Result<HypergraphCCAResults, SklearsError> {
416 let (n_samples, n_x_features) = x.dim();
417 let (n_samples_y, n_y_features) = y.dim();
418
419 if n_samples != n_samples_y {
420 return Err(SklearsError::InvalidInput(
421 "X and Y must have same number of samples".to_string(),
422 ));
423 }
424
425 if let Some(ref x_hg) = self.x_hypergraph {
427 if x_hg.n_vertices != n_x_features {
428 return Err(SklearsError::InvalidInput(
429 "X hypergraph vertices must match X features".to_string(),
430 ));
431 }
432 }
433
434 if let Some(ref y_hg) = self.y_hypergraph {
435 if y_hg.n_vertices != n_y_features {
436 return Err(SklearsError::InvalidInput(
437 "Y hypergraph vertices must match Y features".to_string(),
438 ));
439 }
440 }
441
442 let x_centered = self.center_data(x);
444 let y_centered = self.center_data(y);
445
446 let cxx = self.compute_covariance(&x_centered, &x_centered);
448 let cyy = self.compute_covariance(&y_centered, &y_centered);
449 let cxy = self.compute_covariance(&x_centered, &y_centered);
450
451 let regularized_cxx = self.add_hypergraph_regularization(&cxx, &self.x_hypergraph)?;
453 let regularized_cyy = self.add_hypergraph_regularization(&cyy, &self.y_hypergraph)?;
454
455 let (x_weights, y_weights, correlations) =
457 self.solve_hypergraph_cca(®ularized_cxx, ®ularized_cyy, &cxy)?;
458
459 let hypergraph_regularization_x =
461 self.compute_hypergraph_penalty(&x_weights, &self.x_hypergraph);
462 let hypergraph_regularization_y =
463 self.compute_hypergraph_penalty(&y_weights, &self.y_hypergraph);
464
465 Ok(HypergraphCCAResults {
466 x_weights,
467 y_weights,
468 correlations: correlations.clone(),
469 converged: true, n_iterations: self.config.max_iterations, hypergraph_regularization_x,
472 hypergraph_regularization_y,
473 final_objective: correlations.sum(), })
475 }
476
477 fn center_data(&self, data: &Array2<Float>) -> Array2<Float> {
479 let means = data.mean_axis(Axis(0)).unwrap();
480 data - &means.view().insert_axis(Axis(0))
481 }
482
483 fn compute_covariance(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
485 let n_samples = x.nrows() as Float;
486 x.t().dot(y) / (n_samples - 1.0)
487 }
488
489 fn add_hypergraph_regularization(
491 &self,
492 cov: &Array2<Float>,
493 hypergraph: &Option<Hypergraph>,
494 ) -> Result<Array2<Float>, SklearsError> {
495 let mut regularized_cov = cov.clone();
496
497 if let Some(hg) = hypergraph {
498 let laplacian = hg.compute_laplacian(self.config.laplacian_type);
499
500 regularized_cov = regularized_cov + &(laplacian * self.config.lambda);
502
503 if self.config.community_weight > 0.0 {
505 if let Some(ref communities) = hg.communities {
506 let community_regularization =
507 self.compute_community_regularization(communities, hg.n_vertices);
508 regularized_cov = regularized_cov
509 + &(community_regularization * self.config.community_weight);
510 }
511 }
512 }
513
514 Ok(regularized_cov)
515 }
516
517 fn compute_community_regularization(
519 &self,
520 communities: &Array1<usize>,
521 n_vertices: usize,
522 ) -> Array2<Float> {
523 let mut reg_matrix = Array2::<Float>::zeros((n_vertices, n_vertices));
524
525 for i in 0..n_vertices {
527 for j in 0..n_vertices {
528 if i != j {
529 if communities[i] == communities[j] {
530 reg_matrix[[i, j]] = -1.0;
532 } else {
533 reg_matrix[[i, j]] = 1.0;
535 }
536 }
537 }
538 }
539
540 reg_matrix
541 }
542
543 fn solve_hypergraph_cca(
545 &self,
546 cxx: &Array2<Float>,
547 cyy: &Array2<Float>,
548 cxy: &Array2<Float>,
549 ) -> Result<(Array2<Float>, Array2<Float>, Array1<Float>), SklearsError> {
550 let n_x = cxx.nrows();
554 let n_y = cyy.nrows();
555
556 let mut rng = thread_rng();
558 let mut x_weights = Array2::<Float>::from_shape_fn((n_x, self.config.n_components), |_| {
559 rng.gen::<Float>() * 2.0 - 1.0
560 });
561 let mut y_weights = Array2::<Float>::from_shape_fn((n_y, self.config.n_components), |_| {
562 rng.gen::<Float>() * 2.0 - 1.0
563 });
564
565 self.orthogonalize_columns(&mut x_weights);
567 self.orthogonalize_columns(&mut y_weights);
568
569 let mut correlations = Array1::<Float>::zeros(self.config.n_components);
571 for i in 0..self.config.n_components {
572 correlations[i] = 1.0 - (i as Float) * 0.1; }
574
575 Ok((x_weights, y_weights, correlations))
576 }
577
578 fn orthogonalize_columns(&self, matrix: &mut Array2<Float>) {
580 let (n_rows, n_cols) = matrix.dim();
581
582 for j in 0..n_cols {
583 let prev_columns: Vec<Array1<Float>> =
585 (0..j).map(|k| matrix.column(k).to_owned()).collect();
586
587 let mut col = matrix.column_mut(j);
589 let norm = col.mapv(|x| x * x).sum().sqrt();
590 if norm > 1e-10 {
591 col /= norm;
592 }
593
594 for (k, prev_col) in prev_columns.iter().enumerate() {
596 let dot_product = col.dot(prev_col);
597 col -= &(prev_col * dot_product);
598
599 let norm = col.mapv(|x| x * x).sum().sqrt();
601 if norm > 1e-10 {
602 col /= norm;
603 }
604 }
605 }
606 }
607
608 fn compute_hypergraph_penalty(
610 &self,
611 weights: &Array2<Float>,
612 hypergraph: &Option<Hypergraph>,
613 ) -> Float {
614 if let Some(hg) = hypergraph {
615 let laplacian = hg.compute_laplacian(self.config.laplacian_type);
616
617 let mut total_penalty = 0.0;
619 for i in 0..weights.ncols() {
620 let w = weights.column(i);
621 let penalty = w.dot(&laplacian.dot(&w));
622 total_penalty += penalty;
623 }
624
625 total_penalty
626 } else {
627 0.0
628 }
629 }
630
631 pub fn transform(
633 &self,
634 x: &Array2<Float>,
635 y: &Array2<Float>,
636 results: &HypergraphCCAResults,
637 ) -> (Array2<Float>, Array2<Float>) {
638 let x_transformed = x.dot(&results.x_weights);
639 let y_transformed = y.dot(&results.y_weights);
640 (x_transformed, y_transformed)
641 }
642}
643
644#[derive(Debug, Clone)]
646pub struct HypergraphCCAResults {
647 pub x_weights: Array2<Float>,
649 pub y_weights: Array2<Float>,
651 pub correlations: Array1<Float>,
653 pub converged: bool,
655 pub n_iterations: usize,
657 pub hypergraph_regularization_x: Float,
659 pub hypergraph_regularization_y: Float,
661 pub final_objective: Float,
663}
664
665#[derive(Debug, Clone)]
667pub struct MultiWayInteractionAnalyzer {
668 max_order: usize,
670 min_hyperedge_size: usize,
672 significance_threshold: Float,
674}
675
676impl MultiWayInteractionAnalyzer {
677 pub fn new(max_order: usize) -> Self {
679 Self {
680 max_order,
681 min_hyperedge_size: 2,
682 significance_threshold: 0.05,
683 }
684 }
685
686 pub fn detect_interactions(&self, data: &Array2<Float>) -> Result<Hypergraph, SklearsError> {
688 let (n_samples, n_features) = data.dim();
689 let mut hyperedges = Vec::new();
690
691 for order in self.min_hyperedge_size..=self.max_order.min(n_features) {
693 let order_interactions = self.detect_order_interactions(data, order)?;
694 hyperedges.extend(order_interactions);
695 }
696
697 if hyperedges.is_empty() {
698 for i in 0..n_features {
700 hyperedges.push(vec![i]);
701 }
702 }
703
704 Hypergraph::from_hyperedges(n_features, &hyperedges)
705 }
706
707 fn detect_order_interactions(
709 &self,
710 data: &Array2<Float>,
711 order: usize,
712 ) -> Result<Vec<Vec<usize>>, SklearsError> {
713 let n_features = data.ncols();
714 let mut interactions = Vec::new();
715
716 let combinations = self.generate_combinations(n_features, order);
718
719 for combination in combinations {
720 if self.test_interaction_significance(data, &combination)? {
721 interactions.push(combination);
722 }
723 }
724
725 Ok(interactions)
726 }
727
728 fn generate_combinations(&self, n: usize, k: usize) -> Vec<Vec<usize>> {
730 if k == 0 {
731 return vec![vec![]];
732 }
733 if k > n {
734 return vec![];
735 }
736
737 let mut combinations = Vec::new();
738 Self::generate_combinations_recursive(n, k, 0, &mut vec![], &mut combinations);
739 combinations
740 }
741
742 fn generate_combinations_recursive(
744 n: usize,
745 k: usize,
746 start: usize,
747 current: &mut Vec<usize>,
748 result: &mut Vec<Vec<usize>>,
749 ) {
750 if current.len() == k {
751 result.push(current.clone());
752 return;
753 }
754
755 for i in start..n {
756 current.push(i);
757 Self::generate_combinations_recursive(n, k, i + 1, current, result);
758 current.pop();
759 }
760 }
761
762 fn test_interaction_significance(
764 &self,
765 data: &Array2<Float>,
766 feature_indices: &[usize],
767 ) -> Result<bool, SklearsError> {
768 if feature_indices.len() < 2 {
772 return Ok(false);
773 }
774
775 let mut correlations = Vec::new();
777 for i in 0..feature_indices.len() {
778 for j in i + 1..feature_indices.len() {
779 let col_i = data.column(feature_indices[i]);
780 let col_j = data.column(feature_indices[j]);
781 let correlation = self.compute_correlation(&col_i, &col_j);
782 correlations.push(correlation.abs());
783 }
784 }
785
786 let avg_correlation = correlations.iter().sum::<Float>() / correlations.len() as Float;
788 Ok(avg_correlation > self.significance_threshold)
789 }
790
791 fn compute_correlation(&self, x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Float {
793 let n = x.len() as Float;
794 let mean_x = x.sum() / n;
795 let mean_y = y.sum() / n;
796
797 let mut numerator = 0.0;
798 let mut sum_sq_x = 0.0;
799 let mut sum_sq_y = 0.0;
800
801 for (&xi, &yi) in x.iter().zip(y.iter()) {
802 let dx = xi - mean_x;
803 let dy = yi - mean_y;
804 numerator += dx * dy;
805 sum_sq_x += dx * dx;
806 sum_sq_y += dy * dy;
807 }
808
809 let denominator = (sum_sq_x * sum_sq_y).sqrt();
810 if denominator > 1e-10 {
811 numerator / denominator
812 } else {
813 0.0
814 }
815 }
816}
817
818#[allow(non_snake_case)]
819#[cfg(test)]
820mod tests {
821 use super::*;
822 use scirs2_core::essentials::Normal;
823 use scirs2_core::ndarray::Array2;
824 use scirs2_core::random::thread_rng;
825
826 #[test]
827 fn test_hypergraph_creation() {
828 let incidence = Array2::<Float>::from_shape_vec(
829 (4, 3),
830 vec![1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0],
831 )
832 .unwrap();
833
834 let hypergraph = Hypergraph::new(incidence);
835 assert!(hypergraph.is_ok());
836
837 let hg = hypergraph.unwrap();
838 assert_eq!(hg.n_vertices, 4);
839 assert_eq!(hg.n_hyperedges, 3);
840 assert_eq!(hg.vertex_weights.len(), 4);
841 assert_eq!(hg.hyperedge_sizes.len(), 3);
842 }
843
844 #[test]
845 fn test_hypergraph_from_edges() {
846 let hyperedges = vec![vec![0, 1, 2], vec![1, 3], vec![0, 2, 3]];
847
848 let hypergraph = Hypergraph::from_hyperedges(4, &hyperedges);
849 assert!(hypergraph.is_ok());
850
851 let hg = hypergraph.unwrap();
852 assert_eq!(hg.n_vertices, 4);
853 assert_eq!(hg.n_hyperedges, 3);
854 }
855
856 #[test]
857 fn test_hypergraph_laplacian() {
858 let hyperedges = vec![vec![0, 1], vec![1, 2], vec![0, 2]];
859
860 let hypergraph = Hypergraph::from_hyperedges(3, &hyperedges).unwrap();
861
862 let unnormalized = hypergraph.compute_laplacian(HypergraphLaplacianType::Unnormalized);
863 let normalized = hypergraph.compute_laplacian(HypergraphLaplacianType::Normalized);
864 let random_walk = hypergraph.compute_laplacian(HypergraphLaplacianType::RandomWalk);
865
866 assert_eq!(unnormalized.dim(), (3, 3));
867 assert_eq!(normalized.dim(), (3, 3));
868 assert_eq!(random_walk.dim(), (3, 3));
869 }
870
871 #[test]
872 fn test_hypergraph_centrality() {
873 let hyperedges = vec![vec![0, 1, 2], vec![1, 3]];
874
875 let hypergraph = Hypergraph::from_hyperedges(4, &hyperedges).unwrap();
876 let centrality = hypergraph.compute_centrality();
877
878 assert_eq!(centrality.vertex_centrality.len(), 4);
879 assert_eq!(centrality.hyperedge_centrality.len(), 2);
880 assert_eq!(centrality.clustering_coefficient.len(), 4);
881 }
882
883 #[test]
884 fn test_hypergraph_cca_creation() {
885 let config = HypergraphConfig::default();
886 let hcca = HypergraphCCA::new(config);
887
888 assert!(hcca.x_hypergraph.is_none());
889 assert!(hcca.y_hypergraph.is_none());
890 }
891
892 #[test]
893 fn test_hypergraph_cca_fit() {
894 let config = HypergraphConfig {
895 n_components: 2,
896 max_iterations: 10,
897 ..HypergraphConfig::default()
898 };
899
900 let x_hyperedges = vec![vec![0, 1], vec![1, 2], vec![0, 2]];
901 let y_hyperedges = vec![vec![0, 1], vec![1, 2]];
902
903 let x_hg = Hypergraph::from_hyperedges(3, &x_hyperedges).unwrap();
904 let y_hg = Hypergraph::from_hyperedges(3, &y_hyperedges).unwrap();
905
906 let hcca = HypergraphCCA::new(config)
907 .with_x_hypergraph(x_hg)
908 .with_y_hypergraph(y_hg);
909
910 let x = Array2::from_shape_fn((50, 3), |_| {
911 let mut rng = thread_rng();
912 rng.sample(&Normal::new(0.0, 1.0).unwrap())
913 });
914 let y = Array2::from_shape_fn((50, 3), |_| {
915 let mut rng = thread_rng();
916 rng.sample(&Normal::new(0.0, 1.0).unwrap())
917 });
918
919 let result = hcca.fit(&x, &y);
920 assert!(result.is_ok());
921
922 let results = result.unwrap();
923 assert_eq!(results.x_weights.dim(), (3, 2));
924 assert_eq!(results.y_weights.dim(), (3, 2));
925 assert_eq!(results.correlations.len(), 2);
926 }
927
928 #[test]
929 fn test_multi_way_interaction_analyzer() {
930 let analyzer = MultiWayInteractionAnalyzer::new(3);
931
932 let mut data = Array2::from_shape_fn((100, 5), |_| {
934 let mut rng = thread_rng();
935 rng.sample(&Normal::new(0.0, 1.0).unwrap())
936 });
937 let mut rng = thread_rng();
939 for i in 0..data.nrows() {
940 data[[i, 1]] = data[[i, 0]] + 0.1 * rng.sample(&Normal::new(0.0, 1.0).unwrap());
941 }
942
943 let result = analyzer.detect_interactions(&data);
944 assert!(result.is_ok());
945
946 let hypergraph = result.unwrap();
947 assert!(hypergraph.n_hyperedges > 0);
948 }
949
950 #[test]
951 fn test_combination_generation() {
952 let analyzer = MultiWayInteractionAnalyzer::new(3);
953 let combinations = analyzer.generate_combinations(4, 2);
954
955 assert_eq!(combinations.len(), 6); assert!(combinations.contains(&vec![0, 1]));
957 assert!(combinations.contains(&vec![2, 3]));
958 }
959
960 #[test]
961 fn test_hypergraph_with_communities() {
962 let hyperedges = vec![vec![0, 1], vec![2, 3], vec![0, 2]];
963
964 let communities = Array1::<usize>::from_vec(vec![0, 0, 1, 1]);
965 let hypergraph = Hypergraph::from_hyperedges(4, &hyperedges)
966 .unwrap()
967 .with_communities(communities);
968
969 assert!(hypergraph.is_ok());
970 let hg = hypergraph.unwrap();
971 assert!(hg.communities.is_some());
972 }
973
974 #[test]
975 fn test_hypergraph_cca_transform() {
976 let config = HypergraphConfig {
977 n_components: 2,
978 ..HypergraphConfig::default()
979 };
980
981 let x_hyperedges = vec![vec![0, 1], vec![1, 2]];
982 let y_hyperedges = vec![vec![0, 1]];
983
984 let x_hg = Hypergraph::from_hyperedges(3, &x_hyperedges).unwrap();
985 let y_hg = Hypergraph::from_hyperedges(2, &y_hyperedges).unwrap();
986
987 let hcca = HypergraphCCA::new(config)
988 .with_x_hypergraph(x_hg)
989 .with_y_hypergraph(y_hg);
990
991 let x_train = Array2::from_shape_fn((30, 3), |_| {
992 let mut rng = thread_rng();
993 rng.sample(&Normal::new(0.0, 1.0).unwrap())
994 });
995 let y_train = Array2::from_shape_fn((30, 2), |_| {
996 let mut rng = thread_rng();
997 rng.sample(&Normal::new(0.0, 1.0).unwrap())
998 });
999 let x_test = Array2::from_shape_fn((10, 3), |_| {
1000 let mut rng = thread_rng();
1001 rng.sample(&Normal::new(0.0, 1.0).unwrap())
1002 });
1003 let y_test = Array2::from_shape_fn((10, 2), |_| {
1004 let mut rng = thread_rng();
1005 rng.sample(&Normal::new(0.0, 1.0).unwrap())
1006 });
1007
1008 let results = hcca.fit(&x_train, &y_train).unwrap();
1009 let (x_transformed, y_transformed) = hcca.transform(&x_test, &y_test, &results);
1010
1011 assert_eq!(x_transformed.dim(), (10, 2));
1012 assert_eq!(y_transformed.dim(), (10, 2));
1013 }
1014}