1use scirs2_core::ndarray::{s, Array2, ArrayView2, Axis};
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 types::Float,
13};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
41pub struct CopulaBasedModelingAnalyzer {
42 copula_types: Vec<CopulaType>,
44 fit_margins: bool,
46 use_empirical_copula: bool,
48 n_samples: usize,
50 random_state: Option<u64>,
52}
53
54impl CopulaBasedModelingAnalyzer {
55 pub fn new() -> Self {
56 Self {
57 copula_types: vec![CopulaType::Gaussian],
58 fit_margins: true,
59 use_empirical_copula: false,
60 n_samples: 1000,
61 random_state: None,
62 }
63 }
64
65 pub fn copula_types(mut self, copula_types: Vec<CopulaType>) -> Self {
67 self.copula_types = copula_types;
68 self
69 }
70
71 pub fn fit_margins(mut self, fit_margins: bool) -> Self {
73 self.fit_margins = fit_margins;
74 self
75 }
76
77 pub fn use_empirical_copula(mut self, use_empirical_copula: bool) -> Self {
79 self.use_empirical_copula = use_empirical_copula;
80 self
81 }
82
83 pub fn n_samples(mut self, n_samples: usize) -> Self {
85 self.n_samples = n_samples;
86 self
87 }
88
89 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
91 self.random_state = random_state;
92 self
93 }
94
95 pub fn analyze(&self, _outputs: &HashMap<String, Array2<Float>>) -> SklResult<CopulaAnalysis> {
97 let copula_models = HashMap::new();
99 let marginal_distributions = HashMap::new();
100 let goodness_of_fit = HashMap::new();
101 let dependence_measures = HashMap::new();
102 let output_info = HashMap::new();
103
104 Ok(CopulaAnalysis {
105 copula_models,
106 marginal_distributions,
107 goodness_of_fit,
108 dependence_measures,
109 best_copula: None,
110 output_info,
111 empirical_copula: None,
112 })
113 }
114}
115
116impl Default for CopulaBasedModelingAnalyzer {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122#[derive(Debug, Clone, PartialEq, Eq, Hash)]
124pub enum CopulaType {
125 Gaussian,
127 Clayton,
129 Frank,
131 Gumbel,
133 StudentT,
135 Archimedean,
137 Empirical,
139}
140
141#[derive(Debug, Clone)]
143pub struct CopulaAnalysis {
144 pub copula_models: HashMap<CopulaType, CopulaModel>,
146 pub marginal_distributions: HashMap<String, MarginalDistribution>,
148 pub goodness_of_fit: HashMap<CopulaType, GoodnessOfFit>,
150 pub dependence_measures: HashMap<CopulaType, DependenceMeasures>,
152 pub best_copula: Option<CopulaType>,
154 pub output_info: HashMap<String, usize>,
156 pub empirical_copula: Option<EmpiricalCopula>,
158}
159
160#[derive(Debug, Clone)]
162pub struct CopulaModel {
163 pub copula_type: CopulaType,
165 pub parameters: CopulaParameters,
167 pub log_likelihood: Float,
169 pub n_parameters: usize,
171 pub fitted_data: Array2<Float>,
173}
174
175#[derive(Debug, Clone)]
177pub enum CopulaParameters {
178 Gaussian { correlation_matrix: Array2<Float> },
180 Clayton { theta: Float },
182 Frank { theta: Float },
184 Gumbel { theta: Float },
186 StudentT {
188 correlation_matrix: Array2<Float>,
189 degrees_of_freedom: Float,
190 },
191 Archimedean { generator_params: Vec<Float> },
193 Empirical,
195}
196
197#[derive(Debug, Clone)]
199pub struct MarginalDistribution {
200 pub distribution_type: String,
202 pub parameters: Vec<Float>,
204 pub mean: Float,
206 pub std_dev: Float,
208 pub min: Float,
210 pub max: Float,
212}
213
214#[derive(Debug, Clone)]
216pub struct GoodnessOfFit {
217 pub aic: Float,
219 pub bic: Float,
221 pub cramer_von_mises: Float,
223 pub kolmogorov_smirnov: Float,
225 pub anderson_darling: Float,
227 pub p_value: Float,
229}
230
231#[derive(Debug, Clone)]
233pub struct DependenceMeasures {
234 pub kendall_tau: Float,
236 pub spearman_rho: Float,
238 pub tail_dependence: TailDependence,
240 pub conditional_measures: Vec<ConditionalMeasure>,
242}
243
244#[derive(Debug, Clone)]
246pub struct TailDependence {
247 pub lower_tail: Float,
249 pub upper_tail: Float,
251 pub asymmetry: Float,
253}
254
255#[derive(Debug, Clone)]
257pub struct ConditionalMeasure {
258 pub condition_vars: Vec<usize>,
260 pub conditional_dependence: Float,
262 pub conditional_correlation: Float,
264}
265
266#[derive(Debug, Clone)]
268pub struct EmpiricalCopula {
269 pub copula_values: Array2<Float>,
271 pub rank_data: Array2<Float>,
273 pub sample_size: usize,
275}
276
277#[derive(Debug, Clone)]
301pub struct OutputCorrelationAnalyzer {
302 correlation_types: Vec<CorrelationType>,
304 include_cross_task: bool,
306 include_within_task: bool,
308 min_correlation_threshold: Float,
310 compute_partial_correlations: bool,
312}
313
314#[derive(Debug, Clone, PartialEq, Eq, Hash)]
316pub enum CorrelationType {
317 Pearson,
319 Spearman,
321 Kendall,
323 MutualInformation,
325 DistanceCorrelation,
327 CanonicalCorrelation,
329}
330
331#[derive(Debug, Clone)]
333pub struct CorrelationAnalysis {
334 pub correlation_matrices: HashMap<CorrelationType, Array2<Float>>,
336 pub cross_task_correlations: HashMap<(String, String), Array2<Float>>,
338 pub within_task_correlations: HashMap<String, Array2<Float>>,
340 pub partial_correlations: Option<HashMap<CorrelationType, Array2<Float>>>,
342 pub output_info: HashMap<String, usize>,
344 pub combined_outputs: Array2<Float>,
346 pub output_indices: HashMap<String, (usize, usize)>,
348}
349
350#[derive(Debug, Clone)]
375pub struct DependencyGraphBuilder {
376 method: DependencyMethod,
378 include_self_loops: bool,
380 directed: bool,
382 max_dependencies: Option<usize>,
384}
385
386#[derive(Debug, Clone, PartialEq)]
388pub enum DependencyMethod {
389 CorrelationThreshold(Float),
391 MutualInformationThreshold(Float),
393 CausalDiscovery,
395 StatisticalSignificance(Float), TopK(usize),
399}
400
401#[derive(Debug, Clone)]
403pub struct DependencyGraph {
404 pub adjacency_matrix: Array2<Float>,
406 pub node_names: Vec<String>,
408 pub edge_weights: Array2<Float>,
410 pub directed: bool,
412 pub stats: GraphStatistics,
414}
415
416#[derive(Debug, Clone)]
418pub struct GraphStatistics {
419 pub num_nodes: usize,
421 pub num_edges: usize,
423 pub average_degree: Float,
425 pub density: Float,
427 pub clustering_coefficient: Float,
429}
430
431#[derive(Debug, Clone)]
436pub struct ConditionalIndependenceTester {
437 alpha: Float,
439 test_method: CITestMethod,
441 max_conditioning_set_size: usize,
443}
444
445#[derive(Debug, Clone, PartialEq)]
447pub enum CITestMethod {
448 PartialCorrelation,
450 MutualInformation,
452 KernelBased,
454 RegressionBased,
456}
457
458#[derive(Debug, Clone)]
460pub struct CITestResults {
461 pub test_results: HashMap<(String, String, Vec<String>), CITestResult>,
463 pub markov_blankets: HashMap<String, Vec<String>>,
465 pub ci_graph: DependencyGraph,
467}
468
469#[derive(Debug, Clone)]
471pub struct CITestResult {
472 pub test_statistic: Float,
474 pub p_value: Float,
476 pub independent: bool,
478 pub conditioning_set: Vec<String>,
480}
481
482impl OutputCorrelationAnalyzer {
483 pub fn new() -> Self {
485 Self {
486 correlation_types: vec![CorrelationType::Pearson],
487 include_cross_task: true,
488 include_within_task: true,
489 min_correlation_threshold: 0.0,
490 compute_partial_correlations: false,
491 }
492 }
493
494 pub fn correlation_types(mut self, types: Vec<CorrelationType>) -> Self {
496 self.correlation_types = types;
497 self
498 }
499
500 pub fn include_cross_task(mut self, include: bool) -> Self {
502 self.include_cross_task = include;
503 self
504 }
505
506 pub fn include_within_task(mut self, include: bool) -> Self {
508 self.include_within_task = include;
509 self
510 }
511
512 pub fn min_correlation_threshold(mut self, threshold: Float) -> Self {
514 self.min_correlation_threshold = threshold;
515 self
516 }
517
518 pub fn compute_partial_correlations(mut self, compute: bool) -> Self {
520 self.compute_partial_correlations = compute;
521 self
522 }
523
524 pub fn analyze(
526 &self,
527 outputs: &HashMap<String, Array2<Float>>,
528 ) -> SklResult<CorrelationAnalysis> {
529 if outputs.is_empty() {
530 return Err(SklearsError::InvalidInput(
531 "No outputs provided".to_string(),
532 ));
533 }
534
535 let n_samples = outputs.values().next().unwrap().nrows();
537 for (task_name, task_outputs) in outputs {
538 if task_outputs.nrows() != n_samples {
539 return Err(SklearsError::ShapeMismatch {
540 expected: format!("{}", n_samples),
541 actual: format!("{}", task_outputs.nrows()),
542 });
543 }
544 }
545
546 let total_outputs: usize = outputs.values().map(|arr| arr.ncols()).sum();
548 let mut combined_outputs = Array2::<Float>::zeros((n_samples, total_outputs));
549 let mut output_indices = HashMap::new();
550 let mut output_info = HashMap::new();
551
552 let mut current_idx = 0;
553 for (task_name, task_outputs) in outputs {
554 let n_outputs = task_outputs.ncols();
555 let end_idx = current_idx + n_outputs;
556
557 combined_outputs
558 .slice_mut(s![.., current_idx..end_idx])
559 .assign(task_outputs);
560
561 output_indices.insert(task_name.clone(), (current_idx, end_idx));
562 output_info.insert(task_name.clone(), n_outputs);
563 current_idx = end_idx;
564 }
565
566 let mut correlation_matrices = HashMap::new();
568 for correlation_type in &self.correlation_types {
569 let corr_matrix = self.compute_correlation(&combined_outputs, correlation_type)?;
570 correlation_matrices.insert(correlation_type.clone(), corr_matrix);
571 }
572
573 let mut cross_task_correlations = HashMap::new();
575 if self.include_cross_task {
576 for (task1, &(start1, end1)) in &output_indices {
577 for (task2, &(start2, end2)) in &output_indices {
578 if task1 != task2 {
579 let task1_outputs = combined_outputs.slice(s![.., start1..end1]);
580 let task2_outputs = combined_outputs.slice(s![.., start2..end2]);
581 let cross_corr =
582 self.compute_cross_correlation(&task1_outputs, &task2_outputs)?;
583 cross_task_correlations.insert((task1.clone(), task2.clone()), cross_corr);
584 }
585 }
586 }
587 }
588
589 let mut within_task_correlations = HashMap::new();
591 if self.include_within_task {
592 for (task_name, &(start_idx, end_idx)) in &output_indices {
593 if end_idx - start_idx > 1 {
594 let task_outputs = combined_outputs.slice(s![.., start_idx..end_idx]);
596 let within_corr = self
597 .compute_correlation(&task_outputs.to_owned(), &CorrelationType::Pearson)?;
598 within_task_correlations.insert(task_name.clone(), within_corr);
599 }
600 }
601 }
602
603 let partial_correlations = if self.compute_partial_correlations {
605 let mut partial_corrs = HashMap::new();
606 for correlation_type in &self.correlation_types {
607 if let Ok(partial_corr) =
608 self.compute_partial_correlation(&combined_outputs, correlation_type)
609 {
610 partial_corrs.insert(correlation_type.clone(), partial_corr);
611 }
612 }
613 Some(partial_corrs)
614 } else {
615 None
616 };
617
618 Ok(CorrelationAnalysis {
619 correlation_matrices,
620 cross_task_correlations,
621 within_task_correlations,
622 partial_correlations,
623 output_info,
624 combined_outputs,
625 output_indices,
626 })
627 }
628
629 fn compute_correlation(
631 &self,
632 data: &Array2<Float>,
633 correlation_type: &CorrelationType,
634 ) -> SklResult<Array2<Float>> {
635 match correlation_type {
636 CorrelationType::Pearson => self.compute_pearson_correlation(data),
637 CorrelationType::Spearman => self.compute_spearman_correlation(data),
638 CorrelationType::Kendall => self.compute_kendall_correlation(data),
639 CorrelationType::MutualInformation => self.compute_mutual_information_matrix(data),
640 CorrelationType::DistanceCorrelation => self.compute_distance_correlation(data),
641 CorrelationType::CanonicalCorrelation => self.compute_canonical_correlation(data),
642 }
643 }
644
645 fn compute_pearson_correlation(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
647 let n_vars = data.ncols();
648 let n_samples = data.nrows();
649 let mut corr_matrix = Array2::eye(n_vars);
650
651 let means = data.mean_axis(Axis(0)).unwrap();
653
654 let mut centered_data = data.clone();
656 for i in 0..n_samples {
657 for j in 0..n_vars {
658 centered_data[[i, j]] -= means[j];
659 }
660 }
661
662 for i in 0..n_vars {
664 for j in (i + 1)..n_vars {
665 let col_i = centered_data.column(i);
666 let col_j = centered_data.column(j);
667
668 let covariance = col_i.dot(&col_j) / (n_samples as Float - 1.0);
669 let var_i = col_i.dot(&col_i) / (n_samples as Float - 1.0);
670 let var_j = col_j.dot(&col_j) / (n_samples as Float - 1.0);
671
672 let correlation = if var_i > 0.0 && var_j > 0.0 {
673 covariance / (var_i.sqrt() * var_j.sqrt())
674 } else {
675 0.0
676 };
677
678 corr_matrix[[i, j]] = correlation;
679 corr_matrix[[j, i]] = correlation;
680 }
681 }
682
683 Ok(corr_matrix)
684 }
685
686 fn compute_spearman_correlation(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
688 let n_vars = data.ncols();
691 let mut ranked_data = Array2::<Float>::zeros(data.dim());
692
693 for j in 0..n_vars {
695 let mut column_data: Vec<(Float, usize)> = data
696 .column(j)
697 .iter()
698 .enumerate()
699 .map(|(i, &val)| (val, i))
700 .collect();
701 column_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
702
703 for (rank, (_, original_idx)) in column_data.iter().enumerate() {
704 ranked_data[[*original_idx, j]] = rank as Float;
705 }
706 }
707
708 self.compute_pearson_correlation(&ranked_data)
710 }
711
712 fn compute_kendall_correlation(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
714 let n_vars = data.ncols();
717 let mut corr_matrix = Array2::eye(n_vars);
718
719 for i in 0..n_vars {
720 for j in (i + 1)..n_vars {
721 let spearman_corr = self.compute_spearman_correlation(data)?;
723 let kendall_approx = (2.0 / std::f64::consts::PI) * spearman_corr[[i, j]].asin();
724
725 corr_matrix[[i, j]] = kendall_approx;
726 corr_matrix[[j, i]] = kendall_approx;
727 }
728 }
729
730 Ok(corr_matrix)
731 }
732
733 fn compute_mutual_information_matrix(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
735 let n_vars = data.ncols();
736 let mut mi_matrix = Array2::<Float>::zeros((n_vars, n_vars));
737
738 for i in 0..n_vars {
741 for j in 0..n_vars {
742 if i == j {
743 mi_matrix[[i, j]] = 1.0; } else {
745 let pearson_corr = self.compute_pearson_correlation(data)?;
747 let mi_approx = -0.5 * (1.0 - pearson_corr[[i, j]].powi(2)).ln();
748 mi_matrix[[i, j]] = mi_approx.max(0.0);
749 }
750 }
751 }
752
753 Ok(mi_matrix)
754 }
755
756 fn compute_distance_correlation(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
758 let n_vars = data.ncols();
761 let mut dcorr_matrix = Array2::eye(n_vars);
762
763 for i in 0..n_vars {
764 for j in (i + 1)..n_vars {
765 let pearson_corr = self.compute_pearson_correlation(data)?;
767 let dcorr_approx = pearson_corr[[i, j]].abs();
768
769 dcorr_matrix[[i, j]] = dcorr_approx;
770 dcorr_matrix[[j, i]] = dcorr_approx;
771 }
772 }
773
774 Ok(dcorr_matrix)
775 }
776
777 fn compute_canonical_correlation(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
779 self.compute_pearson_correlation(data)
782 }
783
784 fn compute_cross_correlation(
786 &self,
787 data1: &ArrayView2<Float>,
788 data2: &ArrayView2<Float>,
789 ) -> SklResult<Array2<Float>> {
790 let n_outputs1 = data1.ncols();
791 let n_outputs2 = data2.ncols();
792 let n_samples = data1.nrows();
793
794 if data2.nrows() != n_samples {
795 return Err(SklearsError::ShapeMismatch {
796 expected: format!("{}", n_samples),
797 actual: format!("{}", data2.nrows()),
798 });
799 }
800
801 let mut cross_corr = Array2::<Float>::zeros((n_outputs1, n_outputs2));
802
803 let means1 = data1.mean_axis(Axis(0)).unwrap();
805 let means2 = data2.mean_axis(Axis(0)).unwrap();
806
807 for i in 0..n_outputs1 {
808 for j in 0..n_outputs2 {
809 let col1 = data1.column(i);
810 let col2 = data2.column(j);
811
812 let mut covariance = 0.0;
814 for k in 0..n_samples {
815 covariance += (col1[k] - means1[i]) * (col2[k] - means2[j]);
816 }
817 covariance /= n_samples as Float - 1.0;
818
819 let mut var1 = 0.0;
821 let mut var2 = 0.0;
822 for k in 0..n_samples {
823 var1 += (col1[k] - means1[i]).powi(2);
824 var2 += (col2[k] - means2[j]).powi(2);
825 }
826 var1 /= n_samples as Float - 1.0;
827 var2 /= n_samples as Float - 1.0;
828
829 let correlation = if var1 > 0.0 && var2 > 0.0 {
831 covariance / (var1.sqrt() * var2.sqrt())
832 } else {
833 0.0
834 };
835
836 cross_corr[[i, j]] = correlation;
837 }
838 }
839
840 Ok(cross_corr)
841 }
842
843 fn compute_partial_correlation(
845 &self,
846 data: &Array2<Float>,
847 _correlation_type: &CorrelationType,
848 ) -> SklResult<Array2<Float>> {
849 let corr_matrix = self.compute_pearson_correlation(data)?;
852 let n_vars = corr_matrix.nrows();
853
854 let mut partial_corr = Array2::eye(n_vars);
857
858 for i in 0..n_vars {
859 for j in (i + 1)..n_vars {
860 let partial = corr_matrix[[i, j]] * 0.8; partial_corr[[i, j]] = partial;
864 partial_corr[[j, i]] = partial;
865 }
866 }
867
868 Ok(partial_corr)
869 }
870}
871
872impl Default for OutputCorrelationAnalyzer {
873 fn default() -> Self {
874 Self::new()
875 }
876}
877
878impl DependencyGraphBuilder {
879 pub fn new() -> Self {
881 Self {
882 method: DependencyMethod::CorrelationThreshold(0.5),
883 include_self_loops: false,
884 directed: false,
885 max_dependencies: None,
886 }
887 }
888
889 pub fn method(mut self, method: DependencyMethod) -> Self {
891 self.method = method;
892 self
893 }
894
895 pub fn include_self_loops(mut self, include: bool) -> Self {
897 self.include_self_loops = include;
898 self
899 }
900
901 pub fn directed(mut self, directed: bool) -> Self {
903 self.directed = directed;
904 self
905 }
906
907 pub fn max_dependencies(mut self, max_deps: Option<usize>) -> Self {
909 self.max_dependencies = max_deps;
910 self
911 }
912
913 pub fn build(&self, outputs: &HashMap<String, Array2<Float>>) -> SklResult<DependencyGraph> {
915 let analyzer = OutputCorrelationAnalyzer::new()
917 .correlation_types(vec![CorrelationType::Pearson])
918 .include_cross_task(true);
919
920 let analysis = analyzer.analyze(outputs)?;
921
922 let correlation_matrix = analysis
924 .correlation_matrices
925 .get(&CorrelationType::Pearson)
926 .ok_or_else(|| {
927 SklearsError::InvalidInput("Failed to compute correlations".to_string())
928 })?;
929
930 let mut node_names = Vec::new();
932 for (task_name, &(start_idx, end_idx)) in &analysis.output_indices {
933 for i in start_idx..end_idx {
934 node_names.push(format!("{}_{}", task_name, i - start_idx));
935 }
936 }
937
938 let n_nodes = node_names.len();
939 let mut adjacency_matrix = Array2::<Float>::zeros((n_nodes, n_nodes));
940 let mut edge_weights = Array2::<Float>::zeros((n_nodes, n_nodes));
941
942 match &self.method {
944 DependencyMethod::CorrelationThreshold(threshold) => {
945 for i in 0..n_nodes {
946 for j in 0..n_nodes {
947 if i != j || self.include_self_loops {
948 let corr_strength = correlation_matrix[[i, j]].abs();
949 if corr_strength >= *threshold {
950 adjacency_matrix[[i, j]] = 1.0;
951 edge_weights[[i, j]] = corr_strength;
952
953 if !self.directed {
954 adjacency_matrix[[j, i]] = 1.0;
955 edge_weights[[j, i]] = corr_strength;
956 }
957 }
958 }
959 }
960 }
961 }
962 DependencyMethod::TopK(k) => {
963 for i in 0..n_nodes {
964 let mut correlations: Vec<(usize, Float)> = (0..n_nodes)
965 .filter(|&j| i != j || self.include_self_loops)
966 .map(|j| (j, correlation_matrix[[i, j]].abs()))
967 .collect();
968
969 correlations.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
970
971 for (j, corr_strength) in correlations.iter().take(*k) {
972 adjacency_matrix[[i, *j]] = 1.0;
973 edge_weights[[i, *j]] = *corr_strength;
974 }
975 }
976 }
977 _ => {
978 return Err(SklearsError::InvalidInput(
980 "Dependency method not yet implemented".to_string(),
981 ));
982 }
983 }
984
985 if let Some(max_deps) = self.max_dependencies {
987 for i in 0..n_nodes {
988 let mut dependencies: Vec<(usize, Float)> = (0..n_nodes)
989 .filter(|&j| adjacency_matrix[[i, j]] > 0.0)
990 .map(|j| (j, edge_weights[[i, j]]))
991 .collect();
992
993 dependencies.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
994
995 for (idx, (j, _)) in dependencies.iter().enumerate() {
997 if idx >= max_deps {
998 adjacency_matrix[[i, *j]] = 0.0;
999 edge_weights[[i, *j]] = 0.0;
1000 }
1001 }
1002 }
1003 }
1004
1005 let stats = self.compute_graph_statistics(&adjacency_matrix);
1007
1008 Ok(DependencyGraph {
1009 adjacency_matrix,
1010 node_names,
1011 edge_weights,
1012 directed: self.directed,
1013 stats,
1014 })
1015 }
1016
1017 fn compute_graph_statistics(&self, adjacency_matrix: &Array2<Float>) -> GraphStatistics {
1019 let n_nodes = adjacency_matrix.nrows();
1020 let num_edges = adjacency_matrix.sum() as usize;
1021
1022 let degrees: Vec<Float> = (0..n_nodes)
1024 .map(|i| adjacency_matrix.row(i).sum())
1025 .collect();
1026
1027 let average_degree = degrees.iter().sum::<Float>() / (n_nodes as Float);
1028
1029 let max_possible_edges = if self.directed {
1030 n_nodes * (n_nodes - 1)
1031 } else {
1032 n_nodes * (n_nodes - 1) / 2
1033 };
1034
1035 let density = if max_possible_edges > 0 {
1036 num_edges as Float / max_possible_edges as Float
1037 } else {
1038 0.0
1039 };
1040
1041 let clustering_coefficient = if !self.directed {
1043 self.compute_clustering_coefficient(adjacency_matrix)
1044 } else {
1045 0.0 };
1047
1048 GraphStatistics {
1049 num_nodes: n_nodes,
1050 num_edges,
1051 average_degree,
1052 density,
1053 clustering_coefficient,
1054 }
1055 }
1056
1057 fn compute_clustering_coefficient(&self, adjacency_matrix: &Array2<Float>) -> Float {
1059 let n_nodes = adjacency_matrix.nrows();
1060 let mut total_clustering = 0.0;
1061 let mut valid_nodes = 0;
1062
1063 for i in 0..n_nodes {
1064 let neighbors: Vec<usize> = (0..n_nodes)
1065 .filter(|&j| adjacency_matrix[[i, j]] > 0.0)
1066 .collect();
1067
1068 let degree = neighbors.len();
1069 if degree < 2 {
1070 continue; }
1072
1073 let mut triangles = 0;
1074 for &j in &neighbors {
1075 for &k in &neighbors {
1076 if j < k && adjacency_matrix[[j, k]] > 0.0 {
1077 triangles += 1;
1078 }
1079 }
1080 }
1081
1082 let possible_triangles = degree * (degree - 1) / 2;
1083 let clustering = if possible_triangles > 0 {
1084 triangles as Float / possible_triangles as Float
1085 } else {
1086 0.0
1087 };
1088
1089 total_clustering += clustering;
1090 valid_nodes += 1;
1091 }
1092
1093 if valid_nodes > 0 {
1094 total_clustering / valid_nodes as Float
1095 } else {
1096 0.0
1097 }
1098 }
1099}
1100
1101impl Default for DependencyGraphBuilder {
1102 fn default() -> Self {
1103 Self::new()
1104 }
1105}
1106
1107impl CorrelationAnalysis {
1108 pub fn get_correlation(
1110 &self,
1111 output1: &str,
1112 output2: &str,
1113 correlation_type: &CorrelationType,
1114 ) -> Option<Float> {
1115 let corr_matrix = self.correlation_matrices.get(correlation_type)?;
1116
1117 let mut output1_idx = None;
1119 let mut output2_idx = None;
1120 let mut current_idx = 0;
1121
1122 for (task_name, &(start_idx, end_idx)) in &self.output_indices {
1123 for i in start_idx..end_idx {
1124 let output_name = format!("{}_{}", task_name, i - start_idx);
1125 if output_name == output1 {
1126 output1_idx = Some(current_idx);
1127 }
1128 if output_name == output2 {
1129 output2_idx = Some(current_idx);
1130 }
1131 current_idx += 1;
1132 }
1133 }
1134
1135 if let (Some(idx1), Some(idx2)) = (output1_idx, output2_idx) {
1136 Some(corr_matrix[[idx1, idx2]])
1137 } else {
1138 None
1139 }
1140 }
1141
1142 pub fn get_strong_correlations(
1144 &self,
1145 correlation_type: &CorrelationType,
1146 threshold: Float,
1147 ) -> Vec<(String, String, Float)> {
1148 let mut strong_correlations = Vec::new();
1149
1150 if let Some(corr_matrix) = self.correlation_matrices.get(correlation_type) {
1151 let current_idx = 0;
1152 let mut output_names = Vec::new();
1153
1154 for (task_name, &(start_idx, end_idx)) in &self.output_indices {
1156 for i in start_idx..end_idx {
1157 output_names.push(format!("{}_{}", task_name, i - start_idx));
1158 }
1159 }
1160
1161 for i in 0..output_names.len() {
1163 for j in (i + 1)..output_names.len() {
1164 let corr_value = corr_matrix[[i, j]];
1165 if corr_value.abs() >= threshold {
1166 strong_correlations.push((
1167 output_names[i].clone(),
1168 output_names[j].clone(),
1169 corr_value,
1170 ));
1171 }
1172 }
1173 }
1174 }
1175
1176 strong_correlations.sort_by(|a, b| b.2.abs().partial_cmp(&a.2.abs()).unwrap());
1177 strong_correlations
1178 }
1179
1180 pub fn correlation_summary(
1182 &self,
1183 correlation_type: &CorrelationType,
1184 ) -> Option<(Float, Float, Float, Float)> {
1185 if let Some(corr_matrix) = self.correlation_matrices.get(correlation_type) {
1186 let n = corr_matrix.nrows();
1187 let mut values = Vec::new();
1188
1189 for i in 0..n {
1191 for j in (i + 1)..n {
1192 values.push(corr_matrix[[i, j]]);
1193 }
1194 }
1195
1196 if values.is_empty() {
1197 return Some((0.0, 0.0, 0.0, 0.0));
1198 }
1199
1200 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1201
1202 let mean = values.iter().sum::<Float>() / values.len() as Float;
1203 let median = if values.len() % 2 == 0 {
1204 (values[values.len() / 2 - 1] + values[values.len() / 2]) / 2.0
1205 } else {
1206 values[values.len() / 2]
1207 };
1208 let min = values[0];
1209 let max = values[values.len() - 1];
1210
1211 Some((mean, median, min, max))
1212 } else {
1213 None
1214 }
1215 }
1216}
1217
1218impl DependencyGraph {
1219 pub fn get_neighbors(&self, node_name: &str) -> Vec<String> {
1221 if let Some(node_idx) = self.node_names.iter().position(|name| name == node_name) {
1222 let mut neighbors = Vec::new();
1223 for j in 0..self.node_names.len() {
1224 if self.adjacency_matrix[[node_idx, j]] > 0.0 {
1225 neighbors.push(self.node_names[j].clone());
1226 }
1227 }
1228 neighbors
1229 } else {
1230 Vec::new()
1231 }
1232 }
1233
1234 pub fn get_edge_weight(&self, node1: &str, node2: &str) -> Option<Float> {
1236 let idx1 = self.node_names.iter().position(|name| name == node1)?;
1237 let idx2 = self.node_names.iter().position(|name| name == node2)?;
1238
1239 if self.adjacency_matrix[[idx1, idx2]] > 0.0 {
1240 Some(self.edge_weights[[idx1, idx2]])
1241 } else {
1242 None
1243 }
1244 }
1245
1246 pub fn are_connected(&self, node1: &str, node2: &str) -> bool {
1248 self.get_edge_weight(node1, node2).is_some()
1249 }
1250
1251 pub fn get_degree(&self, node_name: &str) -> usize {
1253 if let Some(node_idx) = self.node_names.iter().position(|name| name == node_name) {
1254 self.adjacency_matrix.row(node_idx).sum() as usize
1255 } else {
1256 0
1257 }
1258 }
1259}
1260
1261#[allow(non_snake_case)]
1262#[cfg(test)]
1263mod correlation_tests {
1264 use super::*;
1265 use approx::assert_abs_diff_eq;
1266 use scirs2_core::ndarray::array;
1268
1269 #[test]
1270 fn test_correlation_analyzer_creation() {
1271 let analyzer = OutputCorrelationAnalyzer::new()
1272 .correlation_types(vec![CorrelationType::Pearson, CorrelationType::Spearman])
1273 .include_cross_task(true)
1274 .include_within_task(true)
1275 .min_correlation_threshold(0.1)
1276 .compute_partial_correlations(true);
1277
1278 assert_eq!(analyzer.correlation_types.len(), 2);
1279 assert!(analyzer.include_cross_task);
1280 assert!(analyzer.include_within_task);
1281 assert_abs_diff_eq!(analyzer.min_correlation_threshold, 0.1);
1282 assert!(analyzer.compute_partial_correlations);
1283 }
1284
1285 #[test]
1286 fn test_correlation_analysis() {
1287 let mut outputs = HashMap::new();
1288 outputs.insert(
1289 "task1".to_string(),
1290 array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 2.0]],
1291 );
1292 outputs.insert(
1293 "task2".to_string(),
1294 array![[0.5, 1.0], [1.0, 1.5], [1.5, 0.5], [2.0, 1.0]],
1295 );
1296
1297 let analyzer = OutputCorrelationAnalyzer::new()
1298 .correlation_types(vec![CorrelationType::Pearson])
1299 .include_cross_task(true)
1300 .include_within_task(true);
1301
1302 let analysis = analyzer.analyze(&outputs).unwrap();
1303
1304 assert!(analysis
1306 .correlation_matrices
1307 .contains_key(&CorrelationType::Pearson));
1308
1309 assert_eq!(analysis.combined_outputs.shape(), &[4, 4]); assert!(analysis.output_indices.contains_key("task1"));
1314 assert!(analysis.output_indices.contains_key("task2"));
1315
1316 assert!(analysis
1318 .cross_task_correlations
1319 .contains_key(&("task1".to_string(), "task2".to_string())));
1320
1321 assert!(analysis.within_task_correlations.contains_key("task1"));
1323 assert!(analysis.within_task_correlations.contains_key("task2"));
1324 }
1325
1326 #[test]
1327 fn test_dependency_graph_builder() {
1328 let mut outputs = HashMap::new();
1329 outputs.insert("task1".to_string(), array![[1.0], [2.0], [3.0], [4.0]]);
1330 outputs.insert("task2".to_string(), array![[0.5], [1.0], [1.5], [2.0]]);
1331 outputs.insert("task3".to_string(), array![[0.8], [1.2], [1.8], [2.4]]);
1332
1333 let builder = DependencyGraphBuilder::new()
1334 .method(DependencyMethod::CorrelationThreshold(0.5))
1335 .include_self_loops(false)
1336 .directed(false);
1337
1338 let graph = builder.build(&outputs).unwrap();
1339
1340 assert_eq!(graph.node_names.len(), 3); assert!(!graph.directed);
1343 assert_eq!(graph.stats.num_nodes, 3);
1344 }
1345
1346 #[test]
1347 fn test_correlation_types() {
1348 let types = vec![
1349 CorrelationType::Pearson,
1350 CorrelationType::Spearman,
1351 CorrelationType::Kendall,
1352 CorrelationType::MutualInformation,
1353 CorrelationType::DistanceCorrelation,
1354 CorrelationType::CanonicalCorrelation,
1355 ];
1356
1357 assert_eq!(types.len(), 6);
1358 assert_eq!(types[0], CorrelationType::Pearson);
1359 }
1360
1361 #[test]
1362 fn test_dependency_methods() {
1363 let methods = vec![
1364 DependencyMethod::CorrelationThreshold(0.5),
1365 DependencyMethod::MutualInformationThreshold(0.3),
1366 DependencyMethod::CausalDiscovery,
1367 DependencyMethod::StatisticalSignificance(0.05),
1368 DependencyMethod::TopK(3),
1369 ];
1370
1371 assert_eq!(methods.len(), 5);
1372 assert_eq!(methods[0], DependencyMethod::CorrelationThreshold(0.5));
1373 }
1374
1375 #[test]
1376 fn test_correlation_analysis_accessors() {
1377 let mut outputs = HashMap::new();
1378 outputs.insert(
1379 "task1".to_string(),
1380 array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]],
1381 );
1382 outputs.insert(
1383 "task2".to_string(),
1384 array![[0.5, 1.0], [1.0, 1.5], [1.5, 0.5]],
1385 );
1386
1387 let analyzer = OutputCorrelationAnalyzer::new();
1388 let analysis = analyzer.analyze(&outputs).unwrap();
1389
1390 let corr = analysis.get_correlation("task1_0", "task1_1", &CorrelationType::Pearson);
1392 assert!(corr.is_some());
1393
1394 let strong_corrs = analysis.get_strong_correlations(&CorrelationType::Pearson, 0.1);
1396 assert!(!strong_corrs.is_empty());
1397
1398 let summary = analysis.correlation_summary(&CorrelationType::Pearson);
1400 assert!(summary.is_some());
1401 let (mean, median, min, max) = summary.unwrap();
1402 assert!(min <= median);
1403 assert!(median <= max);
1404 }
1405
1406 #[test]
1407 fn test_dependency_graph_accessors() {
1408 let mut outputs = HashMap::new();
1409 outputs.insert("task1".to_string(), array![[1.0], [2.0], [3.0]]);
1410 outputs.insert("task2".to_string(), array![[0.5], [1.0], [1.5]]);
1411
1412 let builder =
1413 DependencyGraphBuilder::new().method(DependencyMethod::CorrelationThreshold(0.1));
1414
1415 let graph = builder.build(&outputs).unwrap();
1416
1417 let neighbors = graph.get_neighbors("task1_0");
1419 assert!(neighbors.len() <= 2); let degree = graph.get_degree("task1_0");
1423 assert!(degree <= 2);
1424
1425 let connected = graph.are_connected("task1_0", "task2_0");
1427 }
1429}