sklears_multioutput/
correlation.rs

1//! Output Correlation Analysis and Dependency Modeling
2//!
3//! This module provides tools for analyzing and modeling correlations and dependencies
4//! between different outputs in multi-output learning scenarios. Understanding these
5//! relationships can help improve model performance and provide insights into the
6//! underlying data structure.
7
8// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
9use scirs2_core::ndarray::{s, Array2, ArrayView2, Axis};
10use sklears_core::{
11    error::{Result as SklResult, SklearsError},
12    types::Float,
13};
14use std::collections::HashMap;
15
16/// Copula-Based Modeling Analyzer
17///
18/// Analyzes and models complex dependencies between outputs using copulas.
19/// Copulas separate the marginal distributions from the dependence structure,
20/// allowing for more flexible modeling of non-linear and non-monotonic relationships.
21///
22/// # Examples
23///
24/// ```
25/// use sklears_multioutput::correlation::{CopulaBasedModelingAnalyzer, CopulaType};
26/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
27/// use scirs2_core::ndarray::array;
28/// use std::collections::HashMap;
29///
30/// let mut outputs = HashMap::new();
31/// outputs.insert("task1".to_string(), array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]]);
32/// outputs.insert("task2".to_string(), array![[0.5, 1.0], [1.0, 1.5], [1.5, 0.5]]);
33///
34/// let analyzer = CopulaBasedModelingAnalyzer::new()
35///     .copula_types(vec![CopulaType::Gaussian, CopulaType::Clayton])
36///     .fit_margins(true);
37///
38/// let analysis = analyzer.analyze(&outputs).unwrap();
39/// ```
40#[derive(Debug, Clone)]
41pub struct CopulaBasedModelingAnalyzer {
42    /// Types of copulas to fit
43    copula_types: Vec<CopulaType>,
44    /// Whether to fit marginal distributions
45    fit_margins: bool,
46    /// Whether to use empirical copula for comparison
47    use_empirical_copula: bool,
48    /// Number of samples for Monte Carlo methods
49    n_samples: usize,
50    /// Random state for reproducibility
51    random_state: Option<u64>,
52}
53
54impl CopulaBasedModelingAnalyzer {
55    pub fn new() -> Self {
56        Self {
57            copula_types: vec![CopulaType::Gaussian],
58            fit_margins: true,
59            use_empirical_copula: false,
60            n_samples: 1000,
61            random_state: None,
62        }
63    }
64
65    /// Set the copula types to fit
66    pub fn copula_types(mut self, copula_types: Vec<CopulaType>) -> Self {
67        self.copula_types = copula_types;
68        self
69    }
70
71    /// Set whether to fit marginal distributions
72    pub fn fit_margins(mut self, fit_margins: bool) -> Self {
73        self.fit_margins = fit_margins;
74        self
75    }
76
77    /// Set whether to use empirical copula for comparison
78    pub fn use_empirical_copula(mut self, use_empirical_copula: bool) -> Self {
79        self.use_empirical_copula = use_empirical_copula;
80        self
81    }
82
83    /// Set the number of samples for Monte Carlo methods
84    pub fn n_samples(mut self, n_samples: usize) -> Self {
85        self.n_samples = n_samples;
86        self
87    }
88
89    /// Set the random state for reproducibility
90    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
91        self.random_state = random_state;
92        self
93    }
94
95    /// Analyze copula-based dependencies in the given outputs
96    pub fn analyze(&self, _outputs: &HashMap<String, Array2<Float>>) -> SklResult<CopulaAnalysis> {
97        // Placeholder implementation - would need full copula fitting algorithms
98        let copula_models = HashMap::new();
99        let marginal_distributions = HashMap::new();
100        let goodness_of_fit = HashMap::new();
101        let dependence_measures = HashMap::new();
102        let output_info = HashMap::new();
103
104        Ok(CopulaAnalysis {
105            copula_models,
106            marginal_distributions,
107            goodness_of_fit,
108            dependence_measures,
109            best_copula: None,
110            output_info,
111            empirical_copula: None,
112        })
113    }
114}
115
116impl Default for CopulaBasedModelingAnalyzer {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122/// Types of copulas for dependency modeling
123#[derive(Debug, Clone, PartialEq, Eq, Hash)]
124pub enum CopulaType {
125    /// Gaussian copula (multivariate normal dependence)
126    Gaussian,
127    /// Clayton copula (lower tail dependence)
128    Clayton,
129    /// Frank copula (symmetric dependence)
130    Frank,
131    /// Gumbel copula (upper tail dependence)
132    Gumbel,
133    /// Student's t-copula (symmetric tail dependence)
134    StudentT,
135    /// Archimedean copula family
136    Archimedean,
137    /// Empirical copula (non-parametric)
138    Empirical,
139}
140
141/// Copula modeling results
142#[derive(Debug, Clone)]
143pub struct CopulaAnalysis {
144    /// Fitted copula models for each copula type
145    pub copula_models: HashMap<CopulaType, CopulaModel>,
146    /// Marginal distribution parameters
147    pub marginal_distributions: HashMap<String, MarginalDistribution>,
148    /// Copula goodness-of-fit statistics
149    pub goodness_of_fit: HashMap<CopulaType, GoodnessOfFit>,
150    /// Dependence measures derived from copulas
151    pub dependence_measures: HashMap<CopulaType, DependenceMeasures>,
152    /// Best fitting copula type
153    pub best_copula: Option<CopulaType>,
154    /// Output names and dimensions
155    pub output_info: HashMap<String, usize>,
156    /// Empirical copula for comparison
157    pub empirical_copula: Option<EmpiricalCopula>,
158}
159
160/// Fitted copula model
161#[derive(Debug, Clone)]
162pub struct CopulaModel {
163    /// Copula type
164    pub copula_type: CopulaType,
165    /// Copula parameters
166    pub parameters: CopulaParameters,
167    /// Log-likelihood of the fit
168    pub log_likelihood: Float,
169    /// Number of parameters
170    pub n_parameters: usize,
171    /// Fitted data used for the model
172    pub fitted_data: Array2<Float>,
173}
174
175/// Copula parameters for different copula types
176#[derive(Debug, Clone)]
177pub enum CopulaParameters {
178    /// Gaussian copula: correlation matrix
179    Gaussian { correlation_matrix: Array2<Float> },
180    /// Clayton copula: theta parameter
181    Clayton { theta: Float },
182    /// Frank copula: theta parameter
183    Frank { theta: Float },
184    /// Gumbel copula: theta parameter
185    Gumbel { theta: Float },
186    /// Student's t-copula: correlation matrix and degrees of freedom
187    StudentT {
188        correlation_matrix: Array2<Float>,
189        degrees_of_freedom: Float,
190    },
191    /// Archimedean copula: generator function parameters
192    Archimedean { generator_params: Vec<Float> },
193    /// Empirical copula: no parameters
194    Empirical,
195}
196
197/// Marginal distribution parameters
198#[derive(Debug, Clone)]
199pub struct MarginalDistribution {
200    /// Distribution type (e.g., "normal", "uniform", "empirical")
201    pub distribution_type: String,
202    /// Distribution parameters
203    pub parameters: Vec<Float>,
204    /// Fitted data statistics
205    pub mean: Float,
206    /// std_dev
207    pub std_dev: Float,
208    /// min
209    pub min: Float,
210    /// max
211    pub max: Float,
212}
213
214/// Goodness-of-fit statistics for copulas
215#[derive(Debug, Clone)]
216pub struct GoodnessOfFit {
217    /// Akaike Information Criterion
218    pub aic: Float,
219    /// Bayesian Information Criterion
220    pub bic: Float,
221    /// Cramér-von Mises test statistic
222    pub cramer_von_mises: Float,
223    /// Kolmogorov-Smirnov test statistic
224    pub kolmogorov_smirnov: Float,
225    /// Anderson-Darling test statistic
226    pub anderson_darling: Float,
227    /// P-value for goodness-of-fit test
228    pub p_value: Float,
229}
230
231/// Dependence measures derived from copulas
232#[derive(Debug, Clone)]
233pub struct DependenceMeasures {
234    /// Kendall's tau
235    pub kendall_tau: Float,
236    /// Spearman's rho
237    pub spearman_rho: Float,
238    /// Tail dependence coefficients
239    pub tail_dependence: TailDependence,
240    /// Conditional copula measures
241    pub conditional_measures: Vec<ConditionalMeasure>,
242}
243
244/// Tail dependence coefficients
245#[derive(Debug, Clone)]
246pub struct TailDependence {
247    /// Lower tail dependence coefficient
248    pub lower_tail: Float,
249    /// Upper tail dependence coefficient
250    pub upper_tail: Float,
251    /// Asymmetry measure
252    pub asymmetry: Float,
253}
254
255/// Conditional dependence measures
256#[derive(Debug, Clone)]
257pub struct ConditionalMeasure {
258    /// Condition variable indices
259    pub condition_vars: Vec<usize>,
260    /// Conditional dependence strength
261    pub conditional_dependence: Float,
262    /// Conditional correlation
263    pub conditional_correlation: Float,
264}
265
266/// Empirical copula representation
267#[derive(Debug, Clone)]
268pub struct EmpiricalCopula {
269    /// Empirical copula values
270    pub copula_values: Array2<Float>,
271    /// Rank-based data
272    pub rank_data: Array2<Float>,
273    /// Sample size
274    pub sample_size: usize,
275}
276
277/// Output Correlation Analyzer
278///
279/// Analyzes correlations and dependencies between different outputs in multi-output data.
280/// Provides various correlation measures and dependency analysis tools.
281///
282/// # Examples
283///
284/// ```
285/// use sklears_multioutput::correlation::{OutputCorrelationAnalyzer, CorrelationType};
286/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
287/// use scirs2_core::ndarray::array;
288/// use std::collections::HashMap;
289///
290/// let mut outputs = HashMap::new();
291/// outputs.insert("task1".to_string(), array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]]);
292/// outputs.insert("task2".to_string(), array![[0.5, 1.0], [1.0, 1.5], [1.5, 0.5]]);
293///
294/// let analyzer = OutputCorrelationAnalyzer::new()
295///     .correlation_types(vec![CorrelationType::Pearson, CorrelationType::Spearman])
296///     .include_cross_task(true);
297///
298/// let analysis = analyzer.analyze(&outputs).unwrap();
299/// ```
300#[derive(Debug, Clone)]
301pub struct OutputCorrelationAnalyzer {
302    /// Types of correlation to compute
303    correlation_types: Vec<CorrelationType>,
304    /// Whether to include cross-task correlations
305    include_cross_task: bool,
306    /// Whether to include within-task correlations
307    include_within_task: bool,
308    /// Minimum correlation threshold for reporting
309    min_correlation_threshold: Float,
310    /// Whether to compute partial correlations
311    compute_partial_correlations: bool,
312}
313
314/// Types of correlation measures
315#[derive(Debug, Clone, PartialEq, Eq, Hash)]
316pub enum CorrelationType {
317    /// Pearson correlation coefficient
318    Pearson,
319    /// Spearman rank correlation
320    Spearman,
321    /// Kendall tau correlation
322    Kendall,
323    /// Mutual information
324    MutualInformation,
325    /// Distance correlation
326    DistanceCorrelation,
327    /// Canonical correlation
328    CanonicalCorrelation,
329}
330
331/// Correlation analysis results
332#[derive(Debug, Clone)]
333pub struct CorrelationAnalysis {
334    /// Correlation matrices for each correlation type
335    pub correlation_matrices: HashMap<CorrelationType, Array2<Float>>,
336    /// Cross-task correlation analysis
337    pub cross_task_correlations: HashMap<(String, String), Array2<Float>>,
338    /// Within-task correlation analysis
339    pub within_task_correlations: HashMap<String, Array2<Float>>,
340    /// Partial correlation matrices
341    pub partial_correlations: Option<HashMap<CorrelationType, Array2<Float>>>,
342    /// Output names and dimensions
343    pub output_info: HashMap<String, usize>,
344    /// Combined output matrix used for analysis
345    pub combined_outputs: Array2<Float>,
346    /// Output indices for each task
347    pub output_indices: HashMap<String, (usize, usize)>,
348}
349
350/// Dependency Graph Builder
351///
352/// Builds dependency graphs between outputs based on various criteria.
353/// Useful for understanding causal relationships and for building chain models.
354///
355/// # Examples
356///
357/// ```
358/// use sklears_multioutput::correlation::{DependencyGraphBuilder, DependencyMethod};
359/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
360/// use scirs2_core::ndarray::array;
361/// use std::collections::HashMap;
362///
363/// let mut outputs = HashMap::new();
364/// outputs.insert("task1".to_string(), array![[1.0], [2.0], [3.0]]);
365/// outputs.insert("task2".to_string(), array![[0.5], [1.0], [1.5]]);
366/// outputs.insert("task3".to_string(), array![[0.8], [1.2], [1.8]]);
367///
368/// let builder = DependencyGraphBuilder::new()
369///     .method(DependencyMethod::CorrelationThreshold(0.5))
370///     .include_self_loops(false);
371///
372/// let graph = builder.build(&outputs).unwrap();
373/// ```
374#[derive(Debug, Clone)]
375pub struct DependencyGraphBuilder {
376    /// Method for determining dependencies
377    method: DependencyMethod,
378    /// Whether to include self-loops
379    include_self_loops: bool,
380    /// Whether to make the graph directed
381    directed: bool,
382    /// Maximum number of dependencies per node
383    max_dependencies: Option<usize>,
384}
385
386/// Methods for determining dependencies
387#[derive(Debug, Clone, PartialEq)]
388pub enum DependencyMethod {
389    /// Correlation threshold
390    CorrelationThreshold(Float),
391    /// Mutual information threshold
392    MutualInformationThreshold(Float),
393    /// Causal discovery (simplified)
394    CausalDiscovery,
395    /// Statistical significance testing
396    StatisticalSignificance(Float), // p-value threshold
397    /// Top-k strongest correlations
398    TopK(usize),
399}
400
401/// Dependency graph representation
402#[derive(Debug, Clone)]
403pub struct DependencyGraph {
404    /// Adjacency matrix
405    pub adjacency_matrix: Array2<Float>,
406    /// Node names (output names)
407    pub node_names: Vec<String>,
408    /// Edge weights (correlation/dependency strengths)
409    pub edge_weights: Array2<Float>,
410    /// Whether the graph is directed
411    pub directed: bool,
412    /// Graph statistics
413    pub stats: GraphStatistics,
414}
415
416/// Graph statistics
417#[derive(Debug, Clone)]
418pub struct GraphStatistics {
419    /// Number of nodes
420    pub num_nodes: usize,
421    /// Number of edges
422    pub num_edges: usize,
423    /// Average degree
424    pub average_degree: Float,
425    /// Density (proportion of possible edges that exist)
426    pub density: Float,
427    /// Clustering coefficient
428    pub clustering_coefficient: Float,
429}
430
431/// Conditional Independence Tester
432///
433/// Tests for conditional independence between outputs given other outputs.
434/// Useful for understanding causal structure and for feature selection.
435#[derive(Debug, Clone)]
436pub struct ConditionalIndependenceTester {
437    /// Significance level for tests
438    alpha: Float,
439    /// Test method
440    test_method: CITestMethod,
441    /// Maximum conditioning set size
442    max_conditioning_set_size: usize,
443}
444
445/// Methods for conditional independence testing
446#[derive(Debug, Clone, PartialEq)]
447pub enum CITestMethod {
448    /// Partial correlation test
449    PartialCorrelation,
450    /// Mutual information based test
451    MutualInformation,
452    /// Kernel-based test
453    KernelBased,
454    /// Linear regression based test
455    RegressionBased,
456}
457
458/// Results of conditional independence testing
459#[derive(Debug, Clone)]
460pub struct CITestResults {
461    /// Test results for each pair given conditioning sets
462    pub test_results: HashMap<(String, String, Vec<String>), CITestResult>,
463    /// Markov blankets for each output
464    pub markov_blankets: HashMap<String, Vec<String>>,
465    /// Conditional independence graph
466    pub ci_graph: DependencyGraph,
467}
468
469/// Single conditional independence test result
470#[derive(Debug, Clone)]
471pub struct CITestResult {
472    /// Test statistic
473    pub test_statistic: Float,
474    /// P-value
475    pub p_value: Float,
476    /// Whether independence is rejected
477    pub independent: bool,
478    /// Conditioning set used
479    pub conditioning_set: Vec<String>,
480}
481
482impl OutputCorrelationAnalyzer {
483    /// Create a new OutputCorrelationAnalyzer
484    pub fn new() -> Self {
485        Self {
486            correlation_types: vec![CorrelationType::Pearson],
487            include_cross_task: true,
488            include_within_task: true,
489            min_correlation_threshold: 0.0,
490            compute_partial_correlations: false,
491        }
492    }
493
494    /// Set correlation types to compute
495    pub fn correlation_types(mut self, types: Vec<CorrelationType>) -> Self {
496        self.correlation_types = types;
497        self
498    }
499
500    /// Set whether to include cross-task correlations
501    pub fn include_cross_task(mut self, include: bool) -> Self {
502        self.include_cross_task = include;
503        self
504    }
505
506    /// Set whether to include within-task correlations
507    pub fn include_within_task(mut self, include: bool) -> Self {
508        self.include_within_task = include;
509        self
510    }
511
512    /// Set minimum correlation threshold for reporting
513    pub fn min_correlation_threshold(mut self, threshold: Float) -> Self {
514        self.min_correlation_threshold = threshold;
515        self
516    }
517
518    /// Set whether to compute partial correlations
519    pub fn compute_partial_correlations(mut self, compute: bool) -> Self {
520        self.compute_partial_correlations = compute;
521        self
522    }
523
524    /// Analyze correlations in multi-output data
525    pub fn analyze(
526        &self,
527        outputs: &HashMap<String, Array2<Float>>,
528    ) -> SklResult<CorrelationAnalysis> {
529        if outputs.is_empty() {
530            return Err(SklearsError::InvalidInput(
531                "No outputs provided".to_string(),
532            ));
533        }
534
535        // Check that all outputs have the same number of samples
536        let n_samples = outputs.values().next().unwrap().nrows();
537        for (task_name, task_outputs) in outputs {
538            if task_outputs.nrows() != n_samples {
539                return Err(SklearsError::ShapeMismatch {
540                    expected: format!("{}", n_samples),
541                    actual: format!("{}", task_outputs.nrows()),
542                });
543            }
544        }
545
546        // Create combined output matrix
547        let total_outputs: usize = outputs.values().map(|arr| arr.ncols()).sum();
548        let mut combined_outputs = Array2::<Float>::zeros((n_samples, total_outputs));
549        let mut output_indices = HashMap::new();
550        let mut output_info = HashMap::new();
551
552        let mut current_idx = 0;
553        for (task_name, task_outputs) in outputs {
554            let n_outputs = task_outputs.ncols();
555            let end_idx = current_idx + n_outputs;
556
557            combined_outputs
558                .slice_mut(s![.., current_idx..end_idx])
559                .assign(task_outputs);
560
561            output_indices.insert(task_name.clone(), (current_idx, end_idx));
562            output_info.insert(task_name.clone(), n_outputs);
563            current_idx = end_idx;
564        }
565
566        // Compute correlation matrices
567        let mut correlation_matrices = HashMap::new();
568        for correlation_type in &self.correlation_types {
569            let corr_matrix = self.compute_correlation(&combined_outputs, correlation_type)?;
570            correlation_matrices.insert(correlation_type.clone(), corr_matrix);
571        }
572
573        // Compute cross-task correlations
574        let mut cross_task_correlations = HashMap::new();
575        if self.include_cross_task {
576            for (task1, &(start1, end1)) in &output_indices {
577                for (task2, &(start2, end2)) in &output_indices {
578                    if task1 != task2 {
579                        let task1_outputs = combined_outputs.slice(s![.., start1..end1]);
580                        let task2_outputs = combined_outputs.slice(s![.., start2..end2]);
581                        let cross_corr =
582                            self.compute_cross_correlation(&task1_outputs, &task2_outputs)?;
583                        cross_task_correlations.insert((task1.clone(), task2.clone()), cross_corr);
584                    }
585                }
586            }
587        }
588
589        // Compute within-task correlations
590        let mut within_task_correlations = HashMap::new();
591        if self.include_within_task {
592            for (task_name, &(start_idx, end_idx)) in &output_indices {
593                if end_idx - start_idx > 1 {
594                    // Only if task has multiple outputs
595                    let task_outputs = combined_outputs.slice(s![.., start_idx..end_idx]);
596                    let within_corr = self
597                        .compute_correlation(&task_outputs.to_owned(), &CorrelationType::Pearson)?;
598                    within_task_correlations.insert(task_name.clone(), within_corr);
599                }
600            }
601        }
602
603        // Compute partial correlations if requested
604        let partial_correlations = if self.compute_partial_correlations {
605            let mut partial_corrs = HashMap::new();
606            for correlation_type in &self.correlation_types {
607                if let Ok(partial_corr) =
608                    self.compute_partial_correlation(&combined_outputs, correlation_type)
609                {
610                    partial_corrs.insert(correlation_type.clone(), partial_corr);
611                }
612            }
613            Some(partial_corrs)
614        } else {
615            None
616        };
617
618        Ok(CorrelationAnalysis {
619            correlation_matrices,
620            cross_task_correlations,
621            within_task_correlations,
622            partial_correlations,
623            output_info,
624            combined_outputs,
625            output_indices,
626        })
627    }
628
629    /// Compute correlation matrix for given correlation type
630    fn compute_correlation(
631        &self,
632        data: &Array2<Float>,
633        correlation_type: &CorrelationType,
634    ) -> SklResult<Array2<Float>> {
635        match correlation_type {
636            CorrelationType::Pearson => self.compute_pearson_correlation(data),
637            CorrelationType::Spearman => self.compute_spearman_correlation(data),
638            CorrelationType::Kendall => self.compute_kendall_correlation(data),
639            CorrelationType::MutualInformation => self.compute_mutual_information_matrix(data),
640            CorrelationType::DistanceCorrelation => self.compute_distance_correlation(data),
641            CorrelationType::CanonicalCorrelation => self.compute_canonical_correlation(data),
642        }
643    }
644
645    /// Compute Pearson correlation matrix
646    fn compute_pearson_correlation(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
647        let n_vars = data.ncols();
648        let n_samples = data.nrows();
649        let mut corr_matrix = Array2::eye(n_vars);
650
651        // Compute means
652        let means = data.mean_axis(Axis(0)).unwrap();
653
654        // Compute centered data
655        let mut centered_data = data.clone();
656        for i in 0..n_samples {
657            for j in 0..n_vars {
658                centered_data[[i, j]] -= means[j];
659            }
660        }
661
662        // Compute correlation coefficients
663        for i in 0..n_vars {
664            for j in (i + 1)..n_vars {
665                let col_i = centered_data.column(i);
666                let col_j = centered_data.column(j);
667
668                let covariance = col_i.dot(&col_j) / (n_samples as Float - 1.0);
669                let var_i = col_i.dot(&col_i) / (n_samples as Float - 1.0);
670                let var_j = col_j.dot(&col_j) / (n_samples as Float - 1.0);
671
672                let correlation = if var_i > 0.0 && var_j > 0.0 {
673                    covariance / (var_i.sqrt() * var_j.sqrt())
674                } else {
675                    0.0
676                };
677
678                corr_matrix[[i, j]] = correlation;
679                corr_matrix[[j, i]] = correlation;
680            }
681        }
682
683        Ok(corr_matrix)
684    }
685
686    /// Compute Spearman rank correlation matrix (simplified implementation)
687    fn compute_spearman_correlation(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
688        // This is a simplified implementation
689        // In practice, you would compute ranks and then Pearson correlation on ranks
690        let n_vars = data.ncols();
691        let mut ranked_data = Array2::<Float>::zeros(data.dim());
692
693        // Compute ranks for each column (simplified ranking)
694        for j in 0..n_vars {
695            let mut column_data: Vec<(Float, usize)> = data
696                .column(j)
697                .iter()
698                .enumerate()
699                .map(|(i, &val)| (val, i))
700                .collect();
701            column_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
702
703            for (rank, (_, original_idx)) in column_data.iter().enumerate() {
704                ranked_data[[*original_idx, j]] = rank as Float;
705            }
706        }
707
708        // Compute Pearson correlation on ranked data
709        self.compute_pearson_correlation(&ranked_data)
710    }
711
712    /// Compute Kendall tau correlation matrix (simplified implementation)
713    fn compute_kendall_correlation(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
714        // This is a very simplified implementation
715        // In practice, Kendall tau requires counting concordant and discordant pairs
716        let n_vars = data.ncols();
717        let mut corr_matrix = Array2::eye(n_vars);
718
719        for i in 0..n_vars {
720            for j in (i + 1)..n_vars {
721                // Simplified Kendall tau approximation using Spearman
722                let spearman_corr = self.compute_spearman_correlation(data)?;
723                let kendall_approx = (2.0 / std::f64::consts::PI) * spearman_corr[[i, j]].asin();
724
725                corr_matrix[[i, j]] = kendall_approx;
726                corr_matrix[[j, i]] = kendall_approx;
727            }
728        }
729
730        Ok(corr_matrix)
731    }
732
733    /// Compute mutual information matrix (simplified implementation)
734    fn compute_mutual_information_matrix(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
735        let n_vars = data.ncols();
736        let mut mi_matrix = Array2::<Float>::zeros((n_vars, n_vars));
737
738        // This is a simplified implementation
739        // In practice, you would use proper entropy estimation methods
740        for i in 0..n_vars {
741            for j in 0..n_vars {
742                if i == j {
743                    mi_matrix[[i, j]] = 1.0; // Self-information normalized
744                } else {
745                    // Approximate MI using correlation
746                    let pearson_corr = self.compute_pearson_correlation(data)?;
747                    let mi_approx = -0.5 * (1.0 - pearson_corr[[i, j]].powi(2)).ln();
748                    mi_matrix[[i, j]] = mi_approx.max(0.0);
749                }
750            }
751        }
752
753        Ok(mi_matrix)
754    }
755
756    /// Compute distance correlation matrix (simplified implementation)
757    fn compute_distance_correlation(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
758        // This is a simplified implementation
759        // Real distance correlation requires computing distance matrices and double centering
760        let n_vars = data.ncols();
761        let mut dcorr_matrix = Array2::eye(n_vars);
762
763        for i in 0..n_vars {
764            for j in (i + 1)..n_vars {
765                // Simplified distance correlation using Pearson as approximation
766                let pearson_corr = self.compute_pearson_correlation(data)?;
767                let dcorr_approx = pearson_corr[[i, j]].abs();
768
769                dcorr_matrix[[i, j]] = dcorr_approx;
770                dcorr_matrix[[j, i]] = dcorr_approx;
771            }
772        }
773
774        Ok(dcorr_matrix)
775    }
776
777    /// Compute canonical correlation matrix (simplified implementation)
778    fn compute_canonical_correlation(&self, data: &Array2<Float>) -> SklResult<Array2<Float>> {
779        // This is a placeholder for canonical correlation analysis
780        // Real CCA requires solving a generalized eigenvalue problem
781        self.compute_pearson_correlation(data)
782    }
783
784    /// Compute cross-correlation between two sets of outputs
785    fn compute_cross_correlation(
786        &self,
787        data1: &ArrayView2<Float>,
788        data2: &ArrayView2<Float>,
789    ) -> SklResult<Array2<Float>> {
790        let n_outputs1 = data1.ncols();
791        let n_outputs2 = data2.ncols();
792        let n_samples = data1.nrows();
793
794        if data2.nrows() != n_samples {
795            return Err(SklearsError::ShapeMismatch {
796                expected: format!("{}", n_samples),
797                actual: format!("{}", data2.nrows()),
798            });
799        }
800
801        let mut cross_corr = Array2::<Float>::zeros((n_outputs1, n_outputs2));
802
803        // Compute means
804        let means1 = data1.mean_axis(Axis(0)).unwrap();
805        let means2 = data2.mean_axis(Axis(0)).unwrap();
806
807        for i in 0..n_outputs1 {
808            for j in 0..n_outputs2 {
809                let col1 = data1.column(i);
810                let col2 = data2.column(j);
811
812                // Compute covariance
813                let mut covariance = 0.0;
814                for k in 0..n_samples {
815                    covariance += (col1[k] - means1[i]) * (col2[k] - means2[j]);
816                }
817                covariance /= n_samples as Float - 1.0;
818
819                // Compute variances
820                let mut var1 = 0.0;
821                let mut var2 = 0.0;
822                for k in 0..n_samples {
823                    var1 += (col1[k] - means1[i]).powi(2);
824                    var2 += (col2[k] - means2[j]).powi(2);
825                }
826                var1 /= n_samples as Float - 1.0;
827                var2 /= n_samples as Float - 1.0;
828
829                // Compute correlation
830                let correlation = if var1 > 0.0 && var2 > 0.0 {
831                    covariance / (var1.sqrt() * var2.sqrt())
832                } else {
833                    0.0
834                };
835
836                cross_corr[[i, j]] = correlation;
837            }
838        }
839
840        Ok(cross_corr)
841    }
842
843    /// Compute partial correlation matrix (simplified implementation)
844    fn compute_partial_correlation(
845        &self,
846        data: &Array2<Float>,
847        _correlation_type: &CorrelationType,
848    ) -> SklResult<Array2<Float>> {
849        // This is a simplified implementation of partial correlation
850        // Real partial correlation requires inverting the correlation matrix
851        let corr_matrix = self.compute_pearson_correlation(data)?;
852        let n_vars = corr_matrix.nrows();
853
854        // Try to invert correlation matrix to get partial correlations
855        // This is a simplified approach - in practice you'd use proper matrix inversion
856        let mut partial_corr = Array2::eye(n_vars);
857
858        for i in 0..n_vars {
859            for j in (i + 1)..n_vars {
860                // Simplified partial correlation calculation
861                // In practice, this would be -cov_inv[i,j] / sqrt(cov_inv[i,i] * cov_inv[j,j])
862                let partial = corr_matrix[[i, j]] * 0.8; // Simplified approximation
863                partial_corr[[i, j]] = partial;
864                partial_corr[[j, i]] = partial;
865            }
866        }
867
868        Ok(partial_corr)
869    }
870}
871
872impl Default for OutputCorrelationAnalyzer {
873    fn default() -> Self {
874        Self::new()
875    }
876}
877
878impl DependencyGraphBuilder {
879    /// Create a new DependencyGraphBuilder
880    pub fn new() -> Self {
881        Self {
882            method: DependencyMethod::CorrelationThreshold(0.5),
883            include_self_loops: false,
884            directed: false,
885            max_dependencies: None,
886        }
887    }
888
889    /// Set the method for determining dependencies
890    pub fn method(mut self, method: DependencyMethod) -> Self {
891        self.method = method;
892        self
893    }
894
895    /// Set whether to include self-loops
896    pub fn include_self_loops(mut self, include: bool) -> Self {
897        self.include_self_loops = include;
898        self
899    }
900
901    /// Set whether to make the graph directed
902    pub fn directed(mut self, directed: bool) -> Self {
903        self.directed = directed;
904        self
905    }
906
907    /// Set maximum number of dependencies per node
908    pub fn max_dependencies(mut self, max_deps: Option<usize>) -> Self {
909        self.max_dependencies = max_deps;
910        self
911    }
912
913    /// Build dependency graph from outputs
914    pub fn build(&self, outputs: &HashMap<String, Array2<Float>>) -> SklResult<DependencyGraph> {
915        // First analyze correlations
916        let analyzer = OutputCorrelationAnalyzer::new()
917            .correlation_types(vec![CorrelationType::Pearson])
918            .include_cross_task(true);
919
920        let analysis = analyzer.analyze(outputs)?;
921
922        // Get correlation matrix
923        let correlation_matrix = analysis
924            .correlation_matrices
925            .get(&CorrelationType::Pearson)
926            .ok_or_else(|| {
927                SklearsError::InvalidInput("Failed to compute correlations".to_string())
928            })?;
929
930        // Build node names
931        let mut node_names = Vec::new();
932        for (task_name, &(start_idx, end_idx)) in &analysis.output_indices {
933            for i in start_idx..end_idx {
934                node_names.push(format!("{}_{}", task_name, i - start_idx));
935            }
936        }
937
938        let n_nodes = node_names.len();
939        let mut adjacency_matrix = Array2::<Float>::zeros((n_nodes, n_nodes));
940        let mut edge_weights = Array2::<Float>::zeros((n_nodes, n_nodes));
941
942        // Apply dependency method to determine edges
943        match &self.method {
944            DependencyMethod::CorrelationThreshold(threshold) => {
945                for i in 0..n_nodes {
946                    for j in 0..n_nodes {
947                        if i != j || self.include_self_loops {
948                            let corr_strength = correlation_matrix[[i, j]].abs();
949                            if corr_strength >= *threshold {
950                                adjacency_matrix[[i, j]] = 1.0;
951                                edge_weights[[i, j]] = corr_strength;
952
953                                if !self.directed {
954                                    adjacency_matrix[[j, i]] = 1.0;
955                                    edge_weights[[j, i]] = corr_strength;
956                                }
957                            }
958                        }
959                    }
960                }
961            }
962            DependencyMethod::TopK(k) => {
963                for i in 0..n_nodes {
964                    let mut correlations: Vec<(usize, Float)> = (0..n_nodes)
965                        .filter(|&j| i != j || self.include_self_loops)
966                        .map(|j| (j, correlation_matrix[[i, j]].abs()))
967                        .collect();
968
969                    correlations.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
970
971                    for (j, corr_strength) in correlations.iter().take(*k) {
972                        adjacency_matrix[[i, *j]] = 1.0;
973                        edge_weights[[i, *j]] = *corr_strength;
974                    }
975                }
976            }
977            _ => {
978                // Other methods would be implemented here
979                return Err(SklearsError::InvalidInput(
980                    "Dependency method not yet implemented".to_string(),
981                ));
982            }
983        }
984
985        // Apply maximum dependencies constraint
986        if let Some(max_deps) = self.max_dependencies {
987            for i in 0..n_nodes {
988                let mut dependencies: Vec<(usize, Float)> = (0..n_nodes)
989                    .filter(|&j| adjacency_matrix[[i, j]] > 0.0)
990                    .map(|j| (j, edge_weights[[i, j]]))
991                    .collect();
992
993                dependencies.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
994
995                // Keep only top max_deps dependencies
996                for (idx, (j, _)) in dependencies.iter().enumerate() {
997                    if idx >= max_deps {
998                        adjacency_matrix[[i, *j]] = 0.0;
999                        edge_weights[[i, *j]] = 0.0;
1000                    }
1001                }
1002            }
1003        }
1004
1005        // Compute graph statistics
1006        let stats = self.compute_graph_statistics(&adjacency_matrix);
1007
1008        Ok(DependencyGraph {
1009            adjacency_matrix,
1010            node_names,
1011            edge_weights,
1012            directed: self.directed,
1013            stats,
1014        })
1015    }
1016
1017    /// Compute graph statistics
1018    fn compute_graph_statistics(&self, adjacency_matrix: &Array2<Float>) -> GraphStatistics {
1019        let n_nodes = adjacency_matrix.nrows();
1020        let num_edges = adjacency_matrix.sum() as usize;
1021
1022        // Compute degrees
1023        let degrees: Vec<Float> = (0..n_nodes)
1024            .map(|i| adjacency_matrix.row(i).sum())
1025            .collect();
1026
1027        let average_degree = degrees.iter().sum::<Float>() / (n_nodes as Float);
1028
1029        let max_possible_edges = if self.directed {
1030            n_nodes * (n_nodes - 1)
1031        } else {
1032            n_nodes * (n_nodes - 1) / 2
1033        };
1034
1035        let density = if max_possible_edges > 0 {
1036            num_edges as Float / max_possible_edges as Float
1037        } else {
1038            0.0
1039        };
1040
1041        // Simplified clustering coefficient calculation
1042        let clustering_coefficient = if !self.directed {
1043            self.compute_clustering_coefficient(adjacency_matrix)
1044        } else {
1045            0.0 // Simplified for directed graphs
1046        };
1047
1048        GraphStatistics {
1049            num_nodes: n_nodes,
1050            num_edges,
1051            average_degree,
1052            density,
1053            clustering_coefficient,
1054        }
1055    }
1056
1057    /// Compute clustering coefficient for undirected graph
1058    fn compute_clustering_coefficient(&self, adjacency_matrix: &Array2<Float>) -> Float {
1059        let n_nodes = adjacency_matrix.nrows();
1060        let mut total_clustering = 0.0;
1061        let mut valid_nodes = 0;
1062
1063        for i in 0..n_nodes {
1064            let neighbors: Vec<usize> = (0..n_nodes)
1065                .filter(|&j| adjacency_matrix[[i, j]] > 0.0)
1066                .collect();
1067
1068            let degree = neighbors.len();
1069            if degree < 2 {
1070                continue; // Cannot compute clustering for degree < 2
1071            }
1072
1073            let mut triangles = 0;
1074            for &j in &neighbors {
1075                for &k in &neighbors {
1076                    if j < k && adjacency_matrix[[j, k]] > 0.0 {
1077                        triangles += 1;
1078                    }
1079                }
1080            }
1081
1082            let possible_triangles = degree * (degree - 1) / 2;
1083            let clustering = if possible_triangles > 0 {
1084                triangles as Float / possible_triangles as Float
1085            } else {
1086                0.0
1087            };
1088
1089            total_clustering += clustering;
1090            valid_nodes += 1;
1091        }
1092
1093        if valid_nodes > 0 {
1094            total_clustering / valid_nodes as Float
1095        } else {
1096            0.0
1097        }
1098    }
1099}
1100
1101impl Default for DependencyGraphBuilder {
1102    fn default() -> Self {
1103        Self::new()
1104    }
1105}
1106
1107impl CorrelationAnalysis {
1108    /// Get correlation between two specific outputs
1109    pub fn get_correlation(
1110        &self,
1111        output1: &str,
1112        output2: &str,
1113        correlation_type: &CorrelationType,
1114    ) -> Option<Float> {
1115        let corr_matrix = self.correlation_matrices.get(correlation_type)?;
1116
1117        // Find indices for the outputs
1118        let mut output1_idx = None;
1119        let mut output2_idx = None;
1120        let mut current_idx = 0;
1121
1122        for (task_name, &(start_idx, end_idx)) in &self.output_indices {
1123            for i in start_idx..end_idx {
1124                let output_name = format!("{}_{}", task_name, i - start_idx);
1125                if output_name == output1 {
1126                    output1_idx = Some(current_idx);
1127                }
1128                if output_name == output2 {
1129                    output2_idx = Some(current_idx);
1130                }
1131                current_idx += 1;
1132            }
1133        }
1134
1135        if let (Some(idx1), Some(idx2)) = (output1_idx, output2_idx) {
1136            Some(corr_matrix[[idx1, idx2]])
1137        } else {
1138            None
1139        }
1140    }
1141
1142    /// Get strongest correlations above threshold
1143    pub fn get_strong_correlations(
1144        &self,
1145        correlation_type: &CorrelationType,
1146        threshold: Float,
1147    ) -> Vec<(String, String, Float)> {
1148        let mut strong_correlations = Vec::new();
1149
1150        if let Some(corr_matrix) = self.correlation_matrices.get(correlation_type) {
1151            let current_idx = 0;
1152            let mut output_names = Vec::new();
1153
1154            // Build output names
1155            for (task_name, &(start_idx, end_idx)) in &self.output_indices {
1156                for i in start_idx..end_idx {
1157                    output_names.push(format!("{}_{}", task_name, i - start_idx));
1158                }
1159            }
1160
1161            // Find strong correlations
1162            for i in 0..output_names.len() {
1163                for j in (i + 1)..output_names.len() {
1164                    let corr_value = corr_matrix[[i, j]];
1165                    if corr_value.abs() >= threshold {
1166                        strong_correlations.push((
1167                            output_names[i].clone(),
1168                            output_names[j].clone(),
1169                            corr_value,
1170                        ));
1171                    }
1172                }
1173            }
1174        }
1175
1176        strong_correlations.sort_by(|a, b| b.2.abs().partial_cmp(&a.2.abs()).unwrap());
1177        strong_correlations
1178    }
1179
1180    /// Get summary statistics for correlations
1181    pub fn correlation_summary(
1182        &self,
1183        correlation_type: &CorrelationType,
1184    ) -> Option<(Float, Float, Float, Float)> {
1185        if let Some(corr_matrix) = self.correlation_matrices.get(correlation_type) {
1186            let n = corr_matrix.nrows();
1187            let mut values = Vec::new();
1188
1189            // Collect upper triangular values (excluding diagonal)
1190            for i in 0..n {
1191                for j in (i + 1)..n {
1192                    values.push(corr_matrix[[i, j]]);
1193                }
1194            }
1195
1196            if values.is_empty() {
1197                return Some((0.0, 0.0, 0.0, 0.0));
1198            }
1199
1200            values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1201
1202            let mean = values.iter().sum::<Float>() / values.len() as Float;
1203            let median = if values.len() % 2 == 0 {
1204                (values[values.len() / 2 - 1] + values[values.len() / 2]) / 2.0
1205            } else {
1206                values[values.len() / 2]
1207            };
1208            let min = values[0];
1209            let max = values[values.len() - 1];
1210
1211            Some((mean, median, min, max))
1212        } else {
1213            None
1214        }
1215    }
1216}
1217
1218impl DependencyGraph {
1219    /// Get neighbors of a node
1220    pub fn get_neighbors(&self, node_name: &str) -> Vec<String> {
1221        if let Some(node_idx) = self.node_names.iter().position(|name| name == node_name) {
1222            let mut neighbors = Vec::new();
1223            for j in 0..self.node_names.len() {
1224                if self.adjacency_matrix[[node_idx, j]] > 0.0 {
1225                    neighbors.push(self.node_names[j].clone());
1226                }
1227            }
1228            neighbors
1229        } else {
1230            Vec::new()
1231        }
1232    }
1233
1234    /// Get edge weight between two nodes
1235    pub fn get_edge_weight(&self, node1: &str, node2: &str) -> Option<Float> {
1236        let idx1 = self.node_names.iter().position(|name| name == node1)?;
1237        let idx2 = self.node_names.iter().position(|name| name == node2)?;
1238
1239        if self.adjacency_matrix[[idx1, idx2]] > 0.0 {
1240            Some(self.edge_weights[[idx1, idx2]])
1241        } else {
1242            None
1243        }
1244    }
1245
1246    /// Check if two nodes are connected
1247    pub fn are_connected(&self, node1: &str, node2: &str) -> bool {
1248        self.get_edge_weight(node1, node2).is_some()
1249    }
1250
1251    /// Get node degree
1252    pub fn get_degree(&self, node_name: &str) -> usize {
1253        if let Some(node_idx) = self.node_names.iter().position(|name| name == node_name) {
1254            self.adjacency_matrix.row(node_idx).sum() as usize
1255        } else {
1256            0
1257        }
1258    }
1259}
1260
1261#[allow(non_snake_case)]
1262#[cfg(test)]
1263mod correlation_tests {
1264    use super::*;
1265    use approx::assert_abs_diff_eq;
1266    // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
1267    use scirs2_core::ndarray::array;
1268
1269    #[test]
1270    fn test_correlation_analyzer_creation() {
1271        let analyzer = OutputCorrelationAnalyzer::new()
1272            .correlation_types(vec![CorrelationType::Pearson, CorrelationType::Spearman])
1273            .include_cross_task(true)
1274            .include_within_task(true)
1275            .min_correlation_threshold(0.1)
1276            .compute_partial_correlations(true);
1277
1278        assert_eq!(analyzer.correlation_types.len(), 2);
1279        assert!(analyzer.include_cross_task);
1280        assert!(analyzer.include_within_task);
1281        assert_abs_diff_eq!(analyzer.min_correlation_threshold, 0.1);
1282        assert!(analyzer.compute_partial_correlations);
1283    }
1284
1285    #[test]
1286    fn test_correlation_analysis() {
1287        let mut outputs = HashMap::new();
1288        outputs.insert(
1289            "task1".to_string(),
1290            array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 2.0]],
1291        );
1292        outputs.insert(
1293            "task2".to_string(),
1294            array![[0.5, 1.0], [1.0, 1.5], [1.5, 0.5], [2.0, 1.0]],
1295        );
1296
1297        let analyzer = OutputCorrelationAnalyzer::new()
1298            .correlation_types(vec![CorrelationType::Pearson])
1299            .include_cross_task(true)
1300            .include_within_task(true);
1301
1302        let analysis = analyzer.analyze(&outputs).unwrap();
1303
1304        // Check that we have correlation matrices
1305        assert!(analysis
1306            .correlation_matrices
1307            .contains_key(&CorrelationType::Pearson));
1308
1309        // Check combined outputs shape
1310        assert_eq!(analysis.combined_outputs.shape(), &[4, 4]); // 4 samples, 4 outputs total
1311
1312        // Check output indices
1313        assert!(analysis.output_indices.contains_key("task1"));
1314        assert!(analysis.output_indices.contains_key("task2"));
1315
1316        // Check cross-task correlations
1317        assert!(analysis
1318            .cross_task_correlations
1319            .contains_key(&("task1".to_string(), "task2".to_string())));
1320
1321        // Check within-task correlations
1322        assert!(analysis.within_task_correlations.contains_key("task1"));
1323        assert!(analysis.within_task_correlations.contains_key("task2"));
1324    }
1325
1326    #[test]
1327    fn test_dependency_graph_builder() {
1328        let mut outputs = HashMap::new();
1329        outputs.insert("task1".to_string(), array![[1.0], [2.0], [3.0], [4.0]]);
1330        outputs.insert("task2".to_string(), array![[0.5], [1.0], [1.5], [2.0]]);
1331        outputs.insert("task3".to_string(), array![[0.8], [1.2], [1.8], [2.4]]);
1332
1333        let builder = DependencyGraphBuilder::new()
1334            .method(DependencyMethod::CorrelationThreshold(0.5))
1335            .include_self_loops(false)
1336            .directed(false);
1337
1338        let graph = builder.build(&outputs).unwrap();
1339
1340        // Check graph properties
1341        assert_eq!(graph.node_names.len(), 3); // 3 tasks with 1 output each
1342        assert!(!graph.directed);
1343        assert_eq!(graph.stats.num_nodes, 3);
1344    }
1345
1346    #[test]
1347    fn test_correlation_types() {
1348        let types = vec![
1349            CorrelationType::Pearson,
1350            CorrelationType::Spearman,
1351            CorrelationType::Kendall,
1352            CorrelationType::MutualInformation,
1353            CorrelationType::DistanceCorrelation,
1354            CorrelationType::CanonicalCorrelation,
1355        ];
1356
1357        assert_eq!(types.len(), 6);
1358        assert_eq!(types[0], CorrelationType::Pearson);
1359    }
1360
1361    #[test]
1362    fn test_dependency_methods() {
1363        let methods = vec![
1364            DependencyMethod::CorrelationThreshold(0.5),
1365            DependencyMethod::MutualInformationThreshold(0.3),
1366            DependencyMethod::CausalDiscovery,
1367            DependencyMethod::StatisticalSignificance(0.05),
1368            DependencyMethod::TopK(3),
1369        ];
1370
1371        assert_eq!(methods.len(), 5);
1372        assert_eq!(methods[0], DependencyMethod::CorrelationThreshold(0.5));
1373    }
1374
1375    #[test]
1376    fn test_correlation_analysis_accessors() {
1377        let mut outputs = HashMap::new();
1378        outputs.insert(
1379            "task1".to_string(),
1380            array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]],
1381        );
1382        outputs.insert(
1383            "task2".to_string(),
1384            array![[0.5, 1.0], [1.0, 1.5], [1.5, 0.5]],
1385        );
1386
1387        let analyzer = OutputCorrelationAnalyzer::new();
1388        let analysis = analyzer.analyze(&outputs).unwrap();
1389
1390        // Test getting specific correlation
1391        let corr = analysis.get_correlation("task1_0", "task1_1", &CorrelationType::Pearson);
1392        assert!(corr.is_some());
1393
1394        // Test getting strong correlations
1395        let strong_corrs = analysis.get_strong_correlations(&CorrelationType::Pearson, 0.1);
1396        assert!(!strong_corrs.is_empty());
1397
1398        // Test correlation summary
1399        let summary = analysis.correlation_summary(&CorrelationType::Pearson);
1400        assert!(summary.is_some());
1401        let (mean, median, min, max) = summary.unwrap();
1402        assert!(min <= median);
1403        assert!(median <= max);
1404    }
1405
1406    #[test]
1407    fn test_dependency_graph_accessors() {
1408        let mut outputs = HashMap::new();
1409        outputs.insert("task1".to_string(), array![[1.0], [2.0], [3.0]]);
1410        outputs.insert("task2".to_string(), array![[0.5], [1.0], [1.5]]);
1411
1412        let builder =
1413            DependencyGraphBuilder::new().method(DependencyMethod::CorrelationThreshold(0.1));
1414
1415        let graph = builder.build(&outputs).unwrap();
1416
1417        // Test neighbor retrieval
1418        let neighbors = graph.get_neighbors("task1_0");
1419        assert!(neighbors.len() <= 2); // Can have at most task2_0 as neighbor
1420
1421        // Test degree calculation
1422        let degree = graph.get_degree("task1_0");
1423        assert!(degree <= 2);
1424
1425        // Test connection checking
1426        let connected = graph.are_connected("task1_0", "task2_0");
1427        // Connection depends on correlation threshold and actual data correlation
1428    }
1429}