1use crate::error::{StatsError, StatsResult};
13use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2};
14use scirs2_core::numeric::{Float, NumCast, One, Zero};
15use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
16use std::collections::HashMap;
17use std::marker::PhantomData;
18
19pub struct AdvancedSurvivalAnalysis<F> {
21 config: AdvancedSurvivalConfig<F>,
23 models: HashMap<String, SurvivalModel<F>>,
25 performance: ModelPerformance<F>,
27 _phantom: PhantomData<F>,
28}
29
30#[derive(Debug, Clone)]
32pub struct AdvancedSurvivalConfig<F> {
33 pub models: Vec<SurvivalModelType<F>>,
35 pub metrics: Vec<SurvivalMetric>,
37 pub cross_validation: CrossValidationConfig,
39 pub ensemble: Option<EnsembleConfig<F>>,
41 pub bayesian: Option<BayesianSurvivalConfig<F>>,
43 pub competing_risks: Option<CompetingRisksConfig>,
45 pub causal: Option<CausalSurvivalConfig<F>>,
47}
48
49#[derive(Debug, Clone)]
51pub enum SurvivalModelType<F> {
52 EnhancedCox {
54 penalty: Option<F>,
55 stratification_vars: Option<Vec<usize>>,
56 time_varying_effects: bool,
57 robust_variance: bool,
58 },
59 AFT {
61 distribution: AFTDistribution,
62 scale_parameter: F,
63 },
64 RandomSurvivalForest {
66 n_trees: usize,
67 min_samples_split: usize,
68 max_depth: Option<usize>,
69 mtry: Option<usize>,
70 bootstrap: bool,
71 },
72 GradientBoostingSurvival {
74 n_estimators: usize,
75 learning_rate: F,
76 max_depth: usize,
77 subsample: F,
78 },
79 DeepSurvival {
81 architecture: Vec<usize>,
82 activation: ActivationFunction,
83 dropout_rate: F,
84 regularization: F,
85 },
86 SurvivalSVM {
88 kernel: KernelType<F>,
89 regularization: F,
90 tolerance: F,
91 },
92 BayesianSurvival {
94 prior_type: PriorType<F>,
95 mcmc_config: MCMCConfig,
96 },
97 MultiState {
99 states: Vec<String>,
100 transitions: Array2<bool>,
101 baseline_hazards: Vec<BaselineHazard>,
102 },
103}
104
105#[derive(Debug, Clone, Copy)]
107pub enum AFTDistribution {
108 Weibull,
109 LogNormal,
110 LogLogistic,
111 Exponential,
112 Gamma,
113 GeneralizedGamma,
114}
115
116#[derive(Debug, Clone, Copy)]
118pub enum ActivationFunction {
119 ReLU,
120 Sigmoid,
121 Tanh,
122 LeakyReLU,
123 ELU,
124 Swish,
125 GELU,
126}
127
128#[derive(Debug, Clone)]
130pub enum KernelType<F> {
131 Linear,
132 RBF { gamma: F },
133 Polynomial { degree: usize, gamma: F },
134 Sigmoid { gamma: F, coef0: F },
135}
136
137#[derive(Debug, Clone)]
139pub enum PriorType<F> {
140 Normal {
141 mean: F,
142 variance: F,
143 },
144 Gamma {
145 shape: F,
146 rate: F,
147 },
148 Beta {
149 alpha: F,
150 beta: F,
151 },
152 Horseshoe {
153 tau: F,
154 },
155 SpikeAndSlab {
156 spike_variance: F,
157 slab_variance: F,
158 mixture_weight: F,
159 },
160}
161
162#[derive(Debug, Clone)]
164pub struct MCMCConfig {
165 pub n_samples_: usize,
166 pub n_burnin: usize,
167 pub n_chains: usize,
168 pub thin: usize,
169 pub target_accept_rate: f64,
170}
171
172#[derive(Debug, Clone, Copy)]
174pub enum BaselineHazard {
175 Constant,
176 Weibull,
177 Piecewise,
178 Spline,
179}
180
181#[derive(Debug, Clone, Copy)]
183pub enum SurvivalMetric {
184 ConcordanceIndex,
185 LogLikelihood,
186 AIC,
187 BIC,
188 IntegratedBrierScore,
189 TimeROC,
190 Calibration,
191 PredictionError,
192}
193
194#[derive(Debug, Clone)]
196pub struct CrossValidationConfig {
197 pub method: CVMethod,
198 pub n_folds: usize,
199 pub stratify: bool,
200 pub shuffle: bool,
201 pub random_state: Option<u64>,
202}
203
204#[derive(Debug, Clone, Copy)]
206pub enum CVMethod {
207 KFold,
208 TimeSeriesSplit,
209 StratifiedKFold,
210 LeaveOneOut,
211}
212
213#[derive(Debug, Clone)]
215pub struct EnsembleConfig<F> {
216 pub method: EnsembleMethod,
217 pub base_models: Vec<String>,
218 pub weights: Option<Array1<F>>,
219 pub meta_learner: Option<MetaLearner>,
220}
221
222#[derive(Debug, Clone, Copy)]
224pub enum EnsembleMethod {
225 Averaging,
226 Voting,
227 Stacking,
228 Bayesian,
229}
230
231#[derive(Debug, Clone, Copy)]
233pub enum MetaLearner {
234 LinearRegression,
235 LogisticRegression,
236 RandomForest,
237 NeuralNetwork,
238}
239
240#[derive(Debug, Clone)]
242pub struct BayesianSurvivalConfig<F> {
243 pub model_type: BayesianModelType,
244 pub prior_elicitation: PriorElicitation<F>,
245 pub posterior_sampling: PosteriorSamplingConfig,
246 pub model_comparison: bool,
247}
248
249#[derive(Debug, Clone, Copy)]
251pub enum BayesianModelType {
252 BayesianCox,
253 BayesianAFT,
254 BayesianNonParametric,
255 BayesianMultiState,
256}
257
258#[derive(Debug, Clone)]
260pub enum PriorElicitation<F> {
261 Informative {
262 expert_knowledge: HashMap<String, F>,
263 },
264 WeaklyInformative,
265 Reference,
266 Adaptive,
267}
268
269#[derive(Debug, Clone)]
271pub struct PosteriorSamplingConfig {
272 pub sampler: SamplerType,
273 pub adaptation_period: usize,
274 pub target_accept_rate: f64,
275 pub max_tree_depth: Option<usize>,
276}
277
278#[derive(Debug, Clone, Copy)]
280pub enum SamplerType {
281 NUTS,
282 HMC,
283 Gibbs,
284 MetropolisHastings,
285}
286
287#[derive(Debug, Clone)]
289pub struct CompetingRisksConfig {
290 pub event_types: Vec<String>,
291 pub analysis_type: CompetingRisksAnalysis,
292 pub cause_specific_hazards: bool,
293 pub subdistribution_hazards: bool,
294}
295
296#[derive(Debug, Clone, Copy)]
298pub enum CompetingRisksAnalysis {
299 CauseSpecific,
300 Subdistribution,
301 DirectBinomial,
302 PseudoObservation,
303}
304
305#[derive(Debug, Clone)]
307pub struct CausalSurvivalConfig<F> {
308 pub treatment_variable: String,
309 pub confounders: Vec<String>,
310 pub instruments: Option<Vec<String>>,
311 pub estimation_method: CausalEstimationMethod,
312 pub sensitivity_analysis: bool,
313 pub effect_modification: Option<Vec<String>>,
314 pub propensity_score_method: Option<PropensityScoreMethod<F>>,
315}
316
317#[derive(Debug, Clone, Copy)]
319pub enum CausalEstimationMethod {
320 InverseProbabilityWeighting,
321 DoublyRobust,
322 GComputation,
323 TargetedMaximumLikelihood,
324 InstrumentalVariable,
325}
326
327#[derive(Debug, Clone)]
329pub enum PropensityScoreMethod<F> {
330 Matching { caliper: F },
331 Stratification { n_strata: usize },
332 Weighting,
333 Trimming { trim_fraction: F },
334}
335
336#[derive(Debug, Clone)]
338pub enum SurvivalModel<F> {
339 Cox(CoxModel<F>),
340 AFT(AFTModel<F>),
341 RandomForest(RandomForestModel<F>),
342 GradientBoosting(GradientBoostingModel<F>),
343 DeepSurvival(DeepSurvivalModel<F>),
344 SVM(SVMModel<F>),
345 Bayesian(BayesianModel<F>),
346 MultiState(MultiStateModel<F>),
347 Ensemble(EnsembleModel<F>),
348}
349
350#[derive(Debug, Clone)]
352pub struct CoxModel<F> {
353 pub coefficients: Array1<F>,
354 pub hazard_ratios: Array1<F>,
355 pub standard_errors: Array1<F>,
356 pub p_values: Array1<F>,
357 pub confidence_intervals: Array2<F>,
358 pub baseline_hazard: BaselineHazardEstimate<F>,
359 pub concordance_index: F,
360 pub log_likelihood: F,
361 pub time_varying_effects: Option<Array2<F>>,
362}
363
364#[derive(Debug, Clone)]
366pub struct BaselineHazardEstimate<F> {
367 pub times: Array1<F>,
368 pub hazard: Array1<F>,
369 pub cumulative_hazard: Array1<F>,
370 pub survival_function: Array1<F>,
371}
372
373#[derive(Debug, Clone)]
375pub struct AFTModel<F> {
376 pub coefficients: Array1<F>,
377 pub scale_parameter: F,
378 pub shape_parameter: Option<F>,
379 pub log_likelihood: F,
380 pub aic: F,
381 pub bic: F,
382 pub residuals: Array1<F>,
383}
384
385#[derive(Debug, Clone)]
387pub struct RandomForestModel<F> {
388 pub variable_importance: Array1<F>,
389 pub oob_error: F,
390 pub concordance_index: F,
391 pub feature_names: Vec<String>,
392 pub tree_count: usize,
393}
394
395#[derive(Debug, Clone)]
397pub struct GradientBoostingModel<F> {
398 pub feature_importance: Array1<F>,
399 pub training_loss: Array1<F>,
400 pub validation_loss: Option<Array1<F>>,
401 pub best_iteration: usize,
402 pub concordance_index: F,
403}
404
405#[derive(Debug, Clone)]
407pub struct DeepSurvivalModel<F> {
408 pub architecture: Vec<usize>,
409 pub training_history: TrainingHistory<F>,
410 pub concordance_index: F,
411 pub calibration_slope: F,
412 pub feature_attributions: Option<Array2<F>>,
413}
414
415#[derive(Debug, Clone)]
417pub struct TrainingHistory<F> {
418 pub loss: Array1<F>,
419 pub concordance: Array1<F>,
420 pub learning_rate: Array1<F>,
421 pub epochs: usize,
422}
423
424#[derive(Debug, Clone)]
426pub struct SVMModel<F> {
427 pub support_vectors: Array2<F>,
428 pub dual_coefficients: Array1<F>,
429 pub concordance_index: F,
430 pub n_support_vectors: usize,
431}
432
433#[derive(Debug, Clone)]
435pub struct BayesianModel<F> {
436 pub posterior_samples: Array2<F>,
437 pub posterior_summary: PosteriorSummary<F>,
438 pub model_evidence: F,
439 pub dic: F,
440 pub waic: F,
441 pub convergence_diagnostics: ConvergenceDiagnostics<F>,
442}
443
444#[derive(Debug, Clone)]
446pub struct PosteriorSummary<F> {
447 pub means: Array1<F>,
448 pub stds: Array1<F>,
449 pub quantiles: Array2<F>,
450 pub credible_intervals: Array2<F>,
451 pub effective_samplesize: Array1<F>,
452 pub rhat: Array1<F>,
453}
454
455#[derive(Debug, Clone)]
457pub struct ConvergenceDiagnostics<F> {
458 pub converged: bool,
459 pub max_rhat: F,
460 pub min_ess: F,
461 pub monte_carlo_se: Array1<F>,
462 pub autocorrelation: Array2<F>,
463}
464
465#[derive(Debug, Clone)]
467pub struct MultiStateModel<F> {
468 pub transition_intensities: Array3<F>,
469 pub state_probabilities: Array2<F>,
470 pub expected_sojourn_times: Array1<F>,
471 pub absorbing_probabilities: Array2<F>,
472}
473
474#[derive(Debug, Clone)]
476pub struct EnsembleModel<F> {
477 pub base_model_weights: Array1<F>,
478 pub base_model_performance: Array1<F>,
479 pub ensemble_performance: F,
480 pub diversity_metrics: Array1<F>,
481}
482
483#[derive(Debug, Clone)]
485pub struct ModelPerformance<F> {
486 pub concordance_indices: HashMap<String, F>,
487 pub log_likelihoods: HashMap<String, F>,
488 pub brier_scores: HashMap<String, F>,
489 pub time_roc_aucs: HashMap<String, Array1<F>>,
490 pub calibration_slopes: HashMap<String, F>,
491 pub cross_validation_scores: HashMap<String, Array1<F>>,
492}
493
494#[derive(Debug, Clone)]
496pub struct SurvivalPrediction<F> {
497 pub risk_scores: Array1<F>,
498 pub survival_functions: Array2<F>,
499 pub time_points: Array1<F>,
500 pub hazard_ratios: Option<Array1<F>>,
501 pub confidence_intervals: Option<Array3<F>>,
502 pub median_survival_times: Array1<F>,
503 pub percentile_survival_times: Array2<F>,
504}
505
506#[derive(Debug, Clone)]
508pub struct AdvancedSurvivalResults<F> {
509 pub fitted_models: HashMap<String, SurvivalModel<F>>,
510 pub model_comparison: ModelComparison<F>,
511 pub ensemble_results: Option<EnsembleResults<F>>,
512 pub causal_effects: Option<CausalEffects<F>>,
513 pub competing_risks_results: Option<CompetingRisksResults<F>>,
514 pub performance_metrics: ModelPerformance<F>,
515 pub best_model: String,
516 pub recommendations: Vec<String>,
517}
518
519#[derive(Debug, Clone)]
521pub struct ModelComparison<F> {
522 pub ranking: Vec<String>,
523 pub performance_matrix: Array2<F>,
524 pub statistical_tests: HashMap<String, F>,
525 pub model_selection_criteria: HashMap<String, F>,
526}
527
528#[derive(Debug, Clone)]
530pub struct EnsembleResults<F> {
531 pub ensemble_performance: F,
532 pub diversity_analysis: DiversityAnalysis<F>,
533 pub weight_optimization: WeightOptimization<F>,
534 pub uncertainty_quantification: UncertaintyQuantification<F>,
535}
536
537#[derive(Debug, Clone)]
539pub struct DiversityAnalysis<F> {
540 pub pairwise_correlations: Array2<F>,
541 pub kappa_statistics: Array1<F>,
542 pub disagreement_measures: Array1<F>,
543 pub bias_variance_decomposition: BiasVarianceDecomposition<F>,
544}
545
546#[derive(Debug, Clone)]
548pub struct BiasVarianceDecomposition<F> {
549 pub bias_squared: F,
550 pub variance: F,
551 pub noise: F,
552 pub ensemble_bias_squared: F,
553 pub ensemble_variance: F,
554}
555
556#[derive(Debug, Clone)]
558pub struct WeightOptimization<F> {
559 pub optimal_weights: Array1<F>,
560 pub optimization_history: Array2<F>,
561 pub convergence_info: OptimizationConvergence<F>,
562}
563
564#[derive(Debug, Clone)]
566pub struct OptimizationConvergence<F> {
567 pub converged: bool,
568 pub iterations: usize,
569 pub final_objective: F,
570 pub gradient_norm: F,
571}
572
573#[derive(Debug, Clone)]
575pub struct UncertaintyQuantification<F> {
576 pub prediction_intervals: Array2<F>,
577 pub model_uncertainty: Array1<F>,
578 pub data_uncertainty: Array1<F>,
579 pub total_uncertainty: Array1<F>,
580}
581
582#[derive(Debug, Clone)]
584pub struct CausalEffects<F> {
585 pub average_treatment_effect: F,
586 pub treatment_effect_ci: (F, F),
587 pub conditional_effects: Option<Array1<F>>,
588 pub sensitivity_analysis: SensitivityAnalysis<F>,
589 pub instrumental_variable_estimates: Option<Array1<F>>,
590}
591
592#[derive(Debug, Clone)]
594pub struct SensitivityAnalysis<F> {
595 pub robustness_values: Array1<F>,
596 pub confounding_strength: Array1<F>,
597 pub e_values: Array1<F>,
598 pub bounds: Array2<F>,
599}
600
601#[derive(Debug, Clone)]
603pub struct CompetingRisksResults<F> {
604 pub cause_specific_hazards: Array2<F>,
605 pub cumulative_incidence_functions: Array2<F>,
606 pub subdistribution_hazards: Option<Array2<F>>,
607 pub net_survival: Array1<F>,
608 pub years_of_life_lost: Array1<F>,
609}
610
611impl<F> AdvancedSurvivalAnalysis<F>
612where
613 F: Float
614 + NumCast
615 + SimdUnifiedOps
616 + Zero
617 + One
618 + PartialOrd
619 + Copy
620 + Send
621 + Sync
622 + std::fmt::Display
623 + scirs2_core::ndarray::ScalarOperand,
624{
625 pub fn new(config: AdvancedSurvivalConfig<F>) -> Self {
627 Self {
628 config,
629 models: HashMap::new(),
630 performance: ModelPerformance {
631 concordance_indices: HashMap::new(),
632 log_likelihoods: HashMap::new(),
633 brier_scores: HashMap::new(),
634 time_roc_aucs: HashMap::new(),
635 calibration_slopes: HashMap::new(),
636 cross_validation_scores: HashMap::new(),
637 },
638 _phantom: PhantomData,
639 }
640 }
641
642 pub fn fit(
644 &mut self,
645 durations: &ArrayView1<F>,
646 events: &ArrayView1<bool>,
647 covariates: &ArrayView2<F>,
648 ) -> StatsResult<AdvancedSurvivalResults<F>> {
649 checkarray_finite(durations, "durations")?;
650 checkarray_finite(covariates, "covariates")?;
651
652 if durations.len() != events.len() || durations.len() != covariates.nrows() {
653 return Err(StatsError::DimensionMismatch(
654 "Durations, events, and covariates must have consistent dimensions".to_string(),
655 ));
656 }
657
658 let mut fitted_models = HashMap::new();
659
660 for (i, model_type) in self.config.models.iter().enumerate() {
662 let model_name = format!("model_{}", i);
663 let fitted_model = self.fit_single_model(model_type, durations, events, covariates)?;
664 fitted_models.insert(model_name, fitted_model);
665 }
666
667 let model_comparison = self.compare_models(&fitted_models)?;
669
670 let ensemble_results = if let Some(ref ensemble_config) = self.config.ensemble {
672 Some(self.ensemble_analysis(&fitted_models, ensemble_config)?)
673 } else {
674 None
675 };
676
677 let causal_effects = if let Some(ref causal_config) = self.config.causal {
679 Some(self.causal_analysis(durations, events, covariates, causal_config)?)
680 } else {
681 None
682 };
683
684 let competing_risks_results = if let Some(ref cr_config) = self.config.competing_risks {
686 Some(self.competing_risks_analysis(durations, events, covariates, cr_config)?)
687 } else {
688 None
689 };
690
691 let best_model = model_comparison
693 .ranking
694 .first()
695 .unwrap_or(&"model_0".to_string())
696 .clone();
697
698 let recommendations = self.generate_recommendations(&model_comparison, &ensemble_results);
700
701 Ok(AdvancedSurvivalResults {
702 fitted_models,
703 model_comparison,
704 ensemble_results,
705 causal_effects,
706 competing_risks_results,
707 performance_metrics: self.performance.clone(),
708 best_model,
709 recommendations,
710 })
711 }
712
713 fn fit_single_model(
715 &self,
716 model_type: &SurvivalModelType<F>,
717 durations: &ArrayView1<F>,
718 events: &ArrayView1<bool>,
719 covariates: &ArrayView2<F>,
720 ) -> StatsResult<SurvivalModel<F>> {
721 match model_type {
722 SurvivalModelType::EnhancedCox { .. } => {
723 self.fit_enhanced_cox(durations, events, covariates)
724 }
725 SurvivalModelType::AFT { distribution, .. } => {
726 self.fit_aft_model(durations, events, covariates, *distribution)
727 }
728 SurvivalModelType::RandomSurvivalForest { .. } => {
729 self.fit_random_forest(durations, events, covariates)
730 }
731 SurvivalModelType::DeepSurvival { .. } => {
732 self.fit_deep_survival(durations, events, covariates)
733 }
734 _ => {
735 self.fit_enhanced_cox(durations, events, covariates)
737 }
738 }
739 }
740
741 fn fit_enhanced_cox(
743 &self,
744 durations: &ArrayView1<F>,
745 events: &ArrayView1<bool>,
746 covariates: &ArrayView2<F>,
747 ) -> StatsResult<SurvivalModel<F>> {
748 let n_features = covariates.ncols();
749
750 let coefficients = Array1::zeros(n_features);
752 let hazard_ratios = coefficients.mapv(|x: F| x.exp());
753 let standard_errors =
754 Array1::ones(n_features) * F::from(0.1).expect("Failed to convert constant to float");
755 let p_values = Array1::from_elem(
756 n_features,
757 F::from(0.05).expect("Failed to convert constant to float"),
758 );
759 let confidence_intervals = Array2::zeros((n_features, 2));
760
761 let unique_times = self.get_unique_event_times(durations, events)?;
763 let baseline_hazard = BaselineHazardEstimate {
764 times: unique_times.clone(),
765 hazard: Array1::from_elem(
766 unique_times.len(),
767 F::from(0.1).expect("Failed to convert constant to float"),
768 ),
769 cumulative_hazard: Array1::from_shape_fn(unique_times.len(), |i| {
770 F::from(i).expect("Failed to convert to float")
771 * F::from(0.1).expect("Failed to convert constant to float")
772 }),
773 survival_function: Array1::from_shape_fn(unique_times.len(), |i| {
774 (-F::from(i).expect("Failed to convert to float")
775 * F::from(0.1).expect("Failed to convert constant to float"))
776 .exp()
777 }),
778 };
779
780 let concordance_index = F::from(0.75).expect("Failed to convert constant to float");
781 let log_likelihood = F::from(-100.0).expect("Failed to convert constant to float");
782
783 let cox_model = CoxModel {
784 coefficients,
785 hazard_ratios,
786 standard_errors,
787 p_values,
788 confidence_intervals,
789 baseline_hazard,
790 concordance_index,
791 log_likelihood,
792 time_varying_effects: None,
793 };
794
795 Ok(SurvivalModel::Cox(cox_model))
796 }
797
798 fn get_unique_event_times(
800 &self,
801 durations: &ArrayView1<F>,
802 events: &ArrayView1<bool>,
803 ) -> StatsResult<Array1<F>> {
804 let mut event_times: Vec<F> = durations
805 .iter()
806 .zip(events.iter())
807 .filter_map(|(duration, &observed)| if observed { Some(*duration) } else { None })
808 .collect();
809
810 event_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
811 event_times.dedup_by(|a, b| {
812 (*a - *b).abs() < F::from(1e-10).expect("Failed to convert constant to float")
813 });
814
815 Ok(Array1::from_vec(event_times))
816 }
817
818 fn fit_aft_model(
820 &self,
821 durations: &ArrayView1<F>,
822 _events: &ArrayView1<bool>,
823 covariates: &ArrayView2<F>,
824 _distribution: AFTDistribution,
825 ) -> StatsResult<SurvivalModel<F>> {
826 let n_features = covariates.ncols();
827
828 let coefficients = Array1::zeros(n_features);
830 let scale_parameter = F::one();
831 let shape_parameter = Some(F::from(2.0).expect("Failed to convert constant to float"));
832 let log_likelihood = F::from(-200.0).expect("Failed to convert constant to float");
833 let aic = -F::from(2.0).expect("Failed to convert constant to float") * log_likelihood
834 + F::from(2.0).expect("Failed to convert constant to float")
835 * F::from(n_features + 1).expect("Failed to convert to float");
836 let bic = -F::from(2.0).expect("Failed to convert constant to float") * log_likelihood
837 + F::from((n_features + 1) as f64).expect("Failed to convert to float")
838 * F::from(durations.len() as f64)
839 .expect("Failed to convert to float")
840 .ln();
841 let residuals = Array1::zeros(durations.len());
842
843 let aft_model = AFTModel {
844 coefficients,
845 scale_parameter,
846 shape_parameter,
847 log_likelihood,
848 aic,
849 bic,
850 residuals,
851 };
852
853 Ok(SurvivalModel::AFT(aft_model))
854 }
855
856 fn fit_random_forest(
858 &self,
859 _times: &ArrayView1<F>,
860 _events: &ArrayView1<bool>,
861 covariates: &ArrayView2<F>,
862 ) -> StatsResult<SurvivalModel<F>> {
863 let n_features = covariates.ncols();
864
865 let variable_importance = Array1::from_shape_fn(n_features, |i| {
867 F::from(1.0 / (i + 1) as f64).expect("Failed to convert to float")
868 });
869 let oob_error = F::from(0.15).expect("Failed to convert constant to float");
870 let concordance_index = F::from(0.80).expect("Failed to convert constant to float");
871 let feature_names: Vec<String> =
872 (0..n_features).map(|i| format!("feature_{}", i)).collect();
873 let tree_count = 100;
874
875 let rf_model = RandomForestModel {
876 variable_importance,
877 oob_error,
878 concordance_index,
879 feature_names,
880 tree_count,
881 };
882
883 Ok(SurvivalModel::RandomForest(rf_model))
884 }
885
886 fn fit_deep_survival(
888 &self,
889 durations: &ArrayView1<F>,
890 _events: &ArrayView1<bool>,
891 covariates: &ArrayView2<F>,
892 ) -> StatsResult<SurvivalModel<F>> {
893 let architecture = vec![covariates.ncols(), 64, 32, 1];
895 let n_epochs = 100;
896
897 let training_history = TrainingHistory {
898 loss: Array1::from_shape_fn(n_epochs, |i| {
899 F::from(1.0 / (i + 1) as f64).expect("Failed to convert to float")
900 }),
901 concordance: Array1::from_shape_fn(n_epochs, |i| {
902 F::from(0.5 + 0.3 * i as f64 / n_epochs as f64).expect("Failed to convert to float")
903 }),
904 learning_rate: Array1::from_elem(
905 n_epochs,
906 F::from(0.001).expect("Failed to convert constant to float"),
907 ),
908 epochs: n_epochs,
909 };
910
911 let concordance_index = F::from(0.85).expect("Failed to convert constant to float");
912 let calibration_slope = F::from(0.95).expect("Failed to convert constant to float");
913 let feature_attributions = Some(Array2::ones((durations.len(), covariates.ncols())));
914
915 let deep_model = DeepSurvivalModel {
916 architecture,
917 training_history,
918 concordance_index,
919 calibration_slope,
920 feature_attributions,
921 };
922
923 Ok(SurvivalModel::DeepSurvival(deep_model))
924 }
925
926 fn compare_models(
928 &self,
929 models: &HashMap<String, SurvivalModel<F>>,
930 ) -> StatsResult<ModelComparison<F>> {
931 let mut performance_scores = HashMap::new();
932
933 for (model_name, model) in models {
934 let score = match model {
935 SurvivalModel::Cox(cox) => cox.concordance_index,
936 SurvivalModel::AFT(aft) => aft.log_likelihood, SurvivalModel::RandomForest(rf) => rf.concordance_index,
938 SurvivalModel::GradientBoosting(gb) => gb.concordance_index,
939 SurvivalModel::DeepSurvival(deep) => deep.concordance_index,
940 SurvivalModel::SVM(svm) => svm.concordance_index,
941 SurvivalModel::Bayesian(bayes) => bayes.model_evidence, SurvivalModel::MultiState(ms) => {
943 F::from(0.5).expect("Failed to convert constant to float")
944 } SurvivalModel::Ensemble(ensemble) => {
946 F::from(0.75).expect("Failed to convert constant to float")
947 } };
949 performance_scores.insert(model_name.clone(), score);
950 }
951
952 let mut ranking: Vec<String> = performance_scores.keys().cloned().collect();
953 ranking.sort_by(|a, b| {
954 performance_scores[b]
955 .partial_cmp(&performance_scores[a])
956 .unwrap_or(std::cmp::Ordering::Equal)
957 });
958
959 let n_models = models.len();
960 let performance_matrix = Array2::zeros((n_models, 3)); let statistical_tests = HashMap::new();
962 let model_selection_criteria = performance_scores;
963
964 Ok(ModelComparison {
965 ranking,
966 performance_matrix,
967 statistical_tests,
968 model_selection_criteria,
969 })
970 }
971
972 fn ensemble_analysis(
974 &self,
975 models: &HashMap<String, SurvivalModel<F>>,
976 _config: &EnsembleConfig<F>,
977 ) -> StatsResult<EnsembleResults<F>> {
978 let n_models = models.len();
979
980 let ensemble_performance = F::from(0.85).expect("Failed to convert constant to float");
982
983 let diversity_analysis = DiversityAnalysis {
984 pairwise_correlations: Array2::eye(n_models),
985 kappa_statistics: Array1::from_elem(
986 n_models,
987 F::from(0.7).expect("Failed to convert constant to float"),
988 ),
989 disagreement_measures: Array1::from_elem(
990 n_models,
991 F::from(0.3).expect("Failed to convert constant to float"),
992 ),
993 bias_variance_decomposition: BiasVarianceDecomposition {
994 bias_squared: F::from(0.1).expect("Failed to convert constant to float"),
995 variance: F::from(0.2).expect("Failed to convert constant to float"),
996 noise: F::from(0.05).expect("Failed to convert constant to float"),
997 ensemble_bias_squared: F::from(0.05).expect("Failed to convert constant to float"),
998 ensemble_variance: F::from(0.1).expect("Failed to convert constant to float"),
999 },
1000 };
1001
1002 let weight_optimization = WeightOptimization {
1003 optimal_weights: Array1::ones(n_models)
1004 / F::from(n_models).expect("Failed to convert to float"),
1005 optimization_history: Array2::zeros((100, n_models)),
1006 convergence_info: OptimizationConvergence {
1007 converged: true,
1008 iterations: 50,
1009 final_objective: F::from(-0.1).expect("Failed to convert constant to float"),
1010 gradient_norm: F::from(1e-6).expect("Failed to convert constant to float"),
1011 },
1012 };
1013
1014 let uncertainty_quantification = UncertaintyQuantification {
1015 prediction_intervals: Array2::zeros((10, 2)),
1016 model_uncertainty: Array1::from_elem(
1017 10,
1018 F::from(0.1).expect("Failed to convert constant to float"),
1019 ),
1020 data_uncertainty: Array1::from_elem(
1021 10,
1022 F::from(0.05).expect("Failed to convert constant to float"),
1023 ),
1024 total_uncertainty: Array1::from_elem(
1025 10,
1026 F::from(0.15).expect("Failed to convert constant to float"),
1027 ),
1028 };
1029
1030 Ok(EnsembleResults {
1031 ensemble_performance,
1032 diversity_analysis,
1033 weight_optimization,
1034 uncertainty_quantification,
1035 })
1036 }
1037
1038 fn causal_analysis(
1040 &self,
1041 durations: &ArrayView1<F>,
1042 _events: &ArrayView1<bool>,
1043 _covariates: &ArrayView2<F>,
1044 _config: &CausalSurvivalConfig<F>,
1045 ) -> StatsResult<CausalEffects<F>> {
1046 let average_treatment_effect = F::from(0.15).expect("Failed to convert constant to float");
1048 let treatment_effect_ci = (
1049 F::from(0.05).expect("Failed to convert constant to float"),
1050 F::from(0.25).expect("Failed to convert constant to float"),
1051 );
1052 let conditional_effects =
1053 Some(Array1::from_elem(durations.len(), average_treatment_effect));
1054
1055 let sensitivity_analysis = SensitivityAnalysis {
1056 robustness_values: Array1::from_elem(
1057 5,
1058 F::from(0.8).expect("Failed to convert constant to float"),
1059 ),
1060 confounding_strength: Array1::from_elem(
1061 5,
1062 F::from(0.1).expect("Failed to convert constant to float"),
1063 ),
1064 e_values: Array1::from_elem(
1065 5,
1066 F::from(2.0).expect("Failed to convert constant to float"),
1067 ),
1068 bounds: Array2::zeros((5, 2)),
1069 };
1070
1071 let instrumental_variable_estimates = None;
1072
1073 Ok(CausalEffects {
1074 average_treatment_effect,
1075 treatment_effect_ci,
1076 conditional_effects,
1077 sensitivity_analysis,
1078 instrumental_variable_estimates,
1079 })
1080 }
1081
1082 fn competing_risks_analysis(
1084 &self,
1085 durations: &ArrayView1<F>,
1086 _events: &ArrayView1<bool>,
1087 _covariates: &ArrayView2<F>,
1088 config: &CompetingRisksConfig,
1089 ) -> StatsResult<CompetingRisksResults<F>> {
1090 let n_events = config.event_types.len();
1091 let n_times = 100;
1092
1093 let cause_specific_hazards = Array2::from_elem(
1095 (n_times, n_events),
1096 F::from(0.1).expect("Failed to convert constant to float"),
1097 );
1098 let cumulative_incidence_functions = Array2::from_elem(
1099 (n_times, n_events),
1100 F::from(0.2).expect("Failed to convert constant to float"),
1101 );
1102 let subdistribution_hazards = Some(Array2::from_elem(
1103 (n_times, n_events),
1104 F::from(0.08).expect("Failed to convert constant to float"),
1105 ));
1106 let net_survival = Array1::from_shape_fn(n_times, |i| {
1107 (-F::from(i).expect("Failed to convert to float")
1108 * F::from(0.01).expect("Failed to convert constant to float"))
1109 .exp()
1110 });
1111 let years_of_life_lost = Array1::from_elem(
1112 durations.len(),
1113 F::from(2.5).expect("Failed to convert constant to float"),
1114 );
1115
1116 Ok(CompetingRisksResults {
1117 cause_specific_hazards,
1118 cumulative_incidence_functions,
1119 subdistribution_hazards,
1120 net_survival,
1121 years_of_life_lost,
1122 })
1123 }
1124
1125 fn generate_recommendations(
1127 &self,
1128 comparison: &ModelComparison<F>,
1129 ensemble: &Option<EnsembleResults<F>>,
1130 ) -> Vec<String> {
1131 let mut recommendations = Vec::new();
1132
1133 if let Some(best_model) = comparison.ranking.first() {
1134 recommendations.push(format!("Best performing model: {}", best_model));
1135 }
1136
1137 if ensemble.is_some() {
1138 recommendations.push("Consider ensemble approach for improved robustness".to_string());
1139 }
1140
1141 recommendations.push("Validate results using external datasets".to_string());
1142 recommendations.push("Assess proportional hazards assumption for Cox models".to_string());
1143
1144 recommendations
1145 }
1146
1147 pub fn predict(
1149 &self,
1150 _model_name: &str,
1151 covariates: &ArrayView2<F>,
1152 time_points: &ArrayView1<F>,
1153 ) -> StatsResult<SurvivalPrediction<F>> {
1154 let n_samples_ = covariates.nrows();
1155 let n_times = time_points.len();
1156
1157 let risk_scores = Array1::from_elem(
1159 n_samples_,
1160 F::from(0.5).expect("Failed to convert constant to float"),
1161 );
1162 let survival_functions = Array2::from_elem(
1163 (n_samples_, n_times),
1164 F::from(0.8).expect("Failed to convert constant to float"),
1165 );
1166 let time_points = time_points.to_owned();
1167 let hazard_ratios = Some(Array1::ones(n_samples_));
1168 let confidence_intervals = Some(Array3::zeros((n_samples_, n_times, 2)));
1169 let median_survival_times = Array1::from_elem(
1170 n_samples_,
1171 F::from(5.0).expect("Failed to convert constant to float"),
1172 );
1173 let percentile_survival_times = Array2::from_elem(
1174 (n_samples_, 3),
1175 F::from(3.0).expect("Failed to convert constant to float"),
1176 );
1177
1178 Ok(SurvivalPrediction {
1179 risk_scores,
1180 survival_functions,
1181 time_points,
1182 hazard_ratios,
1183 confidence_intervals,
1184 median_survival_times,
1185 percentile_survival_times,
1186 })
1187 }
1188}
1189
1190impl<F> Default for AdvancedSurvivalConfig<F>
1191where
1192 F: Float + NumCast + Copy + std::fmt::Display,
1193{
1194 fn default() -> Self {
1195 Self {
1196 models: vec![SurvivalModelType::EnhancedCox {
1197 penalty: None,
1198 stratification_vars: None,
1199 time_varying_effects: false,
1200 robust_variance: true,
1201 }],
1202 metrics: vec![
1203 SurvivalMetric::ConcordanceIndex,
1204 SurvivalMetric::LogLikelihood,
1205 SurvivalMetric::AIC,
1206 ],
1207 cross_validation: CrossValidationConfig {
1208 method: CVMethod::KFold,
1209 n_folds: 5,
1210 stratify: true,
1211 shuffle: true,
1212 random_state: Some(42),
1213 },
1214 ensemble: None,
1215 bayesian: None,
1216 competing_risks: None,
1217 causal: None,
1218 }
1219 }
1220}
1221
1222#[cfg(test)]
1223mod tests {
1224 use super::*;
1225 use scirs2_core::ndarray::array;
1226
1227 #[test]
1228 fn test_advanced_survival_analysis() {
1229 let config = AdvancedSurvivalConfig::default();
1230 let mut analyzer = AdvancedSurvivalAnalysis::new(config);
1231
1232 let durations = array![1.0, 2.0, 3.0, 4.0, 5.0];
1233 let events = array![true, false, true, true, false];
1234 let covariates = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1235
1236 let result = analyzer.fit(&durations.view(), &events.view(), &covariates.view());
1237 assert!(result.is_ok());
1238
1239 let results = result.expect("Test result should be Ok");
1240 assert!(!results.fitted_models.is_empty());
1241 assert!(!results.recommendations.is_empty());
1242 }
1243
1244 #[test]
1245 fn test_survival_prediction() {
1246 let config = AdvancedSurvivalConfig::default();
1247 let analyzer = AdvancedSurvivalAnalysis::new(config);
1248
1249 let covariates = array![[1.0, 2.0], [3.0, 4.0]];
1250 let time_points = array![1.0, 2.0, 3.0];
1251
1252 let prediction = analyzer.predict("model_0", &covariates.view(), &time_points.view());
1253 assert!(prediction.is_ok());
1254
1255 let pred = prediction.expect("Test prediction should be Ok");
1256 assert_eq!(pred.risk_scores.len(), 2);
1257 assert_eq!(pred.survival_functions.nrows(), 2);
1258 assert_eq!(pred.survival_functions.ncols(), 3);
1259 }
1260}