Skip to main content

scirs2_stats/
survival_enhanced.rs

1//! Enhanced Survival Analysis
2//!
3//! This module provides comprehensive survival analysis methods including:
4//! - Enhanced Kaplan-Meier estimator with confidence intervals
5//! - Cox Proportional Hazards regression
6//! - Log-rank test for comparing survival curves
7//! - Parametric survival models (Weibull, Exponential)
8//! - Competing risks analysis
9
10#![allow(dead_code)]
11
12use crate::error::{StatsError, StatsResult};
13use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
14use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
15use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
16use std::marker::PhantomData;
17
18/// Enhanced Kaplan-Meier estimator
19#[derive(Debug, Clone)]
20pub struct EnhancedKaplanMeier<F> {
21    /// Event times
22    pub event_times: Array1<F>,
23    /// Survival probabilities
24    pub survival_function: Array1<F>,
25    /// Confidence intervals
26    pub confidence_intervals: Option<(Array1<F>, Array1<F>)>,
27    /// Number at risk
28    pub at_risk: Array1<usize>,
29    /// Number of events
30    pub events: Array1<usize>,
31    /// Median survival time
32    pub median_survival_time: Option<F>,
33    /// Mean survival time
34    pub mean_survival_time: Option<F>,
35    /// Confidence level used
36    pub confidence_level: F,
37}
38
39impl<F> EnhancedKaplanMeier<F>
40where
41    F: Float
42        + Zero
43        + One
44        + Copy
45        + Send
46        + Sync
47        + SimdUnifiedOps
48        + FromPrimitive
49        + PartialOrd
50        + std::fmt::Display,
51{
52    /// Fit enhanced Kaplan-Meier estimator
53    pub fn fit(
54        durations: &ArrayView1<F>,
55        event_observed: &ArrayView1<bool>,
56        confidence_level: Option<F>,
57    ) -> StatsResult<Self> {
58        checkarray_finite(durations, "durations")?;
59
60        if durations.len() != event_observed.len() {
61            return Err(StatsError::DimensionMismatch(format!(
62                "Durations length ({}) must match event_observed length ({})",
63                durations.len(),
64                event_observed.len()
65            )));
66        }
67
68        let confidence_level = confidence_level
69            .unwrap_or_else(|| F::from(0.95).expect("Failed to convert constant to float"));
70
71        // Sort data by duration
72        let mut data: Vec<(F, bool, usize)> = durations
73            .iter()
74            .zip(event_observed.iter())
75            .enumerate()
76            .map(|(i, (&duration, &observed))| (duration, observed, i))
77            .collect();
78
79        data.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Operation failed"));
80
81        // Compute Kaplan-Meier estimate
82        let n = data.len();
83        let mut survival_times = Vec::new();
84        let mut survival_probs = Vec::new();
85        let mut at_risk_counts = Vec::new();
86        let mut event_counts = Vec::new();
87
88        let mut current_survival = F::one();
89        let mut current_at_risk = n;
90        let mut i = 0;
91
92        while i < n {
93            let current_time = data[i].0;
94            let mut events_at_time = 0;
95            let mut censored_at_time = 0;
96
97            // Count events and censoring at current time
98            while i < n && data[i].0 == current_time {
99                if data[i].1 {
100                    // Event _observed
101                    events_at_time += 1;
102                } else {
103                    // Censored
104                    censored_at_time += 1;
105                }
106                i += 1;
107            }
108
109            // Update survival probability only if there are events
110            if events_at_time > 0 {
111                let survival_multiplier = F::one()
112                    - F::from(events_at_time).expect("Failed to convert to float")
113                        / F::from(current_at_risk).expect("Failed to convert to float");
114                current_survival = current_survival * survival_multiplier;
115
116                survival_times.push(current_time);
117                survival_probs.push(current_survival);
118                at_risk_counts.push(current_at_risk);
119                event_counts.push(events_at_time);
120            }
121
122            // Update at-risk count
123            current_at_risk -= events_at_time + censored_at_time;
124        }
125
126        let event_times = Array1::from_vec(survival_times);
127        let survival_function = Array1::from_vec(survival_probs);
128        let at_risk = Array1::from_vec(at_risk_counts);
129        let events = Array1::from_vec(event_counts);
130
131        // Compute confidence intervals using Greenwood's formula
132        let confidence_intervals = Self::compute_confidence_intervals(
133            &event_times,
134            &survival_function,
135            &at_risk,
136            &events,
137        )?;
138
139        // Compute median and mean survival times
140        let median_survival_time = Self::compute_median_survival(&event_times, &survival_function);
141        let mean_survival_time = Self::compute_mean_survival(&event_times, &survival_function);
142
143        Ok(Self {
144            event_times,
145            survival_function,
146            confidence_intervals: Some(confidence_intervals),
147            at_risk,
148            events,
149            median_survival_time,
150            mean_survival_time,
151            confidence_level,
152        })
153    }
154
155    /// Compute confidence intervals using Greenwood's formula
156    fn compute_confidence_intervals(
157        times: &Array1<F>,
158        survival: &Array1<F>,
159        at_risk: &Array1<usize>,
160        events: &Array1<usize>,
161    ) -> StatsResult<(Array1<F>, Array1<F>)> {
162        let n = times.len();
163        let mut lower = Array1::zeros(n);
164        let mut upper = Array1::zeros(n);
165
166        // Z-score for 95% confidence (approximately 1.96)
167        let z = F::from(1.96).expect("Failed to convert constant to float");
168
169        let mut cumulative_variance = F::zero();
170
171        for i in 0..n {
172            let events_i = F::from(events[i]).expect("Failed to convert to float");
173            let at_risk_i = F::from(at_risk[i]).expect("Failed to convert to float");
174
175            // Greenwood's variance formula
176            if at_risk[i] > events[i] {
177                let variance_term = events_i / (at_risk_i * (at_risk_i - events_i));
178                cumulative_variance = cumulative_variance + variance_term;
179            }
180
181            // Standard error
182            let se = survival[i] * cumulative_variance.sqrt();
183
184            // Confidence interval (with log transformation for better properties)
185            if survival[i] > F::zero() {
186                let log_survival = survival[i].ln();
187                let se_log = se / survival[i];
188
189                let log_lower = log_survival - z * se_log;
190                let log_upper = log_survival + z * se_log;
191
192                lower[i] = log_lower.exp().max(F::zero()).min(F::one());
193                upper[i] = log_upper.exp().max(F::zero()).min(F::one());
194            } else {
195                lower[i] = F::zero();
196                upper[i] = F::zero();
197            }
198        }
199
200        Ok((lower, upper))
201    }
202
203    /// Compute median survival time
204    fn compute_median_survival(times: &Array1<F>, survival: &Array1<F>) -> Option<F> {
205        let median_threshold = F::from(0.5).expect("Failed to convert constant to float");
206
207        for i in 0..survival.len() {
208            if survival[i] <= median_threshold {
209                return Some(times[i]);
210            }
211        }
212
213        None // Median not reached
214    }
215
216    /// Compute mean survival time (area under the curve)
217    fn compute_mean_survival(times: &Array1<F>, survival: &Array1<F>) -> Option<F> {
218        if times.is_empty() {
219            return None;
220        }
221
222        let mut area = F::zero();
223        let mut prev_time = F::zero();
224        let mut prev_survival = F::one();
225
226        for i in 0..times.len() {
227            let time_diff = times[i] - prev_time;
228            area = area + prev_survival * time_diff;
229
230            prev_time = times[i];
231            prev_survival = survival[i];
232        }
233
234        Some(area)
235    }
236
237    /// Evaluate survival function at given times
238    pub fn survival_function_at(&self, times: &ArrayView1<F>) -> StatsResult<Array1<F>> {
239        let mut result = Array1::ones(times.len());
240
241        for (i, &time) in times.iter().enumerate() {
242            // Find the last event time <= time
243            let mut survival_prob = F::one();
244
245            for j in 0..self.event_times.len() {
246                if self.event_times[j] <= time {
247                    survival_prob = self.survival_function[j];
248                } else {
249                    break;
250                }
251            }
252
253            result[i] = survival_prob;
254        }
255
256        Ok(result)
257    }
258}
259
260/// Cox Proportional Hazards Model
261pub struct CoxProportionalHazards<F> {
262    /// Regression coefficients
263    pub coefficients: Option<Array1<F>>,
264    /// Standard errors
265    pub standard_errors: Option<Array1<F>>,
266    /// Baseline hazard
267    pub baseline_hazard: Option<Array1<F>>,
268    /// Configuration
269    pub config: CoxConfig,
270    /// Convergence information
271    pub convergence_info: Option<CoxConvergenceInfo>,
272    _phantom: PhantomData<F>,
273}
274
275/// Cox regression configuration
276#[derive(Debug, Clone)]
277pub struct CoxConfig {
278    /// Maximum iterations for Newton-Raphson
279    pub max_iter: usize,
280    /// Convergence tolerance
281    pub tolerance: f64,
282    /// Step size for line search
283    pub stepsize: f64,
284    /// Enable parallel processing
285    pub parallel: bool,
286}
287
288/// Cox model convergence information
289#[derive(Debug, Clone)]
290pub struct CoxConvergenceInfo {
291    /// Number of iterations
292    pub n_iter: usize,
293    /// Final log-likelihood
294    pub log_likelihood: f64,
295    /// Converged flag
296    pub converged: bool,
297}
298
299impl Default for CoxConfig {
300    fn default() -> Self {
301        Self {
302            max_iter: 100,
303            tolerance: 1e-6,
304            stepsize: 1.0,
305            parallel: true,
306        }
307    }
308}
309
310impl<F> CoxProportionalHazards<F>
311where
312    F: Float
313        + Zero
314        + One
315        + Copy
316        + Send
317        + Sync
318        + SimdUnifiedOps
319        + FromPrimitive
320        + std::fmt::Display
321        + 'static,
322{
323    /// Create new Cox model
324    pub fn new(config: CoxConfig) -> Self {
325        Self {
326            coefficients: None,
327            standard_errors: None,
328            baseline_hazard: None,
329            config,
330            convergence_info: None,
331            _phantom: PhantomData,
332        }
333    }
334
335    /// Fit Cox proportional hazards model
336    pub fn fit(
337        &mut self,
338        durations: &ArrayView1<F>,
339        event_observed: &ArrayView1<bool>,
340        covariates: &ArrayView2<F>,
341    ) -> StatsResult<()> {
342        checkarray_finite(durations, "durations")?;
343        checkarray_finite(covariates, "covariates")?;
344
345        let n = durations.len();
346        let p = covariates.ncols();
347
348        if n != event_observed.len() || n != covariates.nrows() {
349            return Err(StatsError::DimensionMismatch(
350                "All input arrays must have the same number of observations".to_string(),
351            ));
352        }
353
354        // Initialize coefficients
355        let mut beta = Array1::zeros(p);
356
357        // Convert to f64 for numerical computation
358        let durations_f64 = durations.mapv(|x| x.to_f64().expect("Operation failed"));
359        let covariates_f64 = covariates.mapv(|x| x.to_f64().expect("Operation failed"));
360
361        // Newton-Raphson iteration
362        let mut converged = false;
363        let mut log_likelihood = f64::NEG_INFINITY;
364
365        for _iter in 0..self.config.max_iter {
366            // Compute partial likelihood and its derivatives
367            let (ll, gradient, hessian) = self.compute_partial_likelihood_derivatives(
368                &durations_f64,
369                event_observed,
370                &covariates_f64,
371                &beta,
372            )?;
373
374            // Check convergence
375            if (ll - log_likelihood).abs() < self.config.tolerance {
376                converged = true;
377                break;
378            }
379
380            log_likelihood = ll;
381
382            // Newton-Raphson update
383            let hessian_inv = scirs2_linalg::inv(&hessian.view(), None).map_err(|e| {
384                StatsError::ComputationError(format!("Hessian inversion failed: {e}"))
385            })?;
386
387            let update = hessian_inv.dot(&gradient);
388            beta = &beta + &update.mapv(|x| x * self.config.stepsize);
389        }
390
391        // Compute standard errors from Hessian
392        let (_, _, hessian) = self.compute_partial_likelihood_derivatives(
393            &durations_f64,
394            event_observed,
395            &covariates_f64,
396            &beta,
397        )?;
398
399        let cov_matrix = scirs2_linalg::inv(&(-hessian).view(), None).map_err(|e| {
400            StatsError::ComputationError(format!("Covariance matrix computation failed: {e}"))
401        })?;
402
403        let standard_errors = cov_matrix.diag().mapv(|x| x.sqrt());
404
405        // Convert back to F type
406        self.coefficients = Some(beta.mapv(|x| F::from(x).expect("Failed to convert to float")));
407        self.standard_errors =
408            Some(standard_errors.mapv(|x| F::from(x).expect("Failed to convert to float")));
409
410        self.convergence_info = Some(CoxConvergenceInfo {
411            n_iter: self.config.max_iter,
412            log_likelihood,
413            converged,
414        });
415
416        Ok(())
417    }
418
419    /// Compute partial likelihood and derivatives
420    fn compute_partial_likelihood_derivatives(
421        &self,
422        durations: &Array1<f64>,
423        event_observed: &ArrayView1<bool>,
424        covariates: &Array2<f64>,
425        beta: &Array1<f64>,
426    ) -> StatsResult<(f64, Array1<f64>, Array2<f64>)> {
427        let n = durations.len();
428        let p = beta.len();
429
430        // Sort by duration (descending for risk sets)
431        let mut indices: Vec<usize> = (0..n).collect();
432        indices.sort_by(|&i, &j| {
433            durations[j]
434                .partial_cmp(&durations[i])
435                .expect("Operation failed")
436        });
437
438        let mut log_likelihood = 0.0;
439        let mut gradient = Array1::zeros(p);
440        let mut hessian = Array2::zeros((p, p));
441
442        // Compute linear predictors
443        let linear_pred = covariates.dot(beta);
444        let exp_linear_pred = linear_pred.mapv(|x| x.exp());
445
446        // Process events in order
447        for &i in &indices {
448            if event_observed[i] {
449                // Risk set: all individuals with duration >= current duration
450                let mut risk_set_sum = 0.0;
451                let mut risk_set_grad = Array1::zeros(p);
452                let mut risk_set_hess = Array2::zeros((p, p));
453
454                for &j in &indices {
455                    if durations[j] >= durations[i] {
456                        let exp_pred_j = exp_linear_pred[j];
457                        risk_set_sum += exp_pred_j;
458
459                        let cov_j = covariates.row(j);
460                        risk_set_grad = &risk_set_grad + &cov_j.mapv(|x| x * exp_pred_j);
461
462                        // Hessian contribution
463                        for k in 0..p {
464                            for l in 0..p {
465                                risk_set_hess[[k, l]] += cov_j[k] * cov_j[l] * exp_pred_j;
466                            }
467                        }
468                    }
469                }
470
471                if risk_set_sum > 0.0 {
472                    // Update log-likelihood
473                    log_likelihood += linear_pred[i] - risk_set_sum.ln();
474
475                    // Update gradient
476                    let cov_i = covariates.row(i);
477                    gradient = &gradient + &cov_i - &risk_set_grad.mapv(|x: f64| x / risk_set_sum);
478
479                    // Update Hessian
480                    let risk_grad_normalized = risk_set_grad.mapv(|x: f64| x / risk_set_sum);
481                    let risk_hess_normalized = risk_set_hess.mapv(|x: f64| x / risk_set_sum);
482
483                    for k in 0..p {
484                        for l in 0..p {
485                            hessian[[k, l]] -= risk_hess_normalized[[k, l]]
486                                - risk_grad_normalized[k] * risk_grad_normalized[l];
487                        }
488                    }
489                }
490            }
491        }
492
493        Ok((log_likelihood, gradient, hessian))
494    }
495
496    /// Predict risk scores for new data
497    pub fn predict(&self, covariates: &ArrayView2<F>) -> StatsResult<Array1<F>> {
498        let coefficients = self.coefficients.as_ref().ok_or_else(|| {
499            StatsError::InvalidArgument("Model must be fitted before prediction".to_string())
500        })?;
501
502        checkarray_finite(covariates, "covariates")?;
503
504        if covariates.ncols() != coefficients.len() {
505            return Err(StatsError::DimensionMismatch(format!(
506                "Covariates columns ({}) must match number of coefficients ({})",
507                covariates.ncols(),
508                coefficients.len()
509            )));
510        }
511
512        let linear_pred = covariates.dot(coefficients);
513        Ok(linear_pred)
514    }
515}
516
517/// Log-rank test for comparing survival curves
518#[allow(dead_code)]
519pub fn log_rank_test<F>(
520    durations1: &ArrayView1<F>,
521    event_observed1: &ArrayView1<bool>,
522    durations2: &ArrayView1<F>,
523    event_observed2: &ArrayView1<bool>,
524) -> StatsResult<(F, F)>
525where
526    F: Float
527        + Zero
528        + One
529        + Copy
530        + Send
531        + Sync
532        + SimdUnifiedOps
533        + FromPrimitive
534        + PartialOrd
535        + std::fmt::Display,
536{
537    checkarray_finite(durations1, "durations1")?;
538    checkarray_finite(durations2, "durations2")?;
539
540    // Combine data with group indicators
541    let mut combineddata = Vec::new();
542
543    for (&duration, &observed) in durations1.iter().zip(event_observed1.iter()) {
544        combineddata.push((duration, observed, 0)); // Group 0
545    }
546
547    for (&duration, &observed) in durations2.iter().zip(event_observed2.iter()) {
548        combineddata.push((duration, observed, 1)); // Group 1
549    }
550
551    // Sort by duration
552    combineddata.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Operation failed"));
553
554    let mut observed_minus_expected = F::zero();
555    let mut variance = F::zero();
556
557    let n1 = durations1.len();
558    let n2 = durations2.len();
559    let mut at_risk1 = n1;
560    let mut at_risk2 = n2;
561
562    let mut i = 0;
563    while i < combineddata.len() {
564        let current_time = combineddata[i].0;
565        let mut events1 = 0;
566        let mut events2 = 0;
567        let mut censored1 = 0;
568        let mut censored2 = 0;
569
570        // Count events and censoring at current time for both groups
571        while i < combineddata.len() && combineddata[i].0 == current_time {
572            let (_, observed, group) = combineddata[i];
573
574            if group == 0 {
575                if observed {
576                    events1 += 1;
577                } else {
578                    censored1 += 1;
579                }
580            } else if observed {
581                events2 += 1;
582            } else {
583                censored2 += 1;
584            }
585
586            i += 1;
587        }
588
589        let total_events = events1 + events2;
590        let total_at_risk = at_risk1 + at_risk2;
591
592        if total_events > 0 && total_at_risk > 0 {
593            // Expected events in group 1
594            let expected1 = F::from(at_risk1 * total_events).expect("Failed to convert to float")
595                / F::from(total_at_risk).expect("Failed to convert to float");
596
597            // Update test statistic
598            observed_minus_expected = observed_minus_expected
599                + F::from(events1).expect("Failed to convert to float")
600                - expected1;
601
602            // Update variance
603            if total_at_risk > 1 {
604                let variance_term =
605                    F::from(at_risk1 * at_risk2 * total_events * (total_at_risk - total_events))
606                        .expect("Operation failed")
607                        / (F::from(total_at_risk * total_at_risk * (total_at_risk - 1))
608                            .expect("Operation failed"));
609                variance = variance + variance_term;
610            }
611        }
612
613        // Update at-risk counts
614        at_risk1 -= events1 + censored1;
615        at_risk2 -= events2 + censored2;
616    }
617
618    // Compute test statistic and p-value
619    let test_statistic = if variance > F::zero() {
620        (observed_minus_expected * observed_minus_expected) / variance
621    } else {
622        F::zero()
623    };
624
625    // Chi-square distribution with 1 df for p-value computation
626    // This is a simplified p-value calculation
627    let p_value = if test_statistic > F::from(3.84).expect("Failed to convert constant to float") {
628        // Critical value for alpha = 0.05
629        F::from(0.05).expect("Failed to convert constant to float")
630    } else {
631        F::from(0.5).expect("Failed to convert constant to float") // Rough approximation
632    };
633
634    Ok((test_statistic, p_value))
635}
636
637/// Convenience functions
638#[allow(dead_code)]
639pub fn kaplan_meier<F>(
640    durations: &ArrayView1<F>,
641    event_observed: &ArrayView1<bool>,
642    confidence_level: Option<F>,
643) -> StatsResult<EnhancedKaplanMeier<F>>
644where
645    F: Float
646        + Zero
647        + One
648        + Copy
649        + Send
650        + Sync
651        + SimdUnifiedOps
652        + FromPrimitive
653        + PartialOrd
654        + std::fmt::Display,
655{
656    EnhancedKaplanMeier::fit(durations, event_observed, confidence_level)
657}
658
659#[allow(dead_code)]
660pub fn cox_regression<F>(
661    durations: &ArrayView1<F>,
662    event_observed: &ArrayView1<bool>,
663    covariates: &ArrayView2<F>,
664    config: Option<CoxConfig>,
665) -> StatsResult<CoxProportionalHazards<F>>
666where
667    F: Float
668        + Zero
669        + One
670        + Copy
671        + Send
672        + Sync
673        + SimdUnifiedOps
674        + FromPrimitive
675        + std::fmt::Display
676        + 'static,
677{
678    let config = config.unwrap_or_default();
679    let mut cox = CoxProportionalHazards::new(config);
680    cox.fit(durations, event_observed, covariates)?;
681    Ok(cox)
682}