1#![allow(dead_code)]
11
12use crate::error::{StatsError, StatsResult};
13use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
14use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
15use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
16use std::marker::PhantomData;
17
18#[derive(Debug, Clone)]
20pub struct EnhancedKaplanMeier<F> {
21 pub event_times: Array1<F>,
23 pub survival_function: Array1<F>,
25 pub confidence_intervals: Option<(Array1<F>, Array1<F>)>,
27 pub at_risk: Array1<usize>,
29 pub events: Array1<usize>,
31 pub median_survival_time: Option<F>,
33 pub mean_survival_time: Option<F>,
35 pub confidence_level: F,
37}
38
39impl<F> EnhancedKaplanMeier<F>
40where
41 F: Float
42 + Zero
43 + One
44 + Copy
45 + Send
46 + Sync
47 + SimdUnifiedOps
48 + FromPrimitive
49 + PartialOrd
50 + std::fmt::Display,
51{
52 pub fn fit(
54 durations: &ArrayView1<F>,
55 event_observed: &ArrayView1<bool>,
56 confidence_level: Option<F>,
57 ) -> StatsResult<Self> {
58 checkarray_finite(durations, "durations")?;
59
60 if durations.len() != event_observed.len() {
61 return Err(StatsError::DimensionMismatch(format!(
62 "Durations length ({}) must match event_observed length ({})",
63 durations.len(),
64 event_observed.len()
65 )));
66 }
67
68 let confidence_level = confidence_level
69 .unwrap_or_else(|| F::from(0.95).expect("Failed to convert constant to float"));
70
71 let mut data: Vec<(F, bool, usize)> = durations
73 .iter()
74 .zip(event_observed.iter())
75 .enumerate()
76 .map(|(i, (&duration, &observed))| (duration, observed, i))
77 .collect();
78
79 data.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Operation failed"));
80
81 let n = data.len();
83 let mut survival_times = Vec::new();
84 let mut survival_probs = Vec::new();
85 let mut at_risk_counts = Vec::new();
86 let mut event_counts = Vec::new();
87
88 let mut current_survival = F::one();
89 let mut current_at_risk = n;
90 let mut i = 0;
91
92 while i < n {
93 let current_time = data[i].0;
94 let mut events_at_time = 0;
95 let mut censored_at_time = 0;
96
97 while i < n && data[i].0 == current_time {
99 if data[i].1 {
100 events_at_time += 1;
102 } else {
103 censored_at_time += 1;
105 }
106 i += 1;
107 }
108
109 if events_at_time > 0 {
111 let survival_multiplier = F::one()
112 - F::from(events_at_time).expect("Failed to convert to float")
113 / F::from(current_at_risk).expect("Failed to convert to float");
114 current_survival = current_survival * survival_multiplier;
115
116 survival_times.push(current_time);
117 survival_probs.push(current_survival);
118 at_risk_counts.push(current_at_risk);
119 event_counts.push(events_at_time);
120 }
121
122 current_at_risk -= events_at_time + censored_at_time;
124 }
125
126 let event_times = Array1::from_vec(survival_times);
127 let survival_function = Array1::from_vec(survival_probs);
128 let at_risk = Array1::from_vec(at_risk_counts);
129 let events = Array1::from_vec(event_counts);
130
131 let confidence_intervals = Self::compute_confidence_intervals(
133 &event_times,
134 &survival_function,
135 &at_risk,
136 &events,
137 )?;
138
139 let median_survival_time = Self::compute_median_survival(&event_times, &survival_function);
141 let mean_survival_time = Self::compute_mean_survival(&event_times, &survival_function);
142
143 Ok(Self {
144 event_times,
145 survival_function,
146 confidence_intervals: Some(confidence_intervals),
147 at_risk,
148 events,
149 median_survival_time,
150 mean_survival_time,
151 confidence_level,
152 })
153 }
154
155 fn compute_confidence_intervals(
157 times: &Array1<F>,
158 survival: &Array1<F>,
159 at_risk: &Array1<usize>,
160 events: &Array1<usize>,
161 ) -> StatsResult<(Array1<F>, Array1<F>)> {
162 let n = times.len();
163 let mut lower = Array1::zeros(n);
164 let mut upper = Array1::zeros(n);
165
166 let z = F::from(1.96).expect("Failed to convert constant to float");
168
169 let mut cumulative_variance = F::zero();
170
171 for i in 0..n {
172 let events_i = F::from(events[i]).expect("Failed to convert to float");
173 let at_risk_i = F::from(at_risk[i]).expect("Failed to convert to float");
174
175 if at_risk[i] > events[i] {
177 let variance_term = events_i / (at_risk_i * (at_risk_i - events_i));
178 cumulative_variance = cumulative_variance + variance_term;
179 }
180
181 let se = survival[i] * cumulative_variance.sqrt();
183
184 if survival[i] > F::zero() {
186 let log_survival = survival[i].ln();
187 let se_log = se / survival[i];
188
189 let log_lower = log_survival - z * se_log;
190 let log_upper = log_survival + z * se_log;
191
192 lower[i] = log_lower.exp().max(F::zero()).min(F::one());
193 upper[i] = log_upper.exp().max(F::zero()).min(F::one());
194 } else {
195 lower[i] = F::zero();
196 upper[i] = F::zero();
197 }
198 }
199
200 Ok((lower, upper))
201 }
202
203 fn compute_median_survival(times: &Array1<F>, survival: &Array1<F>) -> Option<F> {
205 let median_threshold = F::from(0.5).expect("Failed to convert constant to float");
206
207 for i in 0..survival.len() {
208 if survival[i] <= median_threshold {
209 return Some(times[i]);
210 }
211 }
212
213 None }
215
216 fn compute_mean_survival(times: &Array1<F>, survival: &Array1<F>) -> Option<F> {
218 if times.is_empty() {
219 return None;
220 }
221
222 let mut area = F::zero();
223 let mut prev_time = F::zero();
224 let mut prev_survival = F::one();
225
226 for i in 0..times.len() {
227 let time_diff = times[i] - prev_time;
228 area = area + prev_survival * time_diff;
229
230 prev_time = times[i];
231 prev_survival = survival[i];
232 }
233
234 Some(area)
235 }
236
237 pub fn survival_function_at(&self, times: &ArrayView1<F>) -> StatsResult<Array1<F>> {
239 let mut result = Array1::ones(times.len());
240
241 for (i, &time) in times.iter().enumerate() {
242 let mut survival_prob = F::one();
244
245 for j in 0..self.event_times.len() {
246 if self.event_times[j] <= time {
247 survival_prob = self.survival_function[j];
248 } else {
249 break;
250 }
251 }
252
253 result[i] = survival_prob;
254 }
255
256 Ok(result)
257 }
258}
259
260pub struct CoxProportionalHazards<F> {
262 pub coefficients: Option<Array1<F>>,
264 pub standard_errors: Option<Array1<F>>,
266 pub baseline_hazard: Option<Array1<F>>,
268 pub config: CoxConfig,
270 pub convergence_info: Option<CoxConvergenceInfo>,
272 _phantom: PhantomData<F>,
273}
274
275#[derive(Debug, Clone)]
277pub struct CoxConfig {
278 pub max_iter: usize,
280 pub tolerance: f64,
282 pub stepsize: f64,
284 pub parallel: bool,
286}
287
288#[derive(Debug, Clone)]
290pub struct CoxConvergenceInfo {
291 pub n_iter: usize,
293 pub log_likelihood: f64,
295 pub converged: bool,
297}
298
299impl Default for CoxConfig {
300 fn default() -> Self {
301 Self {
302 max_iter: 100,
303 tolerance: 1e-6,
304 stepsize: 1.0,
305 parallel: true,
306 }
307 }
308}
309
310impl<F> CoxProportionalHazards<F>
311where
312 F: Float
313 + Zero
314 + One
315 + Copy
316 + Send
317 + Sync
318 + SimdUnifiedOps
319 + FromPrimitive
320 + std::fmt::Display
321 + 'static,
322{
323 pub fn new(config: CoxConfig) -> Self {
325 Self {
326 coefficients: None,
327 standard_errors: None,
328 baseline_hazard: None,
329 config,
330 convergence_info: None,
331 _phantom: PhantomData,
332 }
333 }
334
335 pub fn fit(
337 &mut self,
338 durations: &ArrayView1<F>,
339 event_observed: &ArrayView1<bool>,
340 covariates: &ArrayView2<F>,
341 ) -> StatsResult<()> {
342 checkarray_finite(durations, "durations")?;
343 checkarray_finite(covariates, "covariates")?;
344
345 let n = durations.len();
346 let p = covariates.ncols();
347
348 if n != event_observed.len() || n != covariates.nrows() {
349 return Err(StatsError::DimensionMismatch(
350 "All input arrays must have the same number of observations".to_string(),
351 ));
352 }
353
354 let mut beta = Array1::zeros(p);
356
357 let durations_f64 = durations.mapv(|x| x.to_f64().expect("Operation failed"));
359 let covariates_f64 = covariates.mapv(|x| x.to_f64().expect("Operation failed"));
360
361 let mut converged = false;
363 let mut log_likelihood = f64::NEG_INFINITY;
364
365 for _iter in 0..self.config.max_iter {
366 let (ll, gradient, hessian) = self.compute_partial_likelihood_derivatives(
368 &durations_f64,
369 event_observed,
370 &covariates_f64,
371 &beta,
372 )?;
373
374 if (ll - log_likelihood).abs() < self.config.tolerance {
376 converged = true;
377 break;
378 }
379
380 log_likelihood = ll;
381
382 let hessian_inv = scirs2_linalg::inv(&hessian.view(), None).map_err(|e| {
384 StatsError::ComputationError(format!("Hessian inversion failed: {e}"))
385 })?;
386
387 let update = hessian_inv.dot(&gradient);
388 beta = &beta + &update.mapv(|x| x * self.config.stepsize);
389 }
390
391 let (_, _, hessian) = self.compute_partial_likelihood_derivatives(
393 &durations_f64,
394 event_observed,
395 &covariates_f64,
396 &beta,
397 )?;
398
399 let cov_matrix = scirs2_linalg::inv(&(-hessian).view(), None).map_err(|e| {
400 StatsError::ComputationError(format!("Covariance matrix computation failed: {e}"))
401 })?;
402
403 let standard_errors = cov_matrix.diag().mapv(|x| x.sqrt());
404
405 self.coefficients = Some(beta.mapv(|x| F::from(x).expect("Failed to convert to float")));
407 self.standard_errors =
408 Some(standard_errors.mapv(|x| F::from(x).expect("Failed to convert to float")));
409
410 self.convergence_info = Some(CoxConvergenceInfo {
411 n_iter: self.config.max_iter,
412 log_likelihood,
413 converged,
414 });
415
416 Ok(())
417 }
418
419 fn compute_partial_likelihood_derivatives(
421 &self,
422 durations: &Array1<f64>,
423 event_observed: &ArrayView1<bool>,
424 covariates: &Array2<f64>,
425 beta: &Array1<f64>,
426 ) -> StatsResult<(f64, Array1<f64>, Array2<f64>)> {
427 let n = durations.len();
428 let p = beta.len();
429
430 let mut indices: Vec<usize> = (0..n).collect();
432 indices.sort_by(|&i, &j| {
433 durations[j]
434 .partial_cmp(&durations[i])
435 .expect("Operation failed")
436 });
437
438 let mut log_likelihood = 0.0;
439 let mut gradient = Array1::zeros(p);
440 let mut hessian = Array2::zeros((p, p));
441
442 let linear_pred = covariates.dot(beta);
444 let exp_linear_pred = linear_pred.mapv(|x| x.exp());
445
446 for &i in &indices {
448 if event_observed[i] {
449 let mut risk_set_sum = 0.0;
451 let mut risk_set_grad = Array1::zeros(p);
452 let mut risk_set_hess = Array2::zeros((p, p));
453
454 for &j in &indices {
455 if durations[j] >= durations[i] {
456 let exp_pred_j = exp_linear_pred[j];
457 risk_set_sum += exp_pred_j;
458
459 let cov_j = covariates.row(j);
460 risk_set_grad = &risk_set_grad + &cov_j.mapv(|x| x * exp_pred_j);
461
462 for k in 0..p {
464 for l in 0..p {
465 risk_set_hess[[k, l]] += cov_j[k] * cov_j[l] * exp_pred_j;
466 }
467 }
468 }
469 }
470
471 if risk_set_sum > 0.0 {
472 log_likelihood += linear_pred[i] - risk_set_sum.ln();
474
475 let cov_i = covariates.row(i);
477 gradient = &gradient + &cov_i - &risk_set_grad.mapv(|x: f64| x / risk_set_sum);
478
479 let risk_grad_normalized = risk_set_grad.mapv(|x: f64| x / risk_set_sum);
481 let risk_hess_normalized = risk_set_hess.mapv(|x: f64| x / risk_set_sum);
482
483 for k in 0..p {
484 for l in 0..p {
485 hessian[[k, l]] -= risk_hess_normalized[[k, l]]
486 - risk_grad_normalized[k] * risk_grad_normalized[l];
487 }
488 }
489 }
490 }
491 }
492
493 Ok((log_likelihood, gradient, hessian))
494 }
495
496 pub fn predict(&self, covariates: &ArrayView2<F>) -> StatsResult<Array1<F>> {
498 let coefficients = self.coefficients.as_ref().ok_or_else(|| {
499 StatsError::InvalidArgument("Model must be fitted before prediction".to_string())
500 })?;
501
502 checkarray_finite(covariates, "covariates")?;
503
504 if covariates.ncols() != coefficients.len() {
505 return Err(StatsError::DimensionMismatch(format!(
506 "Covariates columns ({}) must match number of coefficients ({})",
507 covariates.ncols(),
508 coefficients.len()
509 )));
510 }
511
512 let linear_pred = covariates.dot(coefficients);
513 Ok(linear_pred)
514 }
515}
516
517#[allow(dead_code)]
519pub fn log_rank_test<F>(
520 durations1: &ArrayView1<F>,
521 event_observed1: &ArrayView1<bool>,
522 durations2: &ArrayView1<F>,
523 event_observed2: &ArrayView1<bool>,
524) -> StatsResult<(F, F)>
525where
526 F: Float
527 + Zero
528 + One
529 + Copy
530 + Send
531 + Sync
532 + SimdUnifiedOps
533 + FromPrimitive
534 + PartialOrd
535 + std::fmt::Display,
536{
537 checkarray_finite(durations1, "durations1")?;
538 checkarray_finite(durations2, "durations2")?;
539
540 let mut combineddata = Vec::new();
542
543 for (&duration, &observed) in durations1.iter().zip(event_observed1.iter()) {
544 combineddata.push((duration, observed, 0)); }
546
547 for (&duration, &observed) in durations2.iter().zip(event_observed2.iter()) {
548 combineddata.push((duration, observed, 1)); }
550
551 combineddata.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("Operation failed"));
553
554 let mut observed_minus_expected = F::zero();
555 let mut variance = F::zero();
556
557 let n1 = durations1.len();
558 let n2 = durations2.len();
559 let mut at_risk1 = n1;
560 let mut at_risk2 = n2;
561
562 let mut i = 0;
563 while i < combineddata.len() {
564 let current_time = combineddata[i].0;
565 let mut events1 = 0;
566 let mut events2 = 0;
567 let mut censored1 = 0;
568 let mut censored2 = 0;
569
570 while i < combineddata.len() && combineddata[i].0 == current_time {
572 let (_, observed, group) = combineddata[i];
573
574 if group == 0 {
575 if observed {
576 events1 += 1;
577 } else {
578 censored1 += 1;
579 }
580 } else if observed {
581 events2 += 1;
582 } else {
583 censored2 += 1;
584 }
585
586 i += 1;
587 }
588
589 let total_events = events1 + events2;
590 let total_at_risk = at_risk1 + at_risk2;
591
592 if total_events > 0 && total_at_risk > 0 {
593 let expected1 = F::from(at_risk1 * total_events).expect("Failed to convert to float")
595 / F::from(total_at_risk).expect("Failed to convert to float");
596
597 observed_minus_expected = observed_minus_expected
599 + F::from(events1).expect("Failed to convert to float")
600 - expected1;
601
602 if total_at_risk > 1 {
604 let variance_term =
605 F::from(at_risk1 * at_risk2 * total_events * (total_at_risk - total_events))
606 .expect("Operation failed")
607 / (F::from(total_at_risk * total_at_risk * (total_at_risk - 1))
608 .expect("Operation failed"));
609 variance = variance + variance_term;
610 }
611 }
612
613 at_risk1 -= events1 + censored1;
615 at_risk2 -= events2 + censored2;
616 }
617
618 let test_statistic = if variance > F::zero() {
620 (observed_minus_expected * observed_minus_expected) / variance
621 } else {
622 F::zero()
623 };
624
625 let p_value = if test_statistic > F::from(3.84).expect("Failed to convert constant to float") {
628 F::from(0.05).expect("Failed to convert constant to float")
630 } else {
631 F::from(0.5).expect("Failed to convert constant to float") };
633
634 Ok((test_statistic, p_value))
635}
636
637#[allow(dead_code)]
639pub fn kaplan_meier<F>(
640 durations: &ArrayView1<F>,
641 event_observed: &ArrayView1<bool>,
642 confidence_level: Option<F>,
643) -> StatsResult<EnhancedKaplanMeier<F>>
644where
645 F: Float
646 + Zero
647 + One
648 + Copy
649 + Send
650 + Sync
651 + SimdUnifiedOps
652 + FromPrimitive
653 + PartialOrd
654 + std::fmt::Display,
655{
656 EnhancedKaplanMeier::fit(durations, event_observed, confidence_level)
657}
658
659#[allow(dead_code)]
660pub fn cox_regression<F>(
661 durations: &ArrayView1<F>,
662 event_observed: &ArrayView1<bool>,
663 covariates: &ArrayView2<F>,
664 config: Option<CoxConfig>,
665) -> StatsResult<CoxProportionalHazards<F>>
666where
667 F: Float
668 + Zero
669 + One
670 + Copy
671 + Send
672 + Sync
673 + SimdUnifiedOps
674 + FromPrimitive
675 + std::fmt::Display
676 + 'static,
677{
678 let config = config.unwrap_or_default();
679 let mut cox = CoxProportionalHazards::new(config);
680 cox.fit(durations, event_observed, covariates)?;
681 Ok(cox)
682}