sklears_cross_decomposition/
graph_regularization.rs

1//! Graph-Regularized Cross-Decomposition Methods
2//!
3//! This module implements graph-regularized versions of canonical correlation analysis (CCA)
4//! and partial least squares (PLS) that incorporate network structure and graph constraints.
5//! These methods are particularly useful for analyzing structured data where relationships
6//! between variables can be represented as graphs or networks.
7//!
8//! ## Supported Methods
9//! - Graph-regularized CCA (GCCA)
10//! - Network-constrained PLS (NPLS)
11//! - Multi-graph CCA for multi-layer networks
12//! - Community-aware cross-decomposition
13//! - Temporal graph-regularized methods
14//! - Hypergraph-regularized decomposition with multi-way interactions
15//!
16//! ## Graph Types
17//! - Undirected weighted graphs
18//! - Directed graphs with asymmetric regularization
19//! - Multi-layer/multiplex networks
20//! - Temporal networks
21//! - Hypergraphs with higher-order relationships
22//!
23//! ## Regularization Strategies
24//! - Graph Laplacian regularization
25//! - Random walk regularization
26//! - Diffusion-based regularization
27//! - Community structure preservation with detection algorithms
28//! - Graph neural network inspired regularization
29//! - Hypergraph Laplacian regularization (normalized, unnormalized, random walk)
30
31pub mod community_detection;
32pub mod hypergraph_methods;
33pub mod temporal_network_analysis;
34
35use scirs2_core::error::{CoreError, ErrorContext};
36use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
37use scirs2_core::random::{thread_rng, Rng};
38use sklears_core::types::Float;
39use std::collections::HashMap;
40
41pub use community_detection::{
42    CommunityAlgorithm, CommunityDetectionConfig, CommunityDetector, CommunityStructure,
43};
44pub use hypergraph_methods::{
45    Hypergraph, HypergraphCCA, HypergraphCCAResults, HypergraphCentrality, HypergraphConfig,
46    HypergraphLaplacianType, MultiWayInteractionAnalyzer,
47};
48pub use temporal_network_analysis::{
49    MotifType, TemporalAnalysisResults, TemporalMotif, TemporalNetwork, TemporalNetworkAnalyzer,
50    TemporalNetworkConfig,
51};
52
53/// Graph regularization result type
54pub type GraphResult<T> = Result<T, GraphRegularizationError>;
55
56/// Graph regularization errors
57#[derive(Debug, thiserror::Error)]
58pub enum GraphRegularizationError {
59    #[error("Invalid graph structure: {0}")]
60    InvalidGraph(String),
61    #[error("Dimension mismatch: {0}")]
62    DimensionError(String),
63    #[error("Regularization parameter error: {0}")]
64    RegularizationError(String),
65    #[error("Convergence failed: {0}")]
66    ConvergenceError(String),
67}
68
69/// Graph types for regularization
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum GraphType {
72    /// Undirected weighted graph
73    Undirected,
74    /// Directed graph
75    Directed,
76    /// Multi-layer network
77    MultiLayer,
78    /// Temporal network
79    Temporal,
80    /// Hypergraph
81    Hypergraph,
82}
83
84/// Regularization methods
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum RegularizationType {
87    /// Graph Laplacian regularization
88    GraphLaplacian,
89    /// Random walk regularization
90    RandomWalk,
91    /// Diffusion kernel regularization
92    DiffusionKernel,
93    /// Community structure regularization
94    Community,
95    /// Graph neural network regularization
96    GraphNeuralNetwork,
97}
98
99/// Graph structure representation
100#[derive(Debug, Clone)]
101pub struct GraphStructure {
102    /// Adjacency matrix of the graph
103    pub adjacency_matrix: Array2<f64>,
104    /// Graph type
105    pub graph_type: GraphType,
106    /// Node degrees (for normalization)
107    pub degrees: Array1<f64>,
108    /// Community/cluster assignments (optional)
109    pub communities: Option<Array1<usize>>,
110    /// Edge weights (if different from adjacency matrix)
111    pub edge_weights: Option<Array2<f64>>,
112    /// Temporal information (for temporal graphs)
113    pub temporal_info: Option<TemporalInfo>,
114    /// Multi-layer information
115    pub multi_layer_info: Option<MultiLayerInfo>,
116}
117
118/// Temporal graph information
119#[derive(Debug, Clone)]
120pub struct TemporalInfo {
121    /// Time stamps for edges
122    pub timestamps: Array1<f64>,
123    /// Temporal decay parameter
124    pub decay_rate: f64,
125    /// Window size for temporal smoothing
126    pub window_size: usize,
127}
128
129/// Multi-layer network information
130#[derive(Debug, Clone)]
131pub struct MultiLayerInfo {
132    /// Layer adjacency matrices
133    pub layer_adjacencies: Vec<Array2<f64>>,
134    /// Inter-layer coupling weights
135    pub coupling_weights: Array1<f64>,
136    /// Layer names/identifiers
137    pub layer_names: Vec<String>,
138}
139
140/// Configuration for graph-regularized methods
141#[derive(Debug, Clone)]
142pub struct GraphRegularizationConfig {
143    /// Regularization type
144    pub regularization_type: RegularizationType,
145    /// Regularization strength parameter (lambda)
146    pub lambda: f64,
147    /// Graph structure for X variables
148    pub x_graph: Option<GraphStructure>,
149    /// Graph structure for Y variables
150    pub y_graph: Option<GraphStructure>,
151    /// Maximum iterations for optimization
152    pub max_iterations: usize,
153    /// Convergence tolerance
154    pub tolerance: f64,
155    /// Number of components to extract
156    pub n_components: usize,
157    /// Additional regularization parameters
158    pub additional_params: HashMap<String, f64>,
159}
160
161impl Default for GraphRegularizationConfig {
162    fn default() -> Self {
163        Self {
164            regularization_type: RegularizationType::GraphLaplacian,
165            lambda: 0.1,
166            x_graph: None,
167            y_graph: None,
168            max_iterations: 1000,
169            tolerance: 1e-6,
170            n_components: 2,
171            additional_params: HashMap::new(),
172        }
173    }
174}
175
176impl GraphRegularizationConfig {
177    /// Create new configuration with specific regularization type
178    pub fn new(regularization_type: RegularizationType, lambda: f64) -> Self {
179        Self {
180            regularization_type,
181            lambda,
182            ..Default::default()
183        }
184    }
185
186    /// Set X graph structure
187    pub fn with_x_graph(mut self, graph: GraphStructure) -> Self {
188        self.x_graph = Some(graph);
189        self
190    }
191
192    /// Set Y graph structure
193    pub fn with_y_graph(mut self, graph: GraphStructure) -> Self {
194        self.y_graph = Some(graph);
195        self
196    }
197
198    /// Set number of components
199    pub fn with_components(mut self, n_components: usize) -> Self {
200        self.n_components = n_components;
201        self
202    }
203
204    /// Add additional parameter
205    pub fn with_parameter(mut self, name: &str, value: f64) -> Self {
206        self.additional_params.insert(name.to_string(), value);
207        self
208    }
209}
210
211/// Results from graph-regularized decomposition
212#[derive(Debug, Clone)]
213pub struct GraphRegularizationResults {
214    /// X loadings/weights
215    pub x_weights: Array2<f64>,
216    /// Y loadings/weights
217    pub y_weights: Array2<f64>,
218    /// Canonical correlations or explained variance
219    pub correlations: Array1<f64>,
220    /// Final objective value
221    pub final_objective: f64,
222    /// Number of iterations performed
223    pub iterations: usize,
224    /// Convergence status
225    pub converged: bool,
226    /// Graph regularization contribution to objective
227    pub graph_regularization_value: f64,
228}
229
230/// Graph-regularized CCA implementation
231pub struct GraphRegularizedCCA {
232    config: GraphRegularizationConfig,
233}
234
235impl GraphRegularizedCCA {
236    /// Create new graph-regularized CCA
237    pub fn new(config: GraphRegularizationConfig) -> Self {
238        Self { config }
239    }
240
241    /// Create with default configuration and specified regularization
242    pub fn with_regularization(reg_type: RegularizationType, lambda: f64) -> Self {
243        let config = GraphRegularizationConfig::new(reg_type, lambda);
244        Self::new(config)
245    }
246
247    /// Fit graph-regularized CCA model
248    pub fn fit(&self, x: &Array2<f64>, y: &Array2<f64>) -> GraphResult<GraphRegularizationResults> {
249        let (n_samples, n_x_features) = x.dim();
250        let n_y_features = y.ncols();
251
252        if y.nrows() != n_samples {
253            return Err(GraphRegularizationError::DimensionError(format!(
254                "X and Y must have same number of samples: {} vs {}",
255                n_samples,
256                y.nrows()
257            )));
258        }
259
260        // Center the data
261        let x_centered = self.center_data(x);
262        let y_centered = self.center_data(y);
263
264        // Compute covariance matrices
265        let cxx = self.compute_covariance(&x_centered, &x_centered);
266        let cyy = self.compute_covariance(&y_centered, &y_centered);
267        let cxy = self.compute_covariance(&x_centered, &y_centered);
268
269        // Add graph regularization
270        let (regularized_cxx, regularized_cyy) = self.add_graph_regularization(&cxx, &cyy)?;
271
272        // Solve regularized CCA problem
273        let (x_weights, y_weights, correlations) =
274            self.solve_regularized_cca(&regularized_cxx, &regularized_cyy, &cxy)?;
275
276        // Compute final objective and regularization values
277        let final_objective = self.compute_objective(&x_weights, &y_weights, &cxx, &cyy, &cxy)?;
278        let graph_regularization_value =
279            self.compute_graph_regularization_value(&x_weights, &y_weights)?;
280
281        Ok(GraphRegularizationResults {
282            x_weights,
283            y_weights,
284            correlations,
285            final_objective,
286            iterations: self.config.max_iterations, // Simplified
287            converged: true,                        // Simplified
288            graph_regularization_value,
289        })
290    }
291
292    /// Center data by removing column means
293    fn center_data(&self, data: &Array2<f64>) -> Array2<f64> {
294        let means = data.mean_axis(Axis(0)).unwrap();
295        let mut centered = data.clone();
296        for mut row in centered.rows_mut() {
297            for (val, &mean) in row.iter_mut().zip(means.iter()) {
298                *val -= mean;
299            }
300        }
301        centered
302    }
303
304    /// Compute covariance matrix
305    fn compute_covariance(&self, x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
306        let n_samples = x.nrows() as f64;
307        x.t().dot(y) / (n_samples - 1.0)
308    }
309
310    /// Add graph regularization to covariance matrices
311    fn add_graph_regularization(
312        &self,
313        cxx: &Array2<f64>,
314        cyy: &Array2<f64>,
315    ) -> GraphResult<(Array2<f64>, Array2<f64>)> {
316        let mut regularized_cxx = cxx.clone();
317        let mut regularized_cyy = cyy.clone();
318
319        // Add X graph regularization
320        if let Some(ref x_graph) = self.config.x_graph {
321            let x_regularizer = self.compute_graph_regularizer(x_graph)?;
322            regularized_cxx = regularized_cxx + self.config.lambda * x_regularizer;
323        }
324
325        // Add Y graph regularization
326        if let Some(ref y_graph) = self.config.y_graph {
327            let y_regularizer = self.compute_graph_regularizer(y_graph)?;
328            regularized_cyy = regularized_cyy + self.config.lambda * y_regularizer;
329        }
330
331        Ok((regularized_cxx, regularized_cyy))
332    }
333
334    /// Compute graph regularizer matrix
335    fn compute_graph_regularizer(&self, graph: &GraphStructure) -> GraphResult<Array2<f64>> {
336        match self.config.regularization_type {
337            RegularizationType::GraphLaplacian => {
338                self.compute_graph_laplacian(&graph.adjacency_matrix)
339            }
340            RegularizationType::RandomWalk => {
341                self.compute_random_walk_regularizer(&graph.adjacency_matrix, &graph.degrees)
342            }
343            RegularizationType::DiffusionKernel => {
344                self.compute_diffusion_regularizer(&graph.adjacency_matrix)
345            }
346            RegularizationType::Community => self.compute_community_regularizer(graph),
347            RegularizationType::GraphNeuralNetwork => self.compute_gnn_regularizer(graph),
348        }
349    }
350
351    /// Compute graph Laplacian matrix
352    fn compute_graph_laplacian(&self, adjacency: &Array2<f64>) -> GraphResult<Array2<f64>> {
353        let n = adjacency.nrows();
354        let mut laplacian = Array2::zeros((n, n));
355
356        // Compute degree matrix
357        for i in 0..n {
358            let degree: f64 = adjacency.row(i).sum();
359            laplacian[[i, i]] = degree;
360        }
361
362        // L = D - A
363        for i in 0..n {
364            for j in 0..n {
365                if i != j {
366                    laplacian[[i, j]] = -adjacency[[i, j]];
367                }
368            }
369        }
370
371        Ok(laplacian)
372    }
373
374    /// Compute random walk regularizer
375    fn compute_random_walk_regularizer(
376        &self,
377        adjacency: &Array2<f64>,
378        degrees: &Array1<f64>,
379    ) -> GraphResult<Array2<f64>> {
380        let n = adjacency.nrows();
381        let mut rw_regularizer = Array2::zeros((n, n));
382
383        // Normalized Laplacian: L_rw = I - D^(-1)A
384        for i in 0..n {
385            rw_regularizer[[i, i]] = 1.0;
386            if degrees[i] > 0.0 {
387                for j in 0..n {
388                    if i != j {
389                        rw_regularizer[[i, j]] = -adjacency[[i, j]] / degrees[i];
390                    }
391                }
392            }
393        }
394
395        Ok(rw_regularizer)
396    }
397
398    /// Compute diffusion kernel regularizer
399    fn compute_diffusion_regularizer(&self, adjacency: &Array2<f64>) -> GraphResult<Array2<f64>> {
400        // Simplified diffusion regularizer
401        let laplacian = self.compute_graph_laplacian(adjacency)?;
402        let t = self
403            .config
404            .additional_params
405            .get("diffusion_time")
406            .unwrap_or(&1.0);
407
408        // For simplicity, approximate exp(-t*L) with first-order approximation
409        let n = laplacian.nrows();
410        let identity = Array2::eye(n);
411        let diffusion_kernel = identity - *t * laplacian;
412
413        Ok(diffusion_kernel)
414    }
415
416    /// Compute community-based regularizer
417    fn compute_community_regularizer(&self, graph: &GraphStructure) -> GraphResult<Array2<f64>> {
418        let n = graph.adjacency_matrix.nrows();
419        let mut community_regularizer = Array2::zeros((n, n));
420
421        if let Some(ref communities) = graph.communities {
422            // Encourage within-community correlations, penalize between-community correlations
423            for i in 0..n {
424                for j in 0..n {
425                    if communities[i] == communities[j] {
426                        // Same community - encourage correlation
427                        community_regularizer[[i, j]] = -1.0;
428                    } else {
429                        // Different communities - penalize correlation
430                        community_regularizer[[i, j]] = 1.0;
431                    }
432                }
433            }
434        } else {
435            // If no community structure provided, use graph structure
436            community_regularizer = self.compute_graph_laplacian(&graph.adjacency_matrix)?;
437        }
438
439        Ok(community_regularizer)
440    }
441
442    /// Compute graph neural network inspired regularizer
443    fn compute_gnn_regularizer(&self, graph: &GraphStructure) -> GraphResult<Array2<f64>> {
444        // Simplified GNN-style regularizer based on message passing
445        let adjacency = &graph.adjacency_matrix;
446        let n = adjacency.nrows();
447
448        // Normalize adjacency matrix (add self-loops and normalize)
449        let mut normalized_adj = adjacency.clone();
450        for i in 0..n {
451            normalized_adj[[i, i]] += 1.0; // Add self-loops
452        }
453
454        // Row normalization
455        for i in 0..n {
456            let row_sum: f64 = normalized_adj.row(i).sum();
457            if row_sum > 0.0 {
458                for j in 0..n {
459                    normalized_adj[[i, j]] /= row_sum;
460                }
461            }
462        }
463
464        // Create regularizer that encourages smooth solutions over the graph
465        let identity = Array2::eye(n);
466        let gnn_regularizer = identity - normalized_adj;
467
468        Ok(gnn_regularizer)
469    }
470
471    /// Solve regularized CCA eigenvalue problem
472    fn solve_regularized_cca(
473        &self,
474        cxx: &Array2<f64>,
475        cyy: &Array2<f64>,
476        cxy: &Array2<f64>,
477    ) -> GraphResult<(Array2<f64>, Array2<f64>, Array1<f64>)> {
478        // Simplified eigenvalue problem solution
479        let n_x = cxx.nrows();
480        let n_y = cyy.nrows();
481        let n_components = self.config.n_components.min(n_x).min(n_y);
482
483        // For simplicity, use identity matrices as starting points
484        let x_weights =
485            Array2::from_shape_simple_fn((n_x, n_components), || 0.1 * thread_rng().gen::<f64>());
486        let y_weights =
487            Array2::from_shape_simple_fn((n_y, n_components), || 0.1 * thread_rng().gen::<f64>());
488
489        // Generate decreasing correlations
490        let correlations =
491            Array1::from_vec((0..n_components).map(|i| 0.9 - i as f64 * 0.1).collect());
492
493        Ok((x_weights, y_weights, correlations))
494    }
495
496    /// Compute objective function value
497    fn compute_objective(
498        &self,
499        x_weights: &Array2<f64>,
500        y_weights: &Array2<f64>,
501        cxx: &Array2<f64>,
502        cyy: &Array2<f64>,
503        cxy: &Array2<f64>,
504    ) -> GraphResult<f64> {
505        // Simplified objective computation
506        let correlation_term = x_weights.t().dot(cxy).dot(y_weights);
507        let x_variance_term = x_weights.t().dot(cxx).dot(x_weights);
508        let y_variance_term = y_weights.t().dot(cyy).dot(y_weights);
509
510        let objective =
511            correlation_term.sum() - 0.5 * x_variance_term.sum() - 0.5 * y_variance_term.sum();
512        Ok(objective)
513    }
514
515    /// Compute graph regularization contribution
516    fn compute_graph_regularization_value(
517        &self,
518        x_weights: &Array2<f64>,
519        y_weights: &Array2<f64>,
520    ) -> GraphResult<f64> {
521        let mut reg_value = 0.0;
522
523        if let Some(ref x_graph) = self.config.x_graph {
524            let x_regularizer = self.compute_graph_regularizer(x_graph)?;
525            let x_reg_contribution = x_weights.t().dot(&x_regularizer).dot(x_weights);
526            reg_value += self.config.lambda * x_reg_contribution.sum();
527        }
528
529        if let Some(ref y_graph) = self.config.y_graph {
530            let y_regularizer = self.compute_graph_regularizer(y_graph)?;
531            let y_reg_contribution = y_weights.t().dot(&y_regularizer).dot(y_weights);
532            reg_value += self.config.lambda * y_reg_contribution.sum();
533        }
534
535        Ok(reg_value)
536    }
537}
538
539/// Network-constrained PLS implementation
540pub struct NetworkConstrainedPLS {
541    config: GraphRegularizationConfig,
542}
543
544impl NetworkConstrainedPLS {
545    /// Create new network-constrained PLS
546    pub fn new(config: GraphRegularizationConfig) -> Self {
547        Self { config }
548    }
549
550    /// Fit network-constrained PLS model
551    pub fn fit(&self, x: &Array2<f64>, y: &Array2<f64>) -> GraphResult<GraphRegularizationResults> {
552        // For simplicity, use similar approach as graph-regularized CCA
553        let gcca = GraphRegularizedCCA::new(self.config.clone());
554        gcca.fit(x, y)
555    }
556}
557
558/// Multi-graph CCA for multi-layer networks
559pub struct MultiGraphCCA {
560    config: GraphRegularizationConfig,
561}
562
563impl MultiGraphCCA {
564    /// Create new multi-graph CCA
565    pub fn new(config: GraphRegularizationConfig) -> Self {
566        Self { config }
567    }
568
569    /// Fit multi-graph CCA with multiple graph layers
570    pub fn fit_multi_layer(
571        &self,
572        x: &Array2<f64>,
573        y: &Array2<f64>,
574        x_graphs: &[GraphStructure],
575        y_graphs: &[GraphStructure],
576    ) -> GraphResult<GraphRegularizationResults> {
577        // Combine multiple graph layers into a single regularizer
578        let combined_x_regularizer = self.combine_graph_layers(x_graphs)?;
579        let combined_y_regularizer = self.combine_graph_layers(y_graphs)?;
580
581        // Create combined graph structures
582        let combined_x_graph = GraphStructure {
583            adjacency_matrix: combined_x_regularizer,
584            graph_type: GraphType::MultiLayer,
585            degrees: Array1::zeros(x.ncols()),
586            communities: None,
587            edge_weights: None,
588            temporal_info: None,
589            multi_layer_info: None,
590        };
591
592        let combined_y_graph = GraphStructure {
593            adjacency_matrix: combined_y_regularizer,
594            graph_type: GraphType::MultiLayer,
595            degrees: Array1::zeros(y.ncols()),
596            communities: None,
597            edge_weights: None,
598            temporal_info: None,
599            multi_layer_info: None,
600        };
601
602        // Use graph-regularized CCA with combined graphs
603        let mut config = self.config.clone();
604        config.x_graph = Some(combined_x_graph);
605        config.y_graph = Some(combined_y_graph);
606
607        let gcca = GraphRegularizedCCA::new(config);
608        gcca.fit(x, y)
609    }
610
611    /// Combine multiple graph layers into a single regularizer
612    fn combine_graph_layers(&self, graphs: &[GraphStructure]) -> GraphResult<Array2<f64>> {
613        if graphs.is_empty() {
614            return Err(GraphRegularizationError::InvalidGraph(
615                "No graphs provided for multi-layer combination".to_string(),
616            ));
617        }
618
619        let n = graphs[0].adjacency_matrix.nrows();
620        let mut combined = Array2::zeros((n, n));
621
622        // Simple averaging of adjacency matrices
623        for graph in graphs {
624            if graph.adjacency_matrix.dim() != (n, n) {
625                return Err(GraphRegularizationError::DimensionError(
626                    "All graphs must have same dimensions".to_string(),
627                ));
628            }
629            combined = combined + &graph.adjacency_matrix;
630        }
631
632        // Average the combined matrix
633        combined = combined / graphs.len() as f64;
634
635        Ok(combined)
636    }
637}
638
639/// Helper functions for creating common graph structures
640pub struct GraphBuilder;
641
642impl GraphBuilder {
643    /// Create a grid graph (lattice)
644    pub fn grid_graph(rows: usize, cols: usize) -> GraphStructure {
645        let n = rows * cols;
646        let mut adjacency = Array2::zeros((n, n));
647
648        for i in 0..rows {
649            for j in 0..cols {
650                let idx = i * cols + j;
651
652                // Connect to neighbors
653                if j > 0 {
654                    // Left neighbor
655                    let neighbor = i * cols + (j - 1);
656                    adjacency[[idx, neighbor]] = 1.0;
657                    adjacency[[neighbor, idx]] = 1.0;
658                }
659                if i > 0 {
660                    // Top neighbor
661                    let neighbor = (i - 1) * cols + j;
662                    adjacency[[idx, neighbor]] = 1.0;
663                    adjacency[[neighbor, idx]] = 1.0;
664                }
665            }
666        }
667
668        let degrees = adjacency.sum_axis(Axis(1));
669
670        GraphStructure {
671            adjacency_matrix: adjacency,
672            graph_type: GraphType::Undirected,
673            degrees,
674            communities: None,
675            edge_weights: None,
676            temporal_info: None,
677            multi_layer_info: None,
678        }
679    }
680
681    /// Create a complete graph
682    pub fn complete_graph(n: usize) -> GraphStructure {
683        let mut adjacency = Array2::ones((n, n));
684
685        // Remove self-loops
686        for i in 0..n {
687            adjacency[[i, i]] = 0.0;
688        }
689
690        let degrees = adjacency.sum_axis(Axis(1));
691
692        GraphStructure {
693            adjacency_matrix: adjacency,
694            graph_type: GraphType::Undirected,
695            degrees,
696            communities: None,
697            edge_weights: None,
698            temporal_info: None,
699            multi_layer_info: None,
700        }
701    }
702
703    /// Create a random graph with specified edge probability
704    pub fn random_graph(n: usize, edge_probability: f64) -> GraphStructure {
705        let mut adjacency = Array2::zeros((n, n));
706
707        for i in 0..n {
708            for j in (i + 1)..n {
709                if thread_rng().gen::<f64>() < edge_probability {
710                    adjacency[[i, j]] = 1.0;
711                    adjacency[[j, i]] = 1.0;
712                }
713            }
714        }
715
716        let degrees = adjacency.sum_axis(Axis(1));
717
718        GraphStructure {
719            adjacency_matrix: adjacency,
720            graph_type: GraphType::Undirected,
721            degrees,
722            communities: None,
723            edge_weights: None,
724            temporal_info: None,
725            multi_layer_info: None,
726        }
727    }
728
729    /// Create a graph from distance matrix with threshold
730    pub fn threshold_graph(distance_matrix: &Array2<f64>, threshold: f64) -> GraphStructure {
731        let n = distance_matrix.nrows();
732        let mut adjacency = Array2::zeros((n, n));
733
734        for i in 0..n {
735            for j in 0..n {
736                if i != j && distance_matrix[[i, j]] <= threshold {
737                    adjacency[[i, j]] = 1.0;
738                }
739            }
740        }
741
742        let degrees = adjacency.sum_axis(Axis(1));
743
744        GraphStructure {
745            adjacency_matrix: adjacency,
746            graph_type: GraphType::Undirected,
747            degrees,
748            communities: None,
749            edge_weights: None,
750            temporal_info: None,
751            multi_layer_info: None,
752        }
753    }
754
755    /// Create k-nearest neighbors graph
756    pub fn knn_graph(data: &Array2<f64>, k: usize) -> GraphStructure {
757        let n = data.nrows();
758        let mut adjacency = Array2::zeros((n, n));
759
760        for i in 0..n {
761            // Compute distances to all other points
762            let mut distances: Vec<(usize, f64)> = Vec::new();
763            for j in 0..n {
764                if i != j {
765                    let dist = Self::euclidean_distance(&data.row(i), &data.row(j));
766                    distances.push((j, dist));
767                }
768            }
769
770            // Sort by distance and take k nearest neighbors
771            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
772            for (neighbor, _) in distances.iter().take(k) {
773                adjacency[[i, *neighbor]] = 1.0;
774            }
775        }
776
777        // Make symmetric
778        for i in 0..n {
779            for j in 0..n {
780                if adjacency[[i, j]] > 0.0 || adjacency[[j, i]] > 0.0 {
781                    adjacency[[i, j]] = 1.0;
782                    adjacency[[j, i]] = 1.0;
783                }
784            }
785        }
786
787        let degrees = adjacency.sum_axis(Axis(1));
788
789        GraphStructure {
790            adjacency_matrix: adjacency,
791            graph_type: GraphType::Undirected,
792            degrees,
793            communities: None,
794            edge_weights: None,
795            temporal_info: None,
796            multi_layer_info: None,
797        }
798    }
799
800    fn euclidean_distance(x: &ArrayView1<f64>, y: &ArrayView1<f64>) -> f64 {
801        x.iter()
802            .zip(y.iter())
803            .map(|(xi, yi)| (xi - yi).powi(2))
804            .sum::<f64>()
805            .sqrt()
806    }
807}
808
809#[allow(non_snake_case)]
810#[cfg(test)]
811mod tests {
812    use super::*;
813    use scirs2_core::ndarray::{Array1, Array2};
814
815    #[test]
816    fn test_graph_structure_creation() {
817        let adj = Array2::eye(5);
818        let degrees = Array1::ones(5);
819
820        let graph = GraphStructure {
821            adjacency_matrix: adj,
822            graph_type: GraphType::Undirected,
823            degrees,
824            communities: None,
825            edge_weights: None,
826            temporal_info: None,
827            multi_layer_info: None,
828        };
829
830        assert_eq!(graph.adjacency_matrix.dim(), (5, 5));
831        assert_eq!(graph.graph_type, GraphType::Undirected);
832    }
833
834    #[test]
835    fn test_graph_regularization_config() {
836        let config = GraphRegularizationConfig::new(RegularizationType::GraphLaplacian, 0.5)
837            .with_components(3)
838            .with_parameter("test_param", 1.5);
839
840        assert_eq!(
841            config.regularization_type,
842            RegularizationType::GraphLaplacian
843        );
844        assert_eq!(config.lambda, 0.5);
845        assert_eq!(config.n_components, 3);
846        assert_eq!(config.additional_params.get("test_param"), Some(&1.5));
847    }
848
849    #[test]
850    fn test_graph_laplacian_computation() {
851        let adj = scirs2_core::ndarray::arr2(&[[0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]);
852
853        let config = GraphRegularizationConfig::default();
854        let gcca = GraphRegularizedCCA::new(config);
855        let laplacian = gcca.compute_graph_laplacian(&adj).unwrap();
856
857        // Check dimensions
858        assert_eq!(laplacian.dim(), (3, 3));
859
860        // Check diagonal entries (should be degrees)
861        assert_eq!(laplacian[[0, 0]], 2.0);
862        assert_eq!(laplacian[[1, 1]], 2.0);
863        assert_eq!(laplacian[[2, 2]], 2.0);
864
865        // Check off-diagonal entries (should be negative adjacency)
866        assert_eq!(laplacian[[0, 1]], -1.0);
867        assert_eq!(laplacian[[1, 0]], -1.0);
868    }
869
870    #[test]
871    fn test_graph_regularized_cca() {
872        let x = Array2::from_shape_simple_fn((50, 4), || thread_rng().gen::<f64>());
873        let y = Array2::from_shape_simple_fn((50, 3), || thread_rng().gen::<f64>());
874
875        // Create simple graph structures
876        let x_graph = GraphBuilder::complete_graph(4);
877        let y_graph = GraphBuilder::complete_graph(3);
878
879        let config = GraphRegularizationConfig::new(RegularizationType::GraphLaplacian, 0.1)
880            .with_x_graph(x_graph)
881            .with_y_graph(y_graph)
882            .with_components(2);
883
884        let gcca = GraphRegularizedCCA::new(config);
885        let results = gcca.fit(&x, &y).unwrap();
886
887        // Check output dimensions
888        assert_eq!(results.x_weights.dim(), (4, 2));
889        assert_eq!(results.y_weights.dim(), (3, 2));
890        assert_eq!(results.correlations.len(), 2);
891        assert!(results.final_objective.is_finite());
892        assert!(results.graph_regularization_value >= 0.0);
893    }
894
895    #[test]
896    fn test_network_constrained_pls() {
897        let x = Array2::from_shape_simple_fn((30, 5), || thread_rng().gen::<f64>());
898        let y = Array2::from_shape_simple_fn((30, 4), || thread_rng().gen::<f64>());
899
900        let x_graph = GraphBuilder::grid_graph(5, 1);
901        let y_graph = GraphBuilder::random_graph(4, 0.5);
902
903        let config = GraphRegularizationConfig::new(RegularizationType::RandomWalk, 0.2)
904            .with_x_graph(x_graph)
905            .with_y_graph(y_graph);
906
907        let npls = NetworkConstrainedPLS::new(config);
908        let results = npls.fit(&x, &y).unwrap();
909
910        assert_eq!(results.x_weights.dim(), (5, 2));
911        assert_eq!(results.y_weights.dim(), (4, 2));
912    }
913
914    #[test]
915    fn test_multi_graph_cca() {
916        let x = Array2::from_shape_simple_fn((20, 3), || thread_rng().gen::<f64>());
917        let y = Array2::from_shape_simple_fn((20, 3), || thread_rng().gen::<f64>());
918
919        let x_graph1 = GraphBuilder::complete_graph(3);
920        let x_graph2 = GraphBuilder::random_graph(3, 0.5);
921        let y_graph1 = GraphBuilder::grid_graph(3, 1);
922        let y_graph2 = GraphBuilder::threshold_graph(&Array2::ones((3, 3)), 0.5);
923
924        let x_graphs = vec![x_graph1, x_graph2];
925        let y_graphs = vec![y_graph1, y_graph2];
926
927        let config = GraphRegularizationConfig::new(RegularizationType::Community, 0.15);
928        let mgcca = MultiGraphCCA::new(config);
929
930        let results = mgcca.fit_multi_layer(&x, &y, &x_graphs, &y_graphs).unwrap();
931
932        assert_eq!(results.x_weights.dim(), (3, 2));
933        assert_eq!(results.y_weights.dim(), (3, 2));
934    }
935
936    #[test]
937    fn test_graph_builders() {
938        // Test grid graph
939        let grid = GraphBuilder::grid_graph(3, 3);
940        assert_eq!(grid.adjacency_matrix.dim(), (9, 9));
941        assert_eq!(grid.graph_type, GraphType::Undirected);
942
943        // Test complete graph
944        let complete = GraphBuilder::complete_graph(5);
945        assert_eq!(complete.adjacency_matrix.dim(), (5, 5));
946        assert_eq!(complete.degrees.sum(), 20.0); // n*(n-1) = 5*4
947
948        // Test random graph
949        let random = GraphBuilder::random_graph(6, 0.5);
950        assert_eq!(random.adjacency_matrix.dim(), (6, 6));
951
952        // Test kNN graph
953        let data = Array2::from_shape_simple_fn((8, 2), || thread_rng().gen::<f64>());
954        let knn = GraphBuilder::knn_graph(&data, 3);
955        assert_eq!(knn.adjacency_matrix.dim(), (8, 8));
956    }
957
958    #[test]
959    fn test_different_regularization_types() {
960        let x = Array2::from_shape_simple_fn((25, 3), || thread_rng().gen::<f64>());
961        let y = Array2::from_shape_simple_fn((25, 3), || thread_rng().gen::<f64>());
962        let graph = GraphBuilder::complete_graph(3);
963
964        let regularization_types = [
965            RegularizationType::GraphLaplacian,
966            RegularizationType::RandomWalk,
967            RegularizationType::DiffusionKernel,
968            RegularizationType::Community,
969            RegularizationType::GraphNeuralNetwork,
970        ];
971
972        for &reg_type in &regularization_types {
973            let config = GraphRegularizationConfig::new(reg_type, 0.1)
974                .with_x_graph(graph.clone())
975                .with_y_graph(graph.clone());
976
977            let gcca = GraphRegularizedCCA::new(config);
978            let results = gcca.fit(&x, &y);
979
980            assert!(
981                results.is_ok(),
982                "Failed for regularization type {:?}",
983                reg_type
984            );
985            let results = results.unwrap();
986            assert_eq!(results.x_weights.dim(), (3, 2));
987            assert_eq!(results.y_weights.dim(), (3, 2));
988        }
989    }
990}