Skip to main content

scirs2_stats/
survival_api.rs

1//! Canonical Survival Analysis API
2//!
3//! Wraps the core survival analysis types in `survival::*` with a standardised
4//! interface that matches the SciRS2 public API requirements:
5//!
6//! - [`KMCurve`]       – Kaplan-Meier result with `survival_function(t)` and
7//!                       `confidence_interval(t, alpha)` (Greenwood formula)
8//! - [`NACurve`]       – Nelson-Aalen result with `survival_function(t)` and
9//!                       `confidence_interval(t, alpha)`
10//! - [`log_rank_test`] – two-sample log-rank test → `(statistic, p_value)`
11//! - [`CoxPHModel`]    – Cox proportional hazards fitted via Newton-Raphson
12
13use crate::error::{StatsError, StatsResult};
14use crate::survival::{CoxPH, KaplanMeier, NelsonAalen};
15use scirs2_core::ndarray::Array2;
16
17// ---------------------------------------------------------------------------
18// Normal quantile (Beasley-Springer-Moro rational approximation)
19// ---------------------------------------------------------------------------
20
21fn norm_ppf(p: f64) -> f64 {
22    let p = p.clamp(1e-15, 1.0 - 1e-15);
23    let q = p - 0.5;
24    if q.abs() <= 0.42 {
25        let r = q * q;
26        q * ((((-25.445_87 * r + 41.391_663) * r - 18.615_43) * r + 2.506_628)
27            / ((((3.130_347 * r - 21.060_244) * r + 23.083_928) * r - 8.476_377) * r + 1.0))
28    } else {
29        let r = if q < 0.0 { p } else { 1.0 - p };
30        let r = (-r.ln()).sqrt();
31        let x = (((2.321_213_5 * r + 4.850_091_7) * r - 2.297_460_0) * r - 2.787_688_0)
32            / ((1.637_547_9 * r + 3.543_889_2) * r + 1.0);
33        if q < 0.0 {
34            -x
35        } else {
36            x
37        }
38    }
39}
40
41// ---------------------------------------------------------------------------
42// KMCurve
43// ---------------------------------------------------------------------------
44
45/// The result of fitting the Kaplan-Meier estimator.
46///
47/// Exposes `survival_function(t)` and `confidence_interval(t, alpha)`.
48pub struct KMCurve {
49    km: KaplanMeier,
50}
51
52impl KMCurve {
53    /// Fit the Kaplan-Meier estimator.
54    ///
55    /// # Arguments
56    /// * `times`  – observed event/censoring times (must be ≥ 0, finite).
57    /// * `events` – `true` if the observation is an actual event (uncensored).
58    pub fn fit(times: &[f64], events: &[bool]) -> StatsResult<Self> {
59        let km = KaplanMeier::fit(times, events)?;
60        Ok(Self { km })
61    }
62
63    /// Evaluate the Kaplan-Meier survival function S(t) = P(T > t).
64    pub fn survival_function(&self, t: f64) -> f64 {
65        self.km.survival_at(t)
66    }
67
68    /// Compute a pointwise Greenwood confidence interval for S(t).
69    ///
70    /// Uses the log-log transform for better small-sample coverage.
71    ///
72    /// # Arguments
73    /// * `t`     – time at which to evaluate the CI.
74    /// * `alpha` – significance level (e.g. 0.05 for a 95% CI).
75    ///
76    /// # Returns
77    /// `(lower, upper)` – both in \[0, 1\].
78    pub fn confidence_interval(&self, t: f64, alpha: f64) -> StatsResult<(f64, f64)> {
79        if alpha <= 0.0 || alpha >= 1.0 {
80            return Err(StatsError::InvalidArgument(format!(
81                "alpha must be in (0, 1), got {alpha}"
82            )));
83        }
84        let s = self.survival_function(t);
85        if s <= 0.0 || s >= 1.0 {
86            return Ok((s.clamp(0.0, 1.0), s.clamp(0.0, 1.0)));
87        }
88
89        // Greenwood cumulative variance Σ d_k / (n_k (n_k - d_k)) up to time t
90        let greenwood: f64 = self
91            .km
92            .times
93            .iter()
94            .enumerate()
95            .take_while(|(_, &tk)| tk <= t)
96            .map(|(k, _)| {
97                let n_k = self.km.n_at_risk[k] as f64;
98                let d_k = self.km.n_events[k] as f64;
99                if n_k > d_k {
100                    d_k / (n_k * (n_k - d_k))
101                } else {
102                    0.0
103                }
104            })
105            .sum();
106
107        if greenwood == 0.0 {
108            return Ok((s, s));
109        }
110
111        let z = norm_ppf(1.0 - alpha / 2.0);
112        let ln_s = s.ln();
113        let se_ll = (greenwood / (ln_s * ln_s)).sqrt();
114        let log_log_s = (-ln_s).ln();
115
116        let ll_lo = log_log_s - z * se_ll;
117        let ll_hi = log_log_s + z * se_ll;
118
119        // Back-transform: S = exp(-exp(θ)) is *decreasing* in θ
120        let lower = (-ll_hi.exp()).exp().clamp(0.0, 1.0);
121        let upper = (-ll_lo.exp()).exp().clamp(0.0, 1.0);
122
123        Ok((lower.min(upper), lower.max(upper)))
124    }
125
126    /// Median survival time (smallest t with S(t) ≤ 0.5).
127    pub fn median_survival(&self) -> Option<f64> {
128        self.km.median_survival()
129    }
130
131    /// Mean survival time.
132    pub fn mean_survival(&self) -> f64 {
133        self.km.mean_survival()
134    }
135}
136
137// ---------------------------------------------------------------------------
138// NACurve
139// ---------------------------------------------------------------------------
140
141/// The result of fitting the Nelson-Aalen estimator.
142///
143/// Stores hazard increments and at-risk counts for variance computation.
144pub struct NACurve {
145    na: NelsonAalen,
146    /// Hazard increments Δ H(t_k) = d_k / n_k at each event time.
147    hazard_increments: Vec<f64>,
148    /// Number at risk n_k at each event time.
149    at_risk: Vec<usize>,
150}
151
152impl NACurve {
153    /// Fit the Nelson-Aalen estimator.
154    ///
155    /// # Arguments
156    /// * `times`  – observed times (must be ≥ 0, finite).
157    /// * `events` – `true` if the event occurred (uncensored).
158    pub fn fit(times: &[f64], events: &[bool]) -> StatsResult<Self> {
159        if times.is_empty() {
160            return Err(StatsError::InvalidArgument(
161                "times array cannot be empty".to_string(),
162            ));
163        }
164        if times.len() != events.len() {
165            return Err(StatsError::InvalidArgument(format!(
166                "times ({}) and events ({}) must have equal length",
167                times.len(),
168                events.len()
169            )));
170        }
171        for (i, &t) in times.iter().enumerate() {
172            if !t.is_finite() {
173                return Err(StatsError::InvalidArgument(format!(
174                    "times[{i}] is not finite: {t}"
175                )));
176            }
177            if t < 0.0 {
178                return Err(StatsError::InvalidArgument(format!(
179                    "times[{i}] is negative: {t}"
180                )));
181            }
182        }
183
184        // Sort by time (ties: events before censored)
185        let mut pairs: Vec<(f64, bool)> =
186            times.iter().copied().zip(events.iter().copied()).collect();
187        pairs.sort_by(|a, b| {
188            a.0.partial_cmp(&b.0)
189                .unwrap_or(std::cmp::Ordering::Equal)
190                .then(b.1.cmp(&a.1))
191        });
192
193        let n = pairs.len();
194        let mut at_risk_count = n;
195        let mut hazard_increments = Vec::new();
196        let mut at_risk_vec = Vec::new();
197        let mut idx = 0;
198
199        while idx < pairs.len() {
200            let t = pairs[idx].0;
201            let mut d = 0_usize;
202            let mut c = 0_usize;
203            while idx < pairs.len() && (pairs[idx].0 - t).abs() < 1e-12 {
204                if pairs[idx].1 {
205                    d += 1;
206                } else {
207                    c += 1;
208                }
209                idx += 1;
210            }
211            if d > 0 && at_risk_count > 0 {
212                hazard_increments.push((d as f64) / (at_risk_count as f64));
213                at_risk_vec.push(at_risk_count);
214            }
215            at_risk_count -= d + c;
216        }
217
218        let na = NelsonAalen::fit(times, events)?;
219        Ok(Self {
220            na,
221            hazard_increments,
222            at_risk: at_risk_vec,
223        })
224    }
225
226    /// Evaluate the survival function S(t) = exp(−Ĥ(t)).
227    pub fn survival_function(&self, t: f64) -> f64 {
228        self.na.survival_at(t)
229    }
230
231    /// Evaluate the cumulative hazard Ĥ(t).
232    pub fn cumulative_hazard(&self, t: f64) -> f64 {
233        self.na.hazard_at(t)
234    }
235
236    /// Compute a pointwise confidence interval for S(t).
237    ///
238    /// Uses the log-transform CI for the cumulative hazard.
239    ///
240    /// # Arguments
241    /// * `t`     – evaluation time.
242    /// * `alpha` – significance level (e.g. 0.05 → 95% CI).
243    ///
244    /// # Returns
245    /// `(lower, upper)`.
246    pub fn confidence_interval(&self, t: f64, alpha: f64) -> StatsResult<(f64, f64)> {
247        if alpha <= 0.0 || alpha >= 1.0 {
248            return Err(StatsError::InvalidArgument(format!(
249                "alpha must be in (0, 1), got {alpha}"
250            )));
251        }
252        let s = self.survival_function(t);
253        if s <= 0.0 || s >= 1.0 {
254            return Ok((s.clamp(0.0, 1.0), s.clamp(0.0, 1.0)));
255        }
256
257        // Variance of Ĥ(t): Var[Ĥ(t)] ≈ Σ_{t_k ≤ t} d_k / n_k²
258        // = Σ (increment_k) / n_k
259        let var_h: f64 = self
260            .na
261            .times
262            .iter()
263            .enumerate()
264            .take_while(|(_, &tk)| tk <= t)
265            .map(|(k, _)| {
266                if k < self.at_risk.len() && self.at_risk[k] > 0 {
267                    self.hazard_increments[k] / self.at_risk[k] as f64
268                } else {
269                    0.0
270                }
271            })
272            .sum();
273
274        if var_h == 0.0 {
275            return Ok((s, s));
276        }
277
278        let h = -s.ln();
279        let z = norm_ppf(1.0 - alpha / 2.0);
280        let se = var_h.sqrt();
281
282        // Log-transform CI for H: (H / c, H * c) where c = exp(z se / H)
283        let c = (z * se / h).exp();
284        let h_lo = h / c;
285        let h_hi = h * c;
286
287        let upper = (-h_lo).exp().clamp(0.0, 1.0);
288        let lower = (-h_hi).exp().clamp(0.0, 1.0);
289
290        Ok((lower.min(upper), lower.max(upper)))
291    }
292}
293
294// ---------------------------------------------------------------------------
295// log_rank_test
296// ---------------------------------------------------------------------------
297
298/// Log-rank test comparing survival between two independent groups.
299///
300/// # Arguments
301/// * `group1_times`  – observed times for group 1.
302/// * `group1_events` – event indicators for group 1 (`true` = event).
303/// * `group2_times`  – observed times for group 2.
304/// * `group2_events` – event indicators for group 2.
305///
306/// # Returns
307/// `(statistic, p_value)` – chi-squared test statistic and two-sided p-value.
308pub fn log_rank_test(
309    group1_times: &[f64],
310    group1_events: &[bool],
311    group2_times: &[f64],
312    group2_events: &[bool],
313) -> StatsResult<(f64, f64)> {
314    let result =
315        KaplanMeier::log_rank_test(group1_times, group1_events, group2_times, group2_events)?;
316    Ok(result)
317}
318
319// ---------------------------------------------------------------------------
320// CoxPHModel
321// ---------------------------------------------------------------------------
322
323/// Cox proportional hazards model.
324///
325/// Wraps [`CoxPH`] with a slice-based interface.
326pub struct CoxPHModel {
327    inner: CoxPH,
328}
329
330impl CoxPHModel {
331    /// Fit the Cox PH model via Newton-Raphson partial likelihood optimisation.
332    ///
333    /// # Arguments
334    /// * `times`      – observed event/censoring times.
335    /// * `events`     – event indicators.
336    /// * `covariates` – n_samples × n_features covariate matrix.
337    pub fn fit(times: &[f64], events: &[bool], covariates: &Array2<f64>) -> StatsResult<Self> {
338        let inner = CoxPH::fit(times, events, covariates)?;
339        Ok(Self { inner })
340    }
341
342    /// Log hazard-ratio coefficients β.
343    pub fn coefficients(&self) -> Vec<f64> {
344        self.inner.coefficients.iter().copied().collect()
345    }
346
347    /// Standard errors of the coefficients.
348    pub fn standard_errors(&self) -> Vec<f64> {
349        self.inner.std_errors.iter().copied().collect()
350    }
351
352    /// Two-sided Wald test p-values.
353    pub fn p_values(&self) -> Vec<f64> {
354        self.inner.p_values.iter().copied().collect()
355    }
356
357    /// Hazard ratios exp(β).
358    pub fn hazard_ratios(&self) -> Vec<f64> {
359        self.inner.hazard_ratio().iter().copied().collect()
360    }
361
362    /// Predict risk score exp(Xβ) for a single observation given as a slice.
363    pub fn predict_risk(&self, x: &[f64]) -> f64 {
364        use scirs2_core::ndarray::Array1;
365        let arr = Array1::from_vec(x.to_vec());
366        self.inner.predict_risk(&arr)
367    }
368
369    /// Concordance index (C-statistic) evaluated on provided data.
370    ///
371    /// Requires the covariate matrix used at prediction time.
372    pub fn concordance_index(
373        &self,
374        times: &[f64],
375        events: &[bool],
376        covariates: &Array2<f64>,
377    ) -> f64 {
378        self.inner.concordance_index(times, events, covariates)
379    }
380
381    /// Partial log-likelihood at convergence.
382    pub fn log_likelihood(&self) -> f64 {
383        self.inner.log_likelihood
384    }
385
386    /// Number of Newton-Raphson iterations performed.
387    pub fn n_iterations(&self) -> usize {
388        self.inner.n_iter
389    }
390}
391
392// ---------------------------------------------------------------------------
393// Tests
394// ---------------------------------------------------------------------------
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use scirs2_core::ndarray::Array2;
400
401    fn sample_data() -> (Vec<f64>, Vec<bool>) {
402        (
403            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
404            vec![
405                true, true, false, true, false, true, true, false, true, true,
406            ],
407        )
408    }
409
410    // ----- KMCurve -----
411
412    #[test]
413    fn test_kmcurve_survival_starts_at_one() {
414        let (times, events) = sample_data();
415        let km = KMCurve::fit(&times, &events).expect("fit failed");
416        assert_eq!(km.survival_function(0.0), 1.0);
417    }
418
419    #[test]
420    fn test_kmcurve_survival_bounded() {
421        let (times, events) = sample_data();
422        let km = KMCurve::fit(&times, &events).expect("fit failed");
423        for t in [0.0, 1.5, 5.0, 10.0, 20.0] {
424            let s = km.survival_function(t);
425            assert!(s >= 0.0 && s <= 1.0, "S({t}) = {s} out of [0,1]");
426        }
427    }
428
429    #[test]
430    fn test_kmcurve_survival_non_increasing() {
431        let (times, events) = sample_data();
432        let km = KMCurve::fit(&times, &events).expect("fit failed");
433        let ts = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 20.0];
434        let mut prev = 1.0_f64;
435        for &t in &ts {
436            let s = km.survival_function(t);
437            assert!(s <= prev + 1e-12, "S({t}) = {s} > S_prev = {prev}");
438            prev = s;
439        }
440    }
441
442    #[test]
443    fn test_kmcurve_confidence_interval_ordering() {
444        let (times, events) = sample_data();
445        let km = KMCurve::fit(&times, &events).expect("fit failed");
446        for t in [2.0, 5.0, 8.0] {
447            let (lo, hi) = km.confidence_interval(t, 0.05).expect("CI failed");
448            assert!(lo <= hi + 1e-10, "lo > hi at t={t}: {lo} {hi}");
449            assert!(lo >= 0.0 && hi <= 1.0);
450        }
451    }
452
453    #[test]
454    fn test_kmcurve_ci_invalid_alpha() {
455        let (times, events) = sample_data();
456        let km = KMCurve::fit(&times, &events).expect("fit failed");
457        assert!(km.confidence_interval(3.0, 0.0).is_err());
458        assert!(km.confidence_interval(3.0, 1.0).is_err());
459    }
460
461    // ----- NACurve -----
462
463    #[test]
464    fn test_nacurve_survival_starts_at_one() {
465        let (times, events) = sample_data();
466        let na = NACurve::fit(&times, &events).expect("fit failed");
467        assert_eq!(na.survival_function(0.0), 1.0);
468    }
469
470    #[test]
471    fn test_nacurve_survival_bounded() {
472        let (times, events) = sample_data();
473        let na = NACurve::fit(&times, &events).expect("fit failed");
474        for t in [0.0, 2.5, 6.0, 12.0] {
475            let s = na.survival_function(t);
476            assert!(s >= 0.0 && s <= 1.0, "S({t}) = {s} out of [0,1]");
477        }
478    }
479
480    #[test]
481    fn test_nacurve_confidence_interval_ordering() {
482        let (times, events) = sample_data();
483        let na = NACurve::fit(&times, &events).expect("fit failed");
484        let (lo, hi) = na.confidence_interval(5.0, 0.05).expect("CI failed");
485        assert!(lo <= hi + 1e-10, "lo > hi: {lo} {hi}");
486        assert!(lo >= 0.0 && hi <= 1.0);
487    }
488
489    #[test]
490    fn test_nacurve_ci_invalid_alpha() {
491        let (times, events) = sample_data();
492        let na = NACurve::fit(&times, &events).expect("fit failed");
493        assert!(na.confidence_interval(3.0, 0.0).is_err());
494        assert!(na.confidence_interval(3.0, 1.5).is_err());
495    }
496
497    // ----- log_rank_test -----
498
499    #[test]
500    fn test_log_rank_different_groups() {
501        let times1 = vec![1.0, 2.0, 3.0, 4.0, 5.0];
502        let events1 = vec![true, true, true, true, true];
503        let times2 = vec![6.0, 7.0, 8.0, 9.0, 10.0];
504        let events2 = vec![true, true, true, true, true];
505        let (stat, p) =
506            log_rank_test(&times1, &events1, &times2, &events2).expect("log_rank_test failed");
507        assert!(stat >= 0.0, "statistic should be non-negative");
508        assert!(p >= 0.0 && p <= 1.0, "p-value out of range: {p}");
509        assert!(p < 0.05, "expected significant difference, p = {p}");
510    }
511
512    #[test]
513    fn test_log_rank_identical_groups() {
514        let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
515        let events = vec![true, true, false, true, false, true];
516        let (stat, p) =
517            log_rank_test(&times, &events, &times, &events).expect("log_rank_test failed");
518        assert!(stat < 1e-10, "identical groups: stat={stat}");
519        assert!(p > 0.5, "identical groups should have large p={p}");
520    }
521
522    // ----- CoxPHModel -----
523
524    #[test]
525    fn test_coxph_fit_and_coefficients() {
526        let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
527        let events = vec![true, true, false, true, false, true, true, false];
528        let x = Array2::from_shape_vec((8, 1), vec![0.1, 0.5, 0.2, 0.8, 0.3, 0.9, 0.4, 0.7])
529            .expect("array failed");
530        let model = CoxPHModel::fit(&times, &events, &x).expect("fit failed");
531        assert_eq!(model.coefficients().len(), 1);
532        assert!(model.coefficients()[0].is_finite());
533    }
534
535    #[test]
536    fn test_coxph_log_likelihood_finite() {
537        let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
538        let events = vec![true, true, false, true, false, true, true, false];
539        let x = Array2::from_shape_vec((8, 1), vec![0.1, 0.5, 0.2, 0.8, 0.3, 0.9, 0.4, 0.7])
540            .expect("array failed");
541        let model = CoxPHModel::fit(&times, &events, &x).expect("fit failed");
542        assert!(model.log_likelihood().is_finite());
543    }
544
545    #[test]
546    fn test_coxph_hazard_ratios_positive() {
547        let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
548        let events = vec![true, true, false, true, false, true, true, false];
549        let x = Array2::from_shape_vec((8, 1), vec![0.1, 0.5, 0.2, 0.8, 0.3, 0.9, 0.4, 0.7])
550            .expect("array failed");
551        let model = CoxPHModel::fit(&times, &events, &x).expect("fit failed");
552        for &hr in model.hazard_ratios().iter() {
553            assert!(hr > 0.0, "HR should be positive, got {hr}");
554        }
555    }
556
557    #[test]
558    fn test_coxph_predict_risk_positive() {
559        let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
560        let events = vec![true, true, false, true, false, true, true, false];
561        let x = Array2::from_shape_vec((8, 1), vec![0.1, 0.5, 0.2, 0.8, 0.3, 0.9, 0.4, 0.7])
562            .expect("array failed");
563        let model = CoxPHModel::fit(&times, &events, &x).expect("fit failed");
564        let risk = model.predict_risk(&[0.5]);
565        assert!(risk > 0.0, "risk should be positive, got {risk}");
566    }
567
568    #[test]
569    fn test_coxph_concordance_index_in_range() {
570        let times = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
571        let events = vec![true, true, false, true, false, true, true, false];
572        let x_data = vec![0.1, 0.5, 0.2, 0.8, 0.3, 0.9, 0.4, 0.7];
573        let x = Array2::from_shape_vec((8, 1), x_data.clone()).expect("array failed");
574        let model = CoxPHModel::fit(&times, &events, &x).expect("fit failed");
575        let ci = model.concordance_index(&times, &events, &x);
576        assert!(ci >= 0.0 && ci <= 1.0, "C-index out of [0,1]: {ci}");
577    }
578}