sklears_cross_decomposition/graph_regularization/
hypergraph_methods.rs

1//! Hypergraph-Regularized Cross-Decomposition Methods
2//!
3//! This module provides advanced hypergraph-regularized versions of canonical correlation
4//! analysis (CCA) and partial least squares (PLS). Hypergraphs can model higher-order
5//! relationships where edges can connect more than two vertices, making them ideal for
6//! analyzing complex multi-way interactions in data.
7//!
8//! ## Key Features
9//! - Hypergraph Laplacian regularization with normalized and random walk variants
10//! - Multi-way constraint propagation through hyperedges
11//! - Hypergraph clustering and community detection integration
12//! - Tensor-based hypergraph representations
13//! - Spectral hypergraph methods for dimensionality reduction
14
15use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
16use scirs2_core::random::{thread_rng, Random, Rng};
17use sklears_core::error::SklearsError;
18use sklears_core::types::Float;
19use std::collections::{HashMap, HashSet};
20
21/// Hypergraph structure representation
22#[derive(Debug, Clone)]
23pub struct Hypergraph {
24    /// Number of vertices
25    pub n_vertices: usize,
26    /// Number of hyperedges
27    pub n_hyperedges: usize,
28    /// Incidence matrix: vertices × hyperedges
29    /// H[i,e] = 1 if vertex i is in hyperedge e, 0 otherwise
30    pub incidence_matrix: Array2<Float>,
31    /// Hyperedge weights
32    pub hyperedge_weights: Array1<Float>,
33    /// Vertex weights (degrees)
34    pub vertex_weights: Array1<Float>,
35    /// Hyperedge size distribution
36    pub hyperedge_sizes: Array1<usize>,
37    /// Optional clustering/community information
38    pub communities: Option<Array1<usize>>,
39}
40
41impl Hypergraph {
42    /// Create a new hypergraph from incidence matrix
43    pub fn new(incidence_matrix: Array2<Float>) -> Result<Self, SklearsError> {
44        let (n_vertices, n_hyperedges) = incidence_matrix.dim();
45
46        if n_vertices == 0 || n_hyperedges == 0 {
47            return Err(SklearsError::InvalidInput(
48                "Hypergraph must have at least one vertex and one hyperedge".to_string(),
49            ));
50        }
51
52        // Compute vertex degrees (sum over hyperedges)
53        let vertex_weights = incidence_matrix.sum_axis(Axis(1));
54
55        // Compute hyperedge sizes
56        let hyperedge_sizes = incidence_matrix
57            .sum_axis(Axis(0))
58            .mapv(|x| x as usize)
59            .to_vec()
60            .into();
61
62        // Default uniform hyperedge weights
63        let hyperedge_weights = Array1::<Float>::ones(n_hyperedges);
64
65        Ok(Self {
66            n_vertices,
67            n_hyperedges,
68            incidence_matrix,
69            hyperedge_weights,
70            vertex_weights,
71            hyperedge_sizes,
72            communities: None,
73        })
74    }
75
76    /// Create a hypergraph from a list of hyperedges
77    pub fn from_hyperedges(
78        n_vertices: usize,
79        hyperedges: &[Vec<usize>],
80    ) -> Result<Self, SklearsError> {
81        let n_hyperedges = hyperedges.len();
82        let mut incidence_matrix = Array2::<Float>::zeros((n_vertices, n_hyperedges));
83
84        for (e, hyperedge) in hyperedges.iter().enumerate() {
85            for &vertex in hyperedge {
86                if vertex >= n_vertices {
87                    return Err(SklearsError::InvalidInput(format!(
88                        "Vertex index {} exceeds number of vertices {}",
89                        vertex, n_vertices
90                    )));
91                }
92                incidence_matrix[[vertex, e]] = 1.0;
93            }
94        }
95
96        Self::new(incidence_matrix)
97    }
98
99    /// Set hyperedge weights
100    pub fn with_hyperedge_weights(mut self, weights: Array1<Float>) -> Result<Self, SklearsError> {
101        if weights.len() != self.n_hyperedges {
102            return Err(SklearsError::InvalidInput(
103                "Hyperedge weights must match number of hyperedges".to_string(),
104            ));
105        }
106        self.hyperedge_weights = weights;
107        Ok(self)
108    }
109
110    /// Set community assignments
111    pub fn with_communities(mut self, communities: Array1<usize>) -> Result<Self, SklearsError> {
112        if communities.len() != self.n_vertices {
113            return Err(SklearsError::InvalidInput(
114                "Community assignments must match number of vertices".to_string(),
115            ));
116        }
117        self.communities = Some(communities);
118        Ok(self)
119    }
120
121    /// Compute the hypergraph Laplacian matrix
122    pub fn compute_laplacian(&self, variant: HypergraphLaplacianType) -> Array2<Float> {
123        match variant {
124            HypergraphLaplacianType::Unnormalized => self.compute_unnormalized_laplacian(),
125            HypergraphLaplacianType::Normalized => self.compute_normalized_laplacian(),
126            HypergraphLaplacianType::RandomWalk => self.compute_random_walk_laplacian(),
127        }
128    }
129
130    /// Compute unnormalized hypergraph Laplacian: L = D_v - H * W_e * D_e^{-1} * H^T
131    fn compute_unnormalized_laplacian(&self) -> Array2<Float> {
132        let n = self.n_vertices;
133
134        // Vertex degree matrix
135        let mut d_v = Array2::<Float>::zeros((n, n));
136        for i in 0..n {
137            d_v[[i, i]] = self.vertex_weights[i];
138        }
139
140        // Hyperedge degree matrix (diagonal with hyperedge sizes)
141        let mut d_e_inv = Array2::<Float>::zeros((self.n_hyperedges, self.n_hyperedges));
142        for e in 0..self.n_hyperedges {
143            let hyperedge_size = self.hyperedge_sizes[e] as Float;
144            if hyperedge_size > 0.0 {
145                d_e_inv[[e, e]] = 1.0 / hyperedge_size;
146            }
147        }
148
149        // Weight matrix
150        let mut w_e = Array2::<Float>::zeros((self.n_hyperedges, self.n_hyperedges));
151        for e in 0..self.n_hyperedges {
152            w_e[[e, e]] = self.hyperedge_weights[e];
153        }
154
155        // Compute H * W_e * D_e^{-1} * H^T
156        let hwdh = self
157            .incidence_matrix
158            .dot(&w_e)
159            .dot(&d_e_inv)
160            .dot(&self.incidence_matrix.t());
161
162        d_v - hwdh
163    }
164
165    /// Compute normalized hypergraph Laplacian
166    fn compute_normalized_laplacian(&self) -> Array2<Float> {
167        let unnormalized = self.compute_unnormalized_laplacian();
168        let n = self.n_vertices;
169        let mut normalized = Array2::<Float>::zeros((n, n));
170
171        // L_norm = D_v^{-1/2} * L * D_v^{-1/2}
172        for i in 0..n {
173            for j in 0..n {
174                let d_i_sqrt = if self.vertex_weights[i] > 0.0 {
175                    self.vertex_weights[i].sqrt()
176                } else {
177                    1.0
178                };
179                let d_j_sqrt = if self.vertex_weights[j] > 0.0 {
180                    self.vertex_weights[j].sqrt()
181                } else {
182                    1.0
183                };
184
185                normalized[[i, j]] = unnormalized[[i, j]] / (d_i_sqrt * d_j_sqrt);
186            }
187        }
188
189        normalized
190    }
191
192    /// Compute random walk hypergraph Laplacian
193    fn compute_random_walk_laplacian(&self) -> Array2<Float> {
194        let unnormalized = self.compute_unnormalized_laplacian();
195        let n = self.n_vertices;
196        let mut rw_laplacian = Array2::<Float>::zeros((n, n));
197
198        // L_rw = D_v^{-1} * L
199        for i in 0..n {
200            for j in 0..n {
201                let d_i = if self.vertex_weights[i] > 0.0 {
202                    self.vertex_weights[i]
203                } else {
204                    1.0
205                };
206
207                rw_laplacian[[i, j]] = unnormalized[[i, j]] / d_i;
208            }
209        }
210
211        rw_laplacian
212    }
213
214    /// Detect communities using spectral clustering on the hypergraph
215    pub fn detect_communities(
216        &mut self,
217        n_communities: usize,
218    ) -> Result<Array1<usize>, SklearsError> {
219        let laplacian = self.compute_laplacian(HypergraphLaplacianType::Normalized);
220
221        // Compute eigenvectors of the Laplacian (simplified - in practice would use proper eigendecomposition)
222        let communities = self.simple_spectral_clustering(&laplacian, n_communities)?;
223        self.communities = Some(communities.clone());
224
225        Ok(communities)
226    }
227
228    /// Simple spectral clustering (placeholder implementation)
229    fn simple_spectral_clustering(
230        &self,
231        laplacian: &Array2<Float>,
232        n_communities: usize,
233    ) -> Result<Array1<usize>, SklearsError> {
234        // Simplified clustering - in practice would use proper spectral clustering
235        let mut communities = Array1::<usize>::zeros(self.n_vertices);
236
237        for i in 0..self.n_vertices {
238            communities[i] = i % n_communities;
239        }
240
241        Ok(communities)
242    }
243
244    /// Compute hypergraph centrality measures
245    pub fn compute_centrality(&self) -> HypergraphCentrality {
246        // Vertex centrality based on weighted degree
247        let vertex_centrality = self.vertex_weights.clone() / self.vertex_weights.sum();
248
249        // Hyperedge centrality based on size and weight
250        let mut hyperedge_centrality = Array1::<Float>::zeros(self.n_hyperedges);
251        for e in 0..self.n_hyperedges {
252            hyperedge_centrality[e] =
253                self.hyperedge_weights[e] * (self.hyperedge_sizes[e] as Float);
254        }
255        let total_hyperedge_weight = hyperedge_centrality.sum();
256        if total_hyperedge_weight > 0.0 {
257            hyperedge_centrality /= total_hyperedge_weight;
258        }
259
260        // Clustering coefficient (simplified)
261        let clustering_coefficient = self.compute_clustering_coefficient();
262
263        HypergraphCentrality {
264            vertex_centrality,
265            hyperedge_centrality,
266            clustering_coefficient,
267        }
268    }
269
270    /// Compute clustering coefficient for hypergraph
271    fn compute_clustering_coefficient(&self) -> Array1<Float> {
272        let mut clustering = Array1::<Float>::zeros(self.n_vertices);
273
274        for v in 0..self.n_vertices {
275            let mut total_pairs = 0;
276            let mut connected_pairs = 0;
277
278            // Find all hyperedges containing vertex v
279            let mut neighbors = HashSet::new();
280            for e in 0..self.n_hyperedges {
281                if self.incidence_matrix[[v, e]] > 0.0 {
282                    // Add all other vertices in this hyperedge as neighbors
283                    for u in 0..self.n_vertices {
284                        if u != v && self.incidence_matrix[[u, e]] > 0.0 {
285                            neighbors.insert(u);
286                        }
287                    }
288                }
289            }
290
291            // Check connectivity between neighbor pairs
292            let neighbor_vec: Vec<usize> = neighbors.into_iter().collect();
293            for i in 0..neighbor_vec.len() {
294                for j in i + 1..neighbor_vec.len() {
295                    total_pairs += 1;
296                    let u1 = neighbor_vec[i];
297                    let u2 = neighbor_vec[j];
298
299                    // Check if u1 and u2 are connected through a hyperedge
300                    for e in 0..self.n_hyperedges {
301                        if self.incidence_matrix[[u1, e]] > 0.0
302                            && self.incidence_matrix[[u2, e]] > 0.0
303                        {
304                            connected_pairs += 1;
305                            break;
306                        }
307                    }
308                }
309            }
310
311            clustering[v] = if total_pairs > 0 {
312                connected_pairs as Float / total_pairs as Float
313            } else {
314                0.0
315            };
316        }
317
318        clustering
319    }
320}
321
322/// Types of hypergraph Laplacian matrices
323#[derive(Debug, Clone, Copy, PartialEq, Eq)]
324pub enum HypergraphLaplacianType {
325    /// Unnormalized Laplacian
326    Unnormalized,
327    /// Normalized Laplacian
328    Normalized,
329    /// Random walk Laplacian
330    RandomWalk,
331}
332
333/// Hypergraph centrality measures
334#[derive(Debug, Clone)]
335pub struct HypergraphCentrality {
336    /// Centrality scores for vertices
337    pub vertex_centrality: Array1<Float>,
338    /// Centrality scores for hyperedges
339    pub hyperedge_centrality: Array1<Float>,
340    /// Clustering coefficient for each vertex
341    pub clustering_coefficient: Array1<Float>,
342}
343
344/// Configuration for hypergraph-regularized methods
345#[derive(Debug, Clone)]
346pub struct HypergraphConfig {
347    /// Regularization strength
348    pub lambda: Float,
349    /// Type of hypergraph Laplacian
350    pub laplacian_type: HypergraphLaplacianType,
351    /// Maximum iterations for optimization
352    pub max_iterations: usize,
353    /// Convergence tolerance
354    pub tolerance: Float,
355    /// Number of components to extract
356    pub n_components: usize,
357    /// Community regularization weight
358    pub community_weight: Float,
359    /// Use hyperedge weights in regularization
360    pub use_hyperedge_weights: bool,
361}
362
363impl Default for HypergraphConfig {
364    fn default() -> Self {
365        Self {
366            lambda: 0.1,
367            laplacian_type: HypergraphLaplacianType::Normalized,
368            max_iterations: 1000,
369            tolerance: 1e-6,
370            n_components: 2,
371            community_weight: 0.1,
372            use_hyperedge_weights: true,
373        }
374    }
375}
376
377/// Hypergraph-regularized Canonical Correlation Analysis
378#[derive(Debug, Clone)]
379pub struct HypergraphCCA {
380    /// Configuration parameters
381    config: HypergraphConfig,
382    /// Hypergraph for X variables
383    x_hypergraph: Option<Hypergraph>,
384    /// Hypergraph for Y variables
385    y_hypergraph: Option<Hypergraph>,
386}
387
388impl HypergraphCCA {
389    /// Create new hypergraph-regularized CCA
390    pub fn new(config: HypergraphConfig) -> Self {
391        Self {
392            config,
393            x_hypergraph: None,
394            y_hypergraph: None,
395        }
396    }
397
398    /// Set hypergraph for X variables
399    pub fn with_x_hypergraph(mut self, hypergraph: Hypergraph) -> Self {
400        self.x_hypergraph = Some(hypergraph);
401        self
402    }
403
404    /// Set hypergraph for Y variables
405    pub fn with_y_hypergraph(mut self, hypergraph: Hypergraph) -> Self {
406        self.y_hypergraph = Some(hypergraph);
407        self
408    }
409
410    /// Fit hypergraph-regularized CCA
411    pub fn fit(
412        &self,
413        x: &Array2<Float>,
414        y: &Array2<Float>,
415    ) -> Result<HypergraphCCAResults, SklearsError> {
416        let (n_samples, n_x_features) = x.dim();
417        let (n_samples_y, n_y_features) = y.dim();
418
419        if n_samples != n_samples_y {
420            return Err(SklearsError::InvalidInput(
421                "X and Y must have same number of samples".to_string(),
422            ));
423        }
424
425        // Validate hypergraph dimensions
426        if let Some(ref x_hg) = self.x_hypergraph {
427            if x_hg.n_vertices != n_x_features {
428                return Err(SklearsError::InvalidInput(
429                    "X hypergraph vertices must match X features".to_string(),
430                ));
431            }
432        }
433
434        if let Some(ref y_hg) = self.y_hypergraph {
435            if y_hg.n_vertices != n_y_features {
436                return Err(SklearsError::InvalidInput(
437                    "Y hypergraph vertices must match Y features".to_string(),
438                ));
439            }
440        }
441
442        // Center the data
443        let x_centered = self.center_data(x);
444        let y_centered = self.center_data(y);
445
446        // Compute covariance matrices
447        let cxx = self.compute_covariance(&x_centered, &x_centered);
448        let cyy = self.compute_covariance(&y_centered, &y_centered);
449        let cxy = self.compute_covariance(&x_centered, &y_centered);
450
451        // Add hypergraph regularization
452        let regularized_cxx = self.add_hypergraph_regularization(&cxx, &self.x_hypergraph)?;
453        let regularized_cyy = self.add_hypergraph_regularization(&cyy, &self.y_hypergraph)?;
454
455        // Solve regularized CCA eigenvalue problem
456        let (x_weights, y_weights, correlations) =
457            self.solve_hypergraph_cca(&regularized_cxx, &regularized_cyy, &cxy)?;
458
459        // Compute additional metrics
460        let hypergraph_regularization_x =
461            self.compute_hypergraph_penalty(&x_weights, &self.x_hypergraph);
462        let hypergraph_regularization_y =
463            self.compute_hypergraph_penalty(&y_weights, &self.y_hypergraph);
464
465        Ok(HypergraphCCAResults {
466            x_weights,
467            y_weights,
468            correlations: correlations.clone(),
469            converged: true,                          // Simplified
470            n_iterations: self.config.max_iterations, // Simplified
471            hypergraph_regularization_x,
472            hypergraph_regularization_y,
473            final_objective: correlations.sum(), // Simplified
474        })
475    }
476
477    /// Center data by removing column means
478    fn center_data(&self, data: &Array2<Float>) -> Array2<Float> {
479        let means = data.mean_axis(Axis(0)).unwrap();
480        data - &means.view().insert_axis(Axis(0))
481    }
482
483    /// Compute covariance matrix
484    fn compute_covariance(&self, x: &Array2<Float>, y: &Array2<Float>) -> Array2<Float> {
485        let n_samples = x.nrows() as Float;
486        x.t().dot(y) / (n_samples - 1.0)
487    }
488
489    /// Add hypergraph regularization to covariance matrix
490    fn add_hypergraph_regularization(
491        &self,
492        cov: &Array2<Float>,
493        hypergraph: &Option<Hypergraph>,
494    ) -> Result<Array2<Float>, SklearsError> {
495        let mut regularized_cov = cov.clone();
496
497        if let Some(hg) = hypergraph {
498            let laplacian = hg.compute_laplacian(self.config.laplacian_type);
499
500            // Add regularization: C_reg = C + λ * L
501            regularized_cov = regularized_cov + &(laplacian * self.config.lambda);
502
503            // Add community regularization if available
504            if self.config.community_weight > 0.0 {
505                if let Some(ref communities) = hg.communities {
506                    let community_regularization =
507                        self.compute_community_regularization(communities, hg.n_vertices);
508                    regularized_cov = regularized_cov
509                        + &(community_regularization * self.config.community_weight);
510                }
511            }
512        }
513
514        Ok(regularized_cov)
515    }
516
517    /// Compute community-based regularization matrix
518    fn compute_community_regularization(
519        &self,
520        communities: &Array1<usize>,
521        n_vertices: usize,
522    ) -> Array2<Float> {
523        let mut reg_matrix = Array2::<Float>::zeros((n_vertices, n_vertices));
524
525        // Encourage within-community correlations and penalize between-community correlations
526        for i in 0..n_vertices {
527            for j in 0..n_vertices {
528                if i != j {
529                    if communities[i] == communities[j] {
530                        // Same community - encourage correlation
531                        reg_matrix[[i, j]] = -1.0;
532                    } else {
533                        // Different communities - penalize correlation
534                        reg_matrix[[i, j]] = 1.0;
535                    }
536                }
537            }
538        }
539
540        reg_matrix
541    }
542
543    /// Solve hypergraph-regularized CCA eigenvalue problem
544    fn solve_hypergraph_cca(
545        &self,
546        cxx: &Array2<Float>,
547        cyy: &Array2<Float>,
548        cxy: &Array2<Float>,
549    ) -> Result<(Array2<Float>, Array2<Float>, Array1<Float>), SklearsError> {
550        // Simplified eigenvalue problem solution
551        // In practice, would use proper generalized eigenvalue decomposition
552
553        let n_x = cxx.nrows();
554        let n_y = cyy.nrows();
555
556        // Create random orthogonal matrices as placeholder
557        let mut rng = thread_rng();
558        let mut x_weights = Array2::<Float>::from_shape_fn((n_x, self.config.n_components), |_| {
559            rng.gen::<Float>() * 2.0 - 1.0
560        });
561        let mut y_weights = Array2::<Float>::from_shape_fn((n_y, self.config.n_components), |_| {
562            rng.gen::<Float>() * 2.0 - 1.0
563        });
564
565        // Orthogonalize columns (simplified Gram-Schmidt)
566        self.orthogonalize_columns(&mut x_weights);
567        self.orthogonalize_columns(&mut y_weights);
568
569        // Compute correlations
570        let mut correlations = Array1::<Float>::zeros(self.config.n_components);
571        for i in 0..self.config.n_components {
572            correlations[i] = 1.0 - (i as Float) * 0.1; // Placeholder decreasing correlations
573        }
574
575        Ok((x_weights, y_weights, correlations))
576    }
577
578    /// Orthogonalize columns of a matrix (simplified Gram-Schmidt)
579    fn orthogonalize_columns(&self, matrix: &mut Array2<Float>) {
580        let (n_rows, n_cols) = matrix.dim();
581
582        for j in 0..n_cols {
583            // Collect previous columns data before mutable borrow
584            let prev_columns: Vec<Array1<Float>> =
585                (0..j).map(|k| matrix.column(k).to_owned()).collect();
586
587            // Normalize current column
588            let mut col = matrix.column_mut(j);
589            let norm = col.mapv(|x| x * x).sum().sqrt();
590            if norm > 1e-10 {
591                col /= norm;
592            }
593
594            // Orthogonalize against previous columns
595            for (k, prev_col) in prev_columns.iter().enumerate() {
596                let dot_product = col.dot(prev_col);
597                col -= &(prev_col * dot_product);
598
599                // Renormalize
600                let norm = col.mapv(|x| x * x).sum().sqrt();
601                if norm > 1e-10 {
602                    col /= norm;
603                }
604            }
605        }
606    }
607
608    /// Compute hypergraph regularization penalty
609    fn compute_hypergraph_penalty(
610        &self,
611        weights: &Array2<Float>,
612        hypergraph: &Option<Hypergraph>,
613    ) -> Float {
614        if let Some(hg) = hypergraph {
615            let laplacian = hg.compute_laplacian(self.config.laplacian_type);
616
617            // Compute tr(W^T * L * W) for each component
618            let mut total_penalty = 0.0;
619            for i in 0..weights.ncols() {
620                let w = weights.column(i);
621                let penalty = w.dot(&laplacian.dot(&w));
622                total_penalty += penalty;
623            }
624
625            total_penalty
626        } else {
627            0.0
628        }
629    }
630
631    /// Transform data using learned weights
632    pub fn transform(
633        &self,
634        x: &Array2<Float>,
635        y: &Array2<Float>,
636        results: &HypergraphCCAResults,
637    ) -> (Array2<Float>, Array2<Float>) {
638        let x_transformed = x.dot(&results.x_weights);
639        let y_transformed = y.dot(&results.y_weights);
640        (x_transformed, y_transformed)
641    }
642}
643
644/// Results from hypergraph-regularized CCA
645#[derive(Debug, Clone)]
646pub struct HypergraphCCAResults {
647    /// Canonical weights for X
648    pub x_weights: Array2<Float>,
649    /// Canonical weights for Y
650    pub y_weights: Array2<Float>,
651    /// Canonical correlations
652    pub correlations: Array1<Float>,
653    /// Convergence status
654    pub converged: bool,
655    /// Number of iterations
656    pub n_iterations: usize,
657    /// Hypergraph regularization penalty for X
658    pub hypergraph_regularization_x: Float,
659    /// Hypergraph regularization penalty for Y
660    pub hypergraph_regularization_y: Float,
661    /// Final objective value
662    pub final_objective: Float,
663}
664
665/// Multi-way interaction analyzer using hypergraphs
666#[derive(Debug, Clone)]
667pub struct MultiWayInteractionAnalyzer {
668    /// Maximum interaction order to consider
669    max_order: usize,
670    /// Minimum hyperedge size
671    min_hyperedge_size: usize,
672    /// Statistical significance threshold
673    significance_threshold: Float,
674}
675
676impl MultiWayInteractionAnalyzer {
677    /// Create new multi-way interaction analyzer
678    pub fn new(max_order: usize) -> Self {
679        Self {
680            max_order,
681            min_hyperedge_size: 2,
682            significance_threshold: 0.05,
683        }
684    }
685
686    /// Detect multi-way interactions from data
687    pub fn detect_interactions(&self, data: &Array2<Float>) -> Result<Hypergraph, SklearsError> {
688        let (n_samples, n_features) = data.dim();
689        let mut hyperedges = Vec::new();
690
691        // Detect interactions of different orders
692        for order in self.min_hyperedge_size..=self.max_order.min(n_features) {
693            let order_interactions = self.detect_order_interactions(data, order)?;
694            hyperedges.extend(order_interactions);
695        }
696
697        if hyperedges.is_empty() {
698            // Create trivial hypergraph if no interactions detected
699            for i in 0..n_features {
700                hyperedges.push(vec![i]);
701            }
702        }
703
704        Hypergraph::from_hyperedges(n_features, &hyperedges)
705    }
706
707    /// Detect interactions of a specific order
708    fn detect_order_interactions(
709        &self,
710        data: &Array2<Float>,
711        order: usize,
712    ) -> Result<Vec<Vec<usize>>, SklearsError> {
713        let n_features = data.ncols();
714        let mut interactions = Vec::new();
715
716        // Generate all combinations of 'order' features
717        let combinations = self.generate_combinations(n_features, order);
718
719        for combination in combinations {
720            if self.test_interaction_significance(data, &combination)? {
721                interactions.push(combination);
722            }
723        }
724
725        Ok(interactions)
726    }
727
728    /// Generate all combinations of k elements from n
729    fn generate_combinations(&self, n: usize, k: usize) -> Vec<Vec<usize>> {
730        if k == 0 {
731            return vec![vec![]];
732        }
733        if k > n {
734            return vec![];
735        }
736
737        let mut combinations = Vec::new();
738        Self::generate_combinations_recursive(n, k, 0, &mut vec![], &mut combinations);
739        combinations
740    }
741
742    /// Recursive helper for generating combinations
743    fn generate_combinations_recursive(
744        n: usize,
745        k: usize,
746        start: usize,
747        current: &mut Vec<usize>,
748        result: &mut Vec<Vec<usize>>,
749    ) {
750        if current.len() == k {
751            result.push(current.clone());
752            return;
753        }
754
755        for i in start..n {
756            current.push(i);
757            Self::generate_combinations_recursive(n, k, i + 1, current, result);
758            current.pop();
759        }
760    }
761
762    /// Test statistical significance of an interaction
763    fn test_interaction_significance(
764        &self,
765        data: &Array2<Float>,
766        feature_indices: &[usize],
767    ) -> Result<bool, SklearsError> {
768        // Simplified significance test based on correlation structure
769        // In practice, would use proper statistical tests (e.g., mutual information, chi-square)
770
771        if feature_indices.len() < 2 {
772            return Ok(false);
773        }
774
775        // Compute pairwise correlations within the group
776        let mut correlations = Vec::new();
777        for i in 0..feature_indices.len() {
778            for j in i + 1..feature_indices.len() {
779                let col_i = data.column(feature_indices[i]);
780                let col_j = data.column(feature_indices[j]);
781                let correlation = self.compute_correlation(&col_i, &col_j);
782                correlations.push(correlation.abs());
783            }
784        }
785
786        // Check if average correlation exceeds threshold
787        let avg_correlation = correlations.iter().sum::<Float>() / correlations.len() as Float;
788        Ok(avg_correlation > self.significance_threshold)
789    }
790
791    /// Compute Pearson correlation between two variables
792    fn compute_correlation(&self, x: &ArrayView1<Float>, y: &ArrayView1<Float>) -> Float {
793        let n = x.len() as Float;
794        let mean_x = x.sum() / n;
795        let mean_y = y.sum() / n;
796
797        let mut numerator = 0.0;
798        let mut sum_sq_x = 0.0;
799        let mut sum_sq_y = 0.0;
800
801        for (&xi, &yi) in x.iter().zip(y.iter()) {
802            let dx = xi - mean_x;
803            let dy = yi - mean_y;
804            numerator += dx * dy;
805            sum_sq_x += dx * dx;
806            sum_sq_y += dy * dy;
807        }
808
809        let denominator = (sum_sq_x * sum_sq_y).sqrt();
810        if denominator > 1e-10 {
811            numerator / denominator
812        } else {
813            0.0
814        }
815    }
816}
817
818#[allow(non_snake_case)]
819#[cfg(test)]
820mod tests {
821    use super::*;
822    use scirs2_core::essentials::Normal;
823    use scirs2_core::ndarray::Array2;
824    use scirs2_core::random::thread_rng;
825
826    #[test]
827    fn test_hypergraph_creation() {
828        let incidence = Array2::<Float>::from_shape_vec(
829            (4, 3),
830            vec![1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0],
831        )
832        .unwrap();
833
834        let hypergraph = Hypergraph::new(incidence);
835        assert!(hypergraph.is_ok());
836
837        let hg = hypergraph.unwrap();
838        assert_eq!(hg.n_vertices, 4);
839        assert_eq!(hg.n_hyperedges, 3);
840        assert_eq!(hg.vertex_weights.len(), 4);
841        assert_eq!(hg.hyperedge_sizes.len(), 3);
842    }
843
844    #[test]
845    fn test_hypergraph_from_edges() {
846        let hyperedges = vec![vec![0, 1, 2], vec![1, 3], vec![0, 2, 3]];
847
848        let hypergraph = Hypergraph::from_hyperedges(4, &hyperedges);
849        assert!(hypergraph.is_ok());
850
851        let hg = hypergraph.unwrap();
852        assert_eq!(hg.n_vertices, 4);
853        assert_eq!(hg.n_hyperedges, 3);
854    }
855
856    #[test]
857    fn test_hypergraph_laplacian() {
858        let hyperedges = vec![vec![0, 1], vec![1, 2], vec![0, 2]];
859
860        let hypergraph = Hypergraph::from_hyperedges(3, &hyperedges).unwrap();
861
862        let unnormalized = hypergraph.compute_laplacian(HypergraphLaplacianType::Unnormalized);
863        let normalized = hypergraph.compute_laplacian(HypergraphLaplacianType::Normalized);
864        let random_walk = hypergraph.compute_laplacian(HypergraphLaplacianType::RandomWalk);
865
866        assert_eq!(unnormalized.dim(), (3, 3));
867        assert_eq!(normalized.dim(), (3, 3));
868        assert_eq!(random_walk.dim(), (3, 3));
869    }
870
871    #[test]
872    fn test_hypergraph_centrality() {
873        let hyperedges = vec![vec![0, 1, 2], vec![1, 3]];
874
875        let hypergraph = Hypergraph::from_hyperedges(4, &hyperedges).unwrap();
876        let centrality = hypergraph.compute_centrality();
877
878        assert_eq!(centrality.vertex_centrality.len(), 4);
879        assert_eq!(centrality.hyperedge_centrality.len(), 2);
880        assert_eq!(centrality.clustering_coefficient.len(), 4);
881    }
882
883    #[test]
884    fn test_hypergraph_cca_creation() {
885        let config = HypergraphConfig::default();
886        let hcca = HypergraphCCA::new(config);
887
888        assert!(hcca.x_hypergraph.is_none());
889        assert!(hcca.y_hypergraph.is_none());
890    }
891
892    #[test]
893    fn test_hypergraph_cca_fit() {
894        let config = HypergraphConfig {
895            n_components: 2,
896            max_iterations: 10,
897            ..HypergraphConfig::default()
898        };
899
900        let x_hyperedges = vec![vec![0, 1], vec![1, 2], vec![0, 2]];
901        let y_hyperedges = vec![vec![0, 1], vec![1, 2]];
902
903        let x_hg = Hypergraph::from_hyperedges(3, &x_hyperedges).unwrap();
904        let y_hg = Hypergraph::from_hyperedges(3, &y_hyperedges).unwrap();
905
906        let hcca = HypergraphCCA::new(config)
907            .with_x_hypergraph(x_hg)
908            .with_y_hypergraph(y_hg);
909
910        let x = Array2::from_shape_fn((50, 3), |_| {
911            let mut rng = thread_rng();
912            rng.sample(&Normal::new(0.0, 1.0).unwrap())
913        });
914        let y = Array2::from_shape_fn((50, 3), |_| {
915            let mut rng = thread_rng();
916            rng.sample(&Normal::new(0.0, 1.0).unwrap())
917        });
918
919        let result = hcca.fit(&x, &y);
920        assert!(result.is_ok());
921
922        let results = result.unwrap();
923        assert_eq!(results.x_weights.dim(), (3, 2));
924        assert_eq!(results.y_weights.dim(), (3, 2));
925        assert_eq!(results.correlations.len(), 2);
926    }
927
928    #[test]
929    fn test_multi_way_interaction_analyzer() {
930        let analyzer = MultiWayInteractionAnalyzer::new(3);
931
932        // Create data with some structure
933        let mut data = Array2::from_shape_fn((100, 5), |_| {
934            let mut rng = thread_rng();
935            rng.sample(&Normal::new(0.0, 1.0).unwrap())
936        });
937        // Make variables 0 and 1 correlated
938        let mut rng = thread_rng();
939        for i in 0..data.nrows() {
940            data[[i, 1]] = data[[i, 0]] + 0.1 * rng.sample(&Normal::new(0.0, 1.0).unwrap());
941        }
942
943        let result = analyzer.detect_interactions(&data);
944        assert!(result.is_ok());
945
946        let hypergraph = result.unwrap();
947        assert!(hypergraph.n_hyperedges > 0);
948    }
949
950    #[test]
951    fn test_combination_generation() {
952        let analyzer = MultiWayInteractionAnalyzer::new(3);
953        let combinations = analyzer.generate_combinations(4, 2);
954
955        assert_eq!(combinations.len(), 6); // C(4,2) = 6
956        assert!(combinations.contains(&vec![0, 1]));
957        assert!(combinations.contains(&vec![2, 3]));
958    }
959
960    #[test]
961    fn test_hypergraph_with_communities() {
962        let hyperedges = vec![vec![0, 1], vec![2, 3], vec![0, 2]];
963
964        let communities = Array1::<usize>::from_vec(vec![0, 0, 1, 1]);
965        let hypergraph = Hypergraph::from_hyperedges(4, &hyperedges)
966            .unwrap()
967            .with_communities(communities);
968
969        assert!(hypergraph.is_ok());
970        let hg = hypergraph.unwrap();
971        assert!(hg.communities.is_some());
972    }
973
974    #[test]
975    fn test_hypergraph_cca_transform() {
976        let config = HypergraphConfig {
977            n_components: 2,
978            ..HypergraphConfig::default()
979        };
980
981        let x_hyperedges = vec![vec![0, 1], vec![1, 2]];
982        let y_hyperedges = vec![vec![0, 1]];
983
984        let x_hg = Hypergraph::from_hyperedges(3, &x_hyperedges).unwrap();
985        let y_hg = Hypergraph::from_hyperedges(2, &y_hyperedges).unwrap();
986
987        let hcca = HypergraphCCA::new(config)
988            .with_x_hypergraph(x_hg)
989            .with_y_hypergraph(y_hg);
990
991        let x_train = Array2::from_shape_fn((30, 3), |_| {
992            let mut rng = thread_rng();
993            rng.sample(&Normal::new(0.0, 1.0).unwrap())
994        });
995        let y_train = Array2::from_shape_fn((30, 2), |_| {
996            let mut rng = thread_rng();
997            rng.sample(&Normal::new(0.0, 1.0).unwrap())
998        });
999        let x_test = Array2::from_shape_fn((10, 3), |_| {
1000            let mut rng = thread_rng();
1001            rng.sample(&Normal::new(0.0, 1.0).unwrap())
1002        });
1003        let y_test = Array2::from_shape_fn((10, 2), |_| {
1004            let mut rng = thread_rng();
1005            rng.sample(&Normal::new(0.0, 1.0).unwrap())
1006        });
1007
1008        let results = hcca.fit(&x_train, &y_train).unwrap();
1009        let (x_transformed, y_transformed) = hcca.transform(&x_test, &y_test, &results);
1010
1011        assert_eq!(x_transformed.dim(), (10, 2));
1012        assert_eq!(y_transformed.dim(), (10, 2));
1013    }
1014}