scirs2_stats/
advanced_integration.rs

1//! Advanced Statistical Analysis Integration
2//!
3//! This module provides high-level interfaces that integrate multiple advanced
4//! statistical methods for comprehensive data analysis workflows.
5
6use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::validation::*;
9
10use crate::bayesian::{BayesianLinearRegression, BayesianRegressionResult};
11use crate::mcmc::{GibbsSampler, MultivariateNormalGibbs};
12use crate::multivariate::{FactorAnalysis, FactorAnalysisResult, PCAResult, PCA};
13use crate::qmc::{halton, latin_hypercube, sobol};
14use crate::survival::{CoxPHModel, KaplanMeierEstimator};
15
16/// Comprehensive Bayesian analysis workflow
17#[derive(Debug, Clone)]
18pub struct BayesianAnalysisWorkflow {
19    /// Enable MCMC sampling
20    pub use_mcmc: bool,
21    /// Number of MCMC samples
22    pub n_mcmc_samples: usize,
23    /// MCMC burn-in period
24    pub mcmc_burnin: usize,
25    /// Random seed
26    pub random_seed: Option<u64>,
27}
28
29impl Default for BayesianAnalysisWorkflow {
30    fn default() -> Self {
31        Self {
32            use_mcmc: true,
33            n_mcmc_samples: 1000,
34            mcmc_burnin: 100,
35            random_seed: None,
36        }
37    }
38}
39
40/// Results of comprehensive Bayesian analysis
41#[derive(Debug, Clone)]
42pub struct BayesianAnalysisResult {
43    /// Bayesian regression results
44    pub regression: BayesianRegressionResult,
45    /// MCMC samples (if requested)
46    pub mcmc_samples: Option<Array2<f64>>,
47    /// Posterior predictive samples
48    pub predictive_samples: Option<Array2<f64>>,
49    /// Model comparison metrics
50    pub model_metrics: BayesianModelMetrics,
51}
52
53/// Bayesian model comparison metrics
54#[derive(Debug, Clone)]
55pub struct BayesianModelMetrics {
56    /// Log marginal likelihood
57    pub log_marginal_likelihood: f64,
58    /// Deviance Information Criterion
59    pub dic: f64,
60    /// Widely Applicable Information Criterion
61    pub waic: f64,
62    /// Leave-One-Out Cross-Validation
63    pub loo_ic: f64,
64}
65
66impl BayesianAnalysisWorkflow {
67    /// Create new Bayesian analysis workflow
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    /// Configure MCMC settings
73    pub fn with_mcmc(mut self, n_samples_: usize, burnin: usize) -> Self {
74        self.use_mcmc = true;
75        self.n_mcmc_samples = n_samples_;
76        self.mcmc_burnin = burnin;
77        self
78    }
79
80    /// Disable MCMC sampling
81    pub fn without_mcmc(mut self) -> Self {
82        self.use_mcmc = false;
83        self
84    }
85
86    /// Set random seed
87    pub fn with_seed(mut self, seed: u64) -> Self {
88        self.random_seed = Some(seed);
89        self
90    }
91
92    /// Perform comprehensive Bayesian analysis
93    pub fn analyze(
94        &self,
95        x: ArrayView2<f64>,
96        y: ArrayView1<f64>,
97    ) -> Result<BayesianAnalysisResult> {
98        checkarray_finite(&x, "x")?;
99        checkarray_finite(&y, "y")?;
100
101        let (n_samples_, n_features) = x.dim();
102        if y.len() != n_samples_ {
103            return Err(StatsError::DimensionMismatch(format!(
104                "y length ({}) must match x rows ({})",
105                y.len(),
106                n_samples_
107            )));
108        }
109
110        // Perform Bayesian linear regression
111        let bayesian_reg = BayesianLinearRegression::new(n_features, true)?;
112        let regression = bayesian_reg.fit(x, y)?;
113
114        // MCMC sampling if requested
115        let mcmc_samples = if self.use_mcmc {
116            Some(self.perform_mcmc_sampling(&regression, n_features)?)
117        } else {
118            None
119        };
120
121        // Generate predictive samples
122        let predictive_samples = if self.use_mcmc {
123            Some(self.generate_predictive_samples(&bayesian_reg, &regression, x)?)
124        } else {
125            None
126        };
127
128        // Compute model metrics
129        let model_metrics = self.compute_model_metrics(&regression, x, y)?;
130
131        Ok(BayesianAnalysisResult {
132            regression,
133            mcmc_samples,
134            predictive_samples,
135            model_metrics,
136        })
137    }
138
139    /// Perform MCMC sampling from posterior
140    fn perform_mcmc_sampling(
141        &self,
142        regression: &BayesianRegressionResult,
143        _n_features: usize,
144    ) -> Result<Array2<f64>> {
145        use scirs2_core::random::{rngs::StdRng, SeedableRng};
146
147        let mut rng = match self.random_seed {
148            Some(seed) => StdRng::seed_from_u64(seed),
149            None => {
150                use std::time::{SystemTime, UNIX_EPOCH};
151                let seed = SystemTime::now()
152                    .duration_since(UNIX_EPOCH)
153                    .unwrap_or_default()
154                    .as_secs();
155                StdRng::seed_from_u64(seed)
156            }
157        };
158
159        // Use Gibbs sampling for multivariate normal posterior
160        let gibbs_sampler = MultivariateNormalGibbs::from_precision(
161            regression.posterior_mean.clone(),
162            regression.posterior_covariance.clone(),
163        )?;
164
165        let mut sampler = GibbsSampler::new(gibbs_sampler, regression.posterior_mean.clone())?;
166
167        // Burn-in
168        for _ in 0..self.mcmc_burnin {
169            sampler.step(&mut rng)?;
170        }
171
172        // Collect samples
173        let samples = sampler.sample(self.n_mcmc_samples, &mut rng)?;
174        Ok(samples)
175    }
176
177    /// Generate posterior predictive samples
178    fn generate_predictive_samples(
179        &self,
180        bayesian_reg: &BayesianLinearRegression,
181        regression: &BayesianRegressionResult,
182        x_test: ArrayView2<f64>,
183    ) -> Result<Array2<f64>> {
184        use scirs2_core::random::{rngs::StdRng, SeedableRng};
185        use scirs2_core::random::{Distribution, Normal};
186
187        let mut rng = match self.random_seed {
188            Some(seed) => StdRng::seed_from_u64(seed),
189            None => {
190                use std::time::{SystemTime, UNIX_EPOCH};
191                let seed = SystemTime::now()
192                    .duration_since(UNIX_EPOCH)
193                    .unwrap_or_default()
194                    .as_secs();
195                StdRng::seed_from_u64(seed)
196            }
197        };
198
199        let n_test = x_test.nrows();
200        let mut predictive_samples = Array2::zeros((self.n_mcmc_samples, n_test));
201
202        // Generate predictive samples
203        for i in 0..self.n_mcmc_samples {
204            // Sample from posterior parameter distribution
205            let mut beta_sample = Array1::zeros(regression.posterior_mean.len());
206            for j in 0..beta_sample.len() {
207                let var = regression.posterior_covariance[[j, j]];
208                let normal =
209                    Normal::new(regression.posterior_mean[j], var.sqrt()).map_err(|e| {
210                        StatsError::ComputationError(format!("Failed to create normal: {}", e))
211                    })?;
212                beta_sample[j] = normal.sample(&mut rng);
213            }
214
215            // Generate predictions with this parameter sample
216            let pred_result = bayesian_reg.predict(x_test, regression)?;
217
218            // Add noise
219            let noise_std = (regression.posterior_beta / regression.posterior_alpha).sqrt();
220            let noise_normal = Normal::new(0.0, noise_std).map_err(|e| {
221                StatsError::ComputationError(format!("Failed to create noise normal: {}", e))
222            })?;
223
224            for j in 0..n_test {
225                let noise = noise_normal.sample(&mut rng);
226                predictive_samples[[i, j]] = pred_result.mean[j] + noise;
227            }
228        }
229
230        Ok(predictive_samples)
231    }
232
233    /// Compute Bayesian model comparison metrics
234    fn compute_model_metrics(
235        &self,
236        regression: &BayesianRegressionResult,
237        x: ArrayView2<f64>,
238        _y: ArrayView1<f64>,
239    ) -> Result<BayesianModelMetrics> {
240        let n_samples_ = x.nrows() as f64;
241        let n_params = regression.posterior_mean.len() as f64;
242
243        // Log marginal likelihood (already computed)
244        let log_marginal_likelihood = regression.log_marginal_likelihood;
245
246        // Simplified DIC calculation
247        let deviance = -2.0 * log_marginal_likelihood;
248        let effective_params = n_params; // Simplified
249        let dic = deviance + 2.0 * effective_params;
250
251        // Simplified WAIC (Watanabe-Akaike Information Criterion)
252        let waic = -2.0 * log_marginal_likelihood + 2.0 * effective_params;
253
254        // Simplified LOO-IC (Leave-One-Out Information Criterion)
255        let loo_ic = -2.0 * log_marginal_likelihood
256            + 2.0 * effective_params * n_samples_ / (n_samples_ - n_params - 1.0);
257
258        Ok(BayesianModelMetrics {
259            log_marginal_likelihood,
260            dic,
261            waic,
262            loo_ic,
263        })
264    }
265}
266
267/// Dimensionality reduction and analysis workflow
268#[derive(Debug, Clone)]
269pub struct DimensionalityAnalysisWorkflow {
270    /// Number of PCA components
271    pub n_pca_components: Option<usize>,
272    /// Number of factors for factor analysis
273    pub n_factors: Option<usize>,
274    /// Whether to use incremental PCA for large datasets
275    pub use_incremental_pca: bool,
276    /// PCA batch size (for incremental)
277    pub pca_batchsize: usize,
278    /// Random seed
279    pub random_seed: Option<u64>,
280}
281
282impl Default for DimensionalityAnalysisWorkflow {
283    fn default() -> Self {
284        Self {
285            n_pca_components: None,
286            n_factors: None,
287            use_incremental_pca: false,
288            pca_batchsize: 1000,
289            random_seed: None,
290        }
291    }
292}
293
294/// Results of dimensionality analysis
295#[derive(Debug, Clone)]
296pub struct DimensionalityAnalysisResult {
297    /// PCA results
298    pub pca: Option<PCAResult>,
299    /// Factor analysis results
300    pub factor_analysis: Option<FactorAnalysisResult>,
301    /// Recommended number of components/factors
302    pub recommendations: DimensionalityRecommendations,
303    /// Comparison metrics
304    pub comparison_metrics: DimensionalityMetrics,
305}
306
307/// Recommendations for dimensionality reduction
308#[derive(Debug, Clone)]
309pub struct DimensionalityRecommendations {
310    /// Optimal number of PCA components (Kaiser criterion)
311    pub optimal_pca_components: usize,
312    /// Optimal number of factors (parallel analysis)
313    pub optimal_factors: usize,
314    /// Variance explained by recommended components
315    pub explained_variance_ratio: f64,
316}
317
318/// Comparison metrics for dimensionality reduction methods
319#[derive(Debug, Clone)]
320pub struct DimensionalityMetrics {
321    /// Scree plot data (eigenvalues)
322    pub eigenvalues: Array1<f64>,
323    /// Cumulative explained variance
324    pub cumulative_variance: Array1<f64>,
325    /// Kaiser-Meyer-Olkin measure
326    pub kmo_measure: f64,
327    /// Bartlett's test statistic and p-value
328    pub bartlett_test: (f64, f64),
329}
330
331impl DimensionalityAnalysisWorkflow {
332    /// Create new dimensionality analysis workflow
333    pub fn new() -> Self {
334        Self::default()
335    }
336
337    /// Set PCA configuration
338    pub fn with_pca(
339        mut self,
340        n_components: Option<usize>,
341        incremental: bool,
342        batchsize: usize,
343    ) -> Self {
344        self.n_pca_components = n_components;
345        self.use_incremental_pca = incremental;
346        self.pca_batchsize = batchsize;
347        self
348    }
349
350    /// Set factor analysis configuration
351    pub fn with_factor_analysis(mut self, n_factors: Option<usize>) -> Self {
352        self.n_factors = n_factors;
353        self
354    }
355
356    /// Set random seed
357    pub fn with_seed(mut self, seed: u64) -> Self {
358        self.random_seed = Some(seed);
359        self
360    }
361
362    /// Perform comprehensive dimensionality analysis
363    pub fn analyze(&self, data: ArrayView2<f64>) -> Result<DimensionalityAnalysisResult> {
364        checkarray_finite(&data, "data")?;
365        let (n_samples_, n_features) = data.dim();
366
367        if n_samples_ < 3 {
368            return Err(StatsError::InvalidArgument(
369                "Need at least 3 samples for analysis".to_string(),
370            ));
371        }
372
373        // Perform PCA analysis
374        let pca = if self.use_incremental_pca && n_samples_ > self.pca_batchsize {
375            Some(self.perform_incremental_pca(data)?)
376        } else {
377            Some(self.perform_standard_pca(data)?)
378        };
379
380        // Perform factor analysis if requested
381        let factor_analysis = if self.n_factors.is_some() {
382            Some(self.perform_factor_analysis(data)?)
383        } else {
384            None
385        };
386
387        // Generate recommendations
388        let recommendations = self.generate_recommendations(data, &pca)?;
389
390        // Compute comparison metrics
391        let comparison_metrics = self.compute_metrics(data)?;
392
393        Ok(DimensionalityAnalysisResult {
394            pca,
395            factor_analysis,
396            recommendations,
397            comparison_metrics,
398        })
399    }
400
401    /// Perform standard PCA
402    fn perform_standard_pca(&self, data: ArrayView2<f64>) -> Result<PCAResult> {
403        let n_components = self
404            .n_pca_components
405            .unwrap_or(data.ncols().min(data.nrows()));
406
407        let pca = PCA::new()
408            .with_n_components(n_components)
409            .with_center(true)
410            .with_scale(false);
411
412        if let Some(seed) = self.random_seed {
413            pca.with_random_state(seed).fit(data)
414        } else {
415            pca.fit(data)
416        }
417    }
418
419    /// Perform incremental PCA for large datasets
420    fn perform_incremental_pca(&self, data: ArrayView2<f64>) -> Result<PCAResult> {
421        // For now, fall back to standard PCA since IncrementalPCA fields are private
422        // This would need to be implemented with public accessors in the actual IncrementalPCA
423        self.perform_standard_pca(data)
424    }
425
426    /// Perform factor analysis
427    fn perform_factor_analysis(&self, data: ArrayView2<f64>) -> Result<FactorAnalysisResult> {
428        use crate::multivariate::RotationType;
429
430        let n_factors = self.n_factors.unwrap_or(2);
431
432        let mut fa = FactorAnalysis::new(n_factors)?
433            .with_rotation(RotationType::Varimax)
434            .with_max_iter(1000)
435            .with_tolerance(1e-6);
436
437        if let Some(seed) = self.random_seed {
438            fa = fa.with_random_state(seed);
439        }
440
441        fa.fit(data)
442    }
443
444    /// Generate dimensionality recommendations
445    fn generate_recommendations(
446        &self,
447        data: ArrayView2<f64>,
448        pca: &Option<PCAResult>,
449    ) -> Result<DimensionalityRecommendations> {
450        use crate::multivariate::{efa::parallel_analysis, mle_components};
451
452        // Kaiser criterion for PCA (eigenvalues > 1)
453        let optimal_pca_components = if let Some(ref pca_result) = pca {
454            pca_result
455                .explained_variance
456                .iter()
457                .position(|&ev| ev < 1.0)
458                .unwrap_or(pca_result.explained_variance.len())
459        } else {
460            mle_components(data, None)?
461        };
462
463        // Parallel analysis for factor analysis
464        let optimal_factors = parallel_analysis(data, 100, 95.0, self.random_seed)?;
465
466        // Explained variance ratio
467        let explained_variance_ratio = if let Some(ref pca_result) = pca {
468            pca_result
469                .explained_variance_ratio
470                .slice(scirs2_core::ndarray::s![..optimal_pca_components])
471                .sum()
472        } else {
473            0.0
474        };
475
476        Ok(DimensionalityRecommendations {
477            optimal_pca_components,
478            optimal_factors,
479            explained_variance_ratio,
480        })
481    }
482
483    /// Compute comparison metrics
484    fn compute_metrics(&self, data: ArrayView2<f64>) -> Result<DimensionalityMetrics> {
485        use crate::multivariate::efa::{bartlett_test, kmo_test};
486
487        // Compute covariance matrix for eigenvalues
488        let mean = data.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
489        let mut centered = data.to_owned();
490        for mut row in centered.rows_mut() {
491            row -= &mean;
492        }
493
494        let cov = centered.t().dot(&centered) / (data.nrows() - 1) as f64;
495
496        // Compute eigenvalues
497        use scirs2_core::ndarray::ndarray_linalg::Eigh;
498        let eigenvalues = cov
499            .eigh(scirs2_core::ndarray::ndarray_linalg::UPLO::Upper)
500            .map_err(|e| {
501                StatsError::ComputationError(format!("Eigenvalue decomposition failed: {}", e))
502            })?
503            .0;
504
505        // Sort eigenvalues in descending order
506        let mut sorted_eigenvalues = eigenvalues.to_vec();
507        sorted_eigenvalues.sort_by(|a, b| b.partial_cmp(a).unwrap());
508        let eigenvalues = Array1::from_vec(sorted_eigenvalues);
509
510        // Cumulative variance
511        let total_variance = eigenvalues.sum();
512        let mut cumulative_variance = Array1::zeros(eigenvalues.len());
513        let mut cumsum = 0.0;
514        for i in 0..eigenvalues.len() {
515            cumsum += eigenvalues[i];
516            cumulative_variance[i] = cumsum / total_variance;
517        }
518
519        // KMO measure
520        let kmo_measure = kmo_test(data)?;
521
522        // Bartlett's test
523        let bartlett_test = bartlett_test(data)?;
524
525        Ok(DimensionalityMetrics {
526            eigenvalues,
527            cumulative_variance,
528            kmo_measure,
529            bartlett_test,
530        })
531    }
532}
533
534/// Quasi-Monte Carlo integration and optimization workflow
535#[derive(Debug, Clone)]
536pub struct QMCWorkflow {
537    /// Sequence type
538    pub sequence_type: QMCSequenceType,
539    /// Whether to use scrambling
540    pub scrambling: bool,
541    /// Number of dimensions
542    pub dimensions: usize,
543    /// Number of samples
544    pub n_samples_: usize,
545    /// Random seed
546    pub random_seed: Option<u64>,
547}
548
549/// QMC sequence types
550#[derive(Debug, Clone, Copy)]
551pub enum QMCSequenceType {
552    /// Sobol sequence
553    Sobol,
554    /// Halton sequence
555    Halton,
556    /// Latin Hypercube sampling
557    LatinHypercube,
558}
559
560/// QMC analysis results
561#[derive(Debug, Clone)]
562pub struct QMCResult {
563    /// Generated samples
564    pub samples: Array2<f64>,
565    /// Sequence type used
566    pub sequence_type: QMCSequenceType,
567    /// Quality metrics
568    pub quality_metrics: QMCQualityMetrics,
569}
570
571/// Quality metrics for QMC sequences
572#[derive(Debug, Clone)]
573pub struct QMCQualityMetrics {
574    /// Star discrepancy
575    pub star_discrepancy: f64,
576    /// Uniformity measure
577    pub uniformity: f64,
578    /// Coverage efficiency
579    pub coverage_efficiency: f64,
580}
581
582impl Default for QMCWorkflow {
583    fn default() -> Self {
584        Self {
585            sequence_type: QMCSequenceType::Sobol,
586            scrambling: true,
587            dimensions: 2,
588            n_samples_: 1000,
589            random_seed: None,
590        }
591    }
592}
593
594impl QMCWorkflow {
595    /// Create new QMC workflow
596    pub fn new(dimensions: usize, n_samples_: usize) -> Self {
597        Self {
598            dimensions,
599            n_samples_,
600            ..Default::default()
601        }
602    }
603
604    /// Set sequence type
605    pub fn with_sequence_type(mut self, sequence_type: QMCSequenceType) -> Self {
606        self.sequence_type = sequence_type;
607        self
608    }
609
610    /// Enable or disable scrambling
611    pub fn with_scrambling(mut self, scrambling: bool) -> Self {
612        self.scrambling = scrambling;
613        self
614    }
615
616    /// Set random seed
617    pub fn with_seed(mut self, seed: u64) -> Self {
618        self.random_seed = Some(seed);
619        self
620    }
621
622    /// Generate QMC samples with quality assessment
623    pub fn generate(&self) -> Result<QMCResult> {
624        check_positive(self.dimensions, "dimensions")?;
625        check_positive(self.n_samples_, "n_samples_")?;
626
627        // Generate samples based on sequence type
628        let samples = match self.sequence_type {
629            QMCSequenceType::Sobol => sobol(
630                self.n_samples_,
631                self.dimensions,
632                self.scrambling,
633                self.random_seed,
634            )?,
635            QMCSequenceType::Halton => halton(
636                self.n_samples_,
637                self.dimensions,
638                self.scrambling,
639                self.random_seed,
640            )?,
641            QMCSequenceType::LatinHypercube => {
642                latin_hypercube(self.n_samples_, self.dimensions, self.random_seed)?
643            }
644        };
645
646        // Compute quality metrics
647        let quality_metrics = self.compute_quality_metrics(&samples)?;
648
649        Ok(QMCResult {
650            samples,
651            sequence_type: self.sequence_type,
652            quality_metrics,
653        })
654    }
655
656    /// Compute quality metrics for the sequence
657    fn compute_quality_metrics(&self, samples: &Array2<f64>) -> Result<QMCQualityMetrics> {
658        use crate::qmc::star_discrepancy;
659
660        // Convert to format expected by star_discrepancy
661        let sample_points: Vec<Array1<f64>> = samples
662            .rows()
663            .into_iter()
664            .map(|row| row.to_owned())
665            .collect();
666
667        let samples_view = Array1::from_vec(sample_points);
668        let star_discrepancy = star_discrepancy(&samples_view.view())?;
669
670        // Compute uniformity measure (coefficient of variation of nearest neighbor distances)
671        let uniformity = self.compute_uniformity(samples)?;
672
673        // Compute coverage efficiency
674        let coverage_efficiency = self.compute_coverage_efficiency(samples)?;
675
676        Ok(QMCQualityMetrics {
677            star_discrepancy,
678            uniformity,
679            coverage_efficiency,
680        })
681    }
682
683    /// Compute uniformity measure
684    fn compute_uniformity(&self, samples: &Array2<f64>) -> Result<f64> {
685        let n_samples_ = samples.nrows();
686        let mut min_distances = Array1::zeros(n_samples_);
687
688        // Compute minimum distance to other points for each sample
689        for i in 0..n_samples_ {
690            let mut min_dist = f64::INFINITY;
691            for j in 0..n_samples_ {
692                if i != j {
693                    let mut dist = 0.0;
694                    for k in 0..self.dimensions {
695                        let diff = samples[[i, k]] - samples[[j, k]];
696                        dist += diff * diff;
697                    }
698                    dist = dist.sqrt();
699                    if dist < min_dist {
700                        min_dist = dist;
701                    }
702                }
703            }
704            min_distances[i] = min_dist;
705        }
706
707        // Coefficient of variation of minimum distances
708        let mean_dist = min_distances.mean().unwrap();
709        let var_dist = min_distances.var(1.0);
710        let uniformity = 1.0 / (var_dist.sqrt() / mean_dist); // Inverse CV
711
712        Ok(uniformity)
713    }
714
715    /// Compute coverage efficiency
716    fn compute_coverage_efficiency(&self, samples: &Array2<f64>) -> Result<f64> {
717        // Simple approximation: ratio of actual coverage to expected coverage
718        let n_bins = (self.n_samples_ as f64)
719            .powf(1.0 / self.dimensions as f64)
720            .ceil() as usize;
721        let mut occupied_bins = std::collections::HashSet::new();
722
723        for i in 0..samples.nrows() {
724            let mut bin_id = Vec::new();
725            for j in 0..self.dimensions {
726                let bin = (samples[[i, j]] * n_bins as f64).floor() as usize;
727                bin_id.push(bin.min(n_bins - 1));
728            }
729            occupied_bins.insert(bin_id);
730        }
731
732        let total_bins = n_bins.pow(self.dimensions as u32);
733        let coverage_efficiency = occupied_bins.len() as f64 / total_bins as f64;
734
735        Ok(coverage_efficiency)
736    }
737}
738
739/// Comprehensive survival analysis workflow
740#[derive(Debug, Clone)]
741pub struct SurvivalAnalysisWorkflow {
742    /// Confidence level for intervals
743    pub confidence_level: f64,
744    /// Whether to fit Cox model
745    pub fit_cox_model: bool,
746    /// Maximum iterations for Cox model
747    pub cox_max_iter: usize,
748    /// Convergence tolerance for Cox model
749    pub cox_tolerance: f64,
750}
751
752impl Default for SurvivalAnalysisWorkflow {
753    fn default() -> Self {
754        Self {
755            confidence_level: 0.95,
756            fit_cox_model: true,
757            cox_max_iter: 100,
758            cox_tolerance: 1e-6,
759        }
760    }
761}
762
763/// Comprehensive survival analysis results
764#[derive(Debug, Clone)]
765pub struct SurvivalAnalysisResult {
766    /// Kaplan-Meier estimator
767    pub kaplan_meier: crate::survival::KaplanMeierEstimator,
768    /// Cox proportional hazards model (if requested and covariates provided)
769    pub cox_model: Option<crate::survival::CoxPHModel>,
770    /// Survival summary statistics
771    pub summary_stats: SurvivalSummaryStats,
772}
773
774/// Summary statistics for survival analysis
775#[derive(Debug, Clone)]
776pub struct SurvivalSummaryStats {
777    /// Median survival time
778    pub median_survival: Option<f64>,
779    /// 25th percentile survival time
780    pub q25_survival: Option<f64>,
781    /// 75th percentile survival time
782    pub q75_survival: Option<f64>,
783    /// Event rate
784    pub event_rate: f64,
785    /// Censoring rate
786    pub censoring_rate: f64,
787}
788
789impl SurvivalAnalysisWorkflow {
790    /// Create new survival analysis workflow
791    pub fn new() -> Self {
792        Self::default()
793    }
794
795    /// Set confidence level
796    pub fn with_confidence_level(mut self, level: f64) -> Self {
797        self.confidence_level = level;
798        self
799    }
800
801    /// Configure Cox model fitting
802    pub fn with_cox_model(mut self, max_iter: usize, tolerance: f64) -> Self {
803        self.fit_cox_model = true;
804        self.cox_max_iter = max_iter;
805        self.cox_tolerance = tolerance;
806        self
807    }
808
809    /// Disable Cox model fitting
810    pub fn without_cox_model(mut self) -> Self {
811        self.fit_cox_model = false;
812        self
813    }
814
815    /// Perform comprehensive survival analysis
816    pub fn analyze(
817        &self,
818        durations: ArrayView1<f64>,
819        events: ArrayView1<bool>,
820        covariates: Option<ArrayView2<f64>>,
821    ) -> Result<SurvivalAnalysisResult> {
822        checkarray_finite(&durations, "durations")?;
823
824        if durations.len() != events.len() {
825            return Err(StatsError::DimensionMismatch(format!(
826                "durations length ({}) must match events length ({})",
827                durations.len(),
828                events.len()
829            )));
830        }
831
832        // Fit Kaplan-Meier estimator
833        let kaplan_meier =
834            KaplanMeierEstimator::fit(durations, events, Some(self.confidence_level))?;
835
836        // Fit Cox model if requested and covariates provided
837        let cox_model = if self.fit_cox_model {
838            if let Some(cov) = covariates {
839                Some(CoxPHModel::fit(
840                    durations,
841                    events,
842                    cov,
843                    Some(self.cox_max_iter),
844                    Some(self.cox_tolerance),
845                )?)
846            } else {
847                None
848            }
849        } else {
850            None
851        };
852
853        // Compute summary statistics
854        let summary_stats = self.compute_summary_stats(&durations, &events, &kaplan_meier)?;
855
856        Ok(SurvivalAnalysisResult {
857            kaplan_meier,
858            cox_model,
859            summary_stats,
860        })
861    }
862
863    /// Compute survival summary statistics
864    fn compute_summary_stats(
865        &self,
866        _durations: &ArrayView1<f64>,
867        events: &ArrayView1<bool>,
868        km: &KaplanMeierEstimator,
869    ) -> Result<SurvivalSummaryStats> {
870        // Event and censoring rates
871        let total_events: usize = events.iter().map(|&e| if e { 1 } else { 0 }).sum();
872        let total_observations = events.len();
873        let event_rate = total_events as f64 / total_observations as f64;
874        let censoring_rate = 1.0 - event_rate;
875
876        // Median survival time (already computed in KM estimator)
877        let median_survival = km.median_survival_time;
878
879        // Percentile survival times
880        let q25_survival = self.find_survival_percentile(km, 0.75)?; // 75% survival = 25th percentile time
881        let q75_survival = self.find_survival_percentile(km, 0.25)?; // 25% survival = 75th percentile time
882
883        Ok(SurvivalSummaryStats {
884            median_survival,
885            q25_survival,
886            q75_survival,
887            event_rate,
888            censoring_rate,
889        })
890    }
891
892    /// Find time at which survival probability equals target
893    fn find_survival_percentile(
894        &self,
895        km: &KaplanMeierEstimator,
896        target_survival: f64,
897    ) -> Result<Option<f64>> {
898        for i in 0..km.survival_function.len() {
899            if km.survival_function[i] <= target_survival {
900                return Ok(Some(km.event_times[i]));
901            }
902        }
903        Ok(None) // Target survival not reached
904    }
905}