1use crate::error::{StatsError, StatsResult as Result};
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::validation::*;
9
10#[derive(Debug, Clone)]
14pub struct KaplanMeierEstimator {
15 pub event_times: Array1<f64>,
17 pub survival_function: Array1<f64>,
19 pub confidence_intervals: Option<(Array1<f64>, Array1<f64>)>,
21 pub at_risk: Array1<usize>,
23 pub events: Array1<usize>,
25 pub median_survival_time: Option<f64>,
27}
28
29impl KaplanMeierEstimator {
30 pub fn fit(
40 durations: ArrayView1<f64>,
41 event_observed: ArrayView1<bool>,
42 confidence_level: Option<f64>,
43 ) -> Result<Self> {
44 checkarray_finite(&durations, "durations")?;
45
46 if durations.len() != event_observed.len() {
47 return Err(StatsError::DimensionMismatch(format!(
48 "durations length ({durations_len}) must match event_observed length ({events_len})",
49 durations_len = durations.len(),
50 events_len = event_observed.len()
51 )));
52 }
53
54 if durations.is_empty() {
55 return Err(StatsError::InvalidArgument(
56 "Input arrays cannot be empty".to_string(),
57 ));
58 }
59
60 if let Some(conf) = confidence_level {
61 check_probability(conf, "confidence_level")?;
62 }
63
64 let mut time_event_pairs: Vec<(f64, bool)> = durations
66 .iter()
67 .zip(event_observed.iter())
68 .map(|(&t, &e)| (t, e))
69 .collect();
70 time_event_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
71
72 let mut unique_times = Vec::new();
74 let mut at_risk_counts = Vec::new();
75 let mut event_counts = Vec::new();
76 let mut survival_probs = Vec::new();
77
78 let n = time_event_pairs.len();
79 let mut current_survival = 1.0;
80 let mut current_at_risk = n;
81
82 let mut i = 0;
83 while i < time_event_pairs.len() {
84 let current_time = time_event_pairs[i].0;
85 let mut events_at_time = 0;
86 let mut censored_at_time = 0;
87
88 while i < time_event_pairs.len() && time_event_pairs[i].0 == current_time {
90 if time_event_pairs[i].1 {
91 events_at_time += 1;
92 } else {
93 censored_at_time += 1;
94 }
95 i += 1;
96 }
97
98 if events_at_time > 0 {
99 let survival_this_time = 1.0 - (events_at_time as f64) / (current_at_risk as f64);
101 current_survival *= survival_this_time;
102
103 unique_times.push(current_time);
104 at_risk_counts.push(current_at_risk);
105 event_counts.push(events_at_time);
106 survival_probs.push(current_survival);
107 }
108
109 current_at_risk -= events_at_time + censored_at_time;
111 }
112
113 let event_times = Array1::from_vec(unique_times);
114 let survival_function = Array1::from_vec(survival_probs);
115 let at_risk = Array1::from_vec(at_risk_counts);
116 let events = Array1::from_vec(event_counts);
117
118 let confidence_intervals = if let Some(conf_level) = confidence_level {
120 Some(Self::calculate_confidence_intervals(
121 &survival_function,
122 &at_risk,
123 &events,
124 conf_level,
125 )?)
126 } else {
127 None
128 };
129
130 let median_survival_time =
132 Self::calculate_median_survival(&event_times, &survival_function);
133
134 Ok(Self {
135 event_times,
136 survival_function,
137 confidence_intervals,
138 at_risk,
139 events,
140 median_survival_time,
141 })
142 }
143
144 fn calculate_confidence_intervals(
146 survival_function: &Array1<f64>,
147 at_risk: &Array1<usize>,
148 events: &Array1<usize>,
149 confidence_level: f64,
150 ) -> Result<(Array1<f64>, Array1<f64>)> {
151 let _alpha = 1.0 - confidence_level;
152 let z_score = 1.96; let mut lower_bounds = Array1::zeros(survival_function.len());
155 let mut upper_bounds = Array1::zeros(survival_function.len());
156
157 let mut cumulative_variance = 0.0;
159
160 for i in 0..survival_function.len() {
161 let n_i = at_risk[i] as f64;
163 let d_i = events[i] as f64;
164
165 if n_i > d_i && n_i > 0.0 {
166 cumulative_variance += d_i / (n_i * (n_i - d_i));
167 }
168
169 let s_t = survival_function[i];
170 if s_t > 0.0 {
171 let se = s_t * cumulative_variance.sqrt();
172
173 let log_log_s = (-(s_t.ln())).ln();
175 let se_log_log = se / (s_t * s_t.ln().abs());
176
177 let lower_log_log = log_log_s - z_score * se_log_log;
178 let upper_log_log = log_log_s + z_score * se_log_log;
179
180 lower_bounds[i] = (-(-lower_log_log.exp()).exp()).max(0.0);
181 upper_bounds[i] = (-(-upper_log_log.exp()).exp()).min(1.0);
182 } else {
183 lower_bounds[i] = 0.0;
184 upper_bounds[i] = 0.0;
185 }
186 }
187
188 Ok((lower_bounds, upper_bounds))
189 }
190
191 fn calculate_median_survival(
193 event_times: &Array1<f64>,
194 survival_function: &Array1<f64>,
195 ) -> Option<f64> {
196 for i in 0..survival_function.len() {
198 if survival_function[i] <= 0.5 {
199 return Some(event_times[i]);
200 }
201 }
202 None }
204
205 pub fn predict(&self, times: ArrayView1<f64>) -> Result<Array1<f64>> {
207 checkarray_finite(×, "times")?;
208
209 let mut predictions = Array1::zeros(times.len());
210
211 for (i, &t) in times.iter().enumerate() {
212 if t < 0.0 {
213 return Err(StatsError::InvalidArgument(
214 "Times must be non-negative".to_string(),
215 ));
216 }
217
218 let mut survival_prob = 1.0; for j in 0..self.event_times.len() {
222 if self.event_times[j] <= t {
223 survival_prob = self.survival_function[j];
224 } else {
225 break;
226 }
227 }
228
229 predictions[i] = survival_prob;
230 }
231
232 Ok(predictions)
233 }
234}
235
236pub struct LogRankTest;
240
241impl LogRankTest {
242 pub fn compare_two_groups(
253 durations1: ArrayView1<f64>,
254 events1: ArrayView1<bool>,
255 durations2: ArrayView1<f64>,
256 events2: ArrayView1<bool>,
257 ) -> Result<(f64, f64)> {
258 checkarray_finite(&durations1, "durations1")?;
259 checkarray_finite(&durations2, "durations2")?;
260
261 if durations1.len() != events1.len() || durations2.len() != events2.len() {
262 return Err(StatsError::DimensionMismatch(
263 "Durations and events arrays must have same length".to_string(),
264 ));
265 }
266
267 let mut combineddata = Vec::new();
269
270 for i in 0..durations1.len() {
271 combineddata.push((durations1[i], events1[i], 0)); }
273 for i in 0..durations2.len() {
274 combineddata.push((durations2[i], events2[i], 1)); }
276
277 combineddata.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
279
280 let mut observed_group1 = 0.0;
282 let mut expected_group1 = 0.0;
283 let mut variance = 0.0;
284
285 let n1 = durations1.len() as f64;
286 let n2 = durations2.len() as f64;
287 let mut at_risk1 = n1;
288 let mut at_risk2 = n2;
289
290 let mut i = 0;
291 while i < combineddata.len() {
292 let current_time = combineddata[i].0;
293 let mut events_group1 = 0.0;
294 let mut events_group2 = 0.0;
295 let mut censored_group1 = 0.0;
296 let mut censored_group2 = 0.0;
297
298 while i < combineddata.len() && combineddata[i].0 == current_time {
300 let (_, is_event, group) = combineddata[i];
301 if group == 0 {
302 if is_event {
303 events_group1 += 1.0;
304 } else {
305 censored_group1 += 1.0;
306 }
307 } else if is_event {
308 events_group2 += 1.0;
309 } else {
310 censored_group2 += 1.0;
311 }
312 i += 1;
313 }
314
315 let total_events = events_group1 + events_group2;
316 let total_at_risk = at_risk1 + at_risk2;
317
318 if total_events > 0.0 && total_at_risk > 0.0 {
319 let expected_events1 = (at_risk1 / total_at_risk) * total_events;
321
322 let var_contrib =
324 (at_risk1 * at_risk2 * total_events * (total_at_risk - total_events))
325 / (total_at_risk.powi(2) * (total_at_risk - 1.0).max(1.0));
326
327 observed_group1 += events_group1;
328 expected_group1 += expected_events1;
329 variance += var_contrib;
330 }
331
332 at_risk1 -= events_group1 + censored_group1;
334 at_risk2 -= events_group2 + censored_group2;
335 }
336
337 if variance <= 0.0 {
339 return Ok((0.0, 1.0)); }
341
342 let test_statistic = (observed_group1 - expected_group1).powi(2) / variance;
343
344 let p_value = Self::chi_square_survival(test_statistic, 1.0);
346
347 Ok((test_statistic, p_value))
348 }
349
350 fn chi_square_survival(x: f64, df: f64) -> f64 {
352 if x <= 0.0 {
353 return 1.0;
354 }
355
356 let mean = df;
358 let var = 2.0 * df;
359 let std = var.sqrt();
360
361 if df > 30.0 {
363 let z = (x - mean) / std;
364 return 0.5 * (1.0 - Self::erf(z / std::f64::consts::SQRT_2));
365 }
366
367 (-x / mean).exp()
369 }
370
371 fn erf(x: f64) -> f64 {
373 let a1 = 0.254829592;
375 let a2 = -0.284496736;
376 let a3 = 1.421413741;
377 let a4 = -1.453152027;
378 let a5 = 1.061405429;
379 let p = 0.3275911;
380
381 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
382 let x = x.abs();
383
384 let t = 1.0 / (1.0 + p * x);
385 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
386
387 sign * y
388 }
389}
390
391#[derive(Debug, Clone)]
395pub struct CoxPHModel {
396 pub coefficients: Array1<f64>,
398 pub covariance_matrix: Array2<f64>,
400 pub log_likelihood: f64,
402 pub baseline_cumulative_hazard: Array1<f64>,
404 pub baseline_times: Array1<f64>,
406 pub n_iter: usize,
408}
409
410impl CoxPHModel {
411 pub fn fit(
423 durations: ArrayView1<f64>,
424 events: ArrayView1<bool>,
425 covariates: ArrayView2<f64>,
426 max_iter: Option<usize>,
427 tol: Option<f64>,
428 ) -> Result<Self> {
429 checkarray_finite(&durations, "durations")?;
430 checkarray_finite(&covariates, "covariates")?;
431
432 let (n_samples_, n_features) = covariates.dim();
433 let max_iter = max_iter.unwrap_or(100);
434 let tol = tol.unwrap_or(1e-6);
435
436 if durations.len() != n_samples_ || events.len() != n_samples_ {
437 return Err(StatsError::DimensionMismatch(
438 "All input arrays must have the same number of samples".to_string(),
439 ));
440 }
441
442 let mut beta = Array1::zeros(n_features);
444 let mut prev_log_likelihood = f64::NEG_INFINITY;
445
446 for iteration in 0..max_iter {
447 let (log_likelihood, gradient, hessian) =
449 Self::partial_likelihood_derivatives(&durations, &events, &covariates, &beta)?;
450
451 if (log_likelihood - prev_log_likelihood).abs() < tol {
453 let covariance_matrix = Self::invert_hessian(&hessian)?;
454 let (baseline_times, baseline_cumulative_hazard) =
455 Self::calculatebaseline_hazard(&durations, &events, &covariates, &beta)?;
456
457 return Ok(Self {
458 coefficients: beta,
459 covariance_matrix,
460 log_likelihood,
461 baseline_cumulative_hazard,
462 baseline_times,
463 n_iter: iteration + 1,
464 });
465 }
466
467 let hessian_inv = Self::invert_hessian(&hessian)?;
469 let delta = hessian_inv.dot(&gradient);
470 beta = &beta + δ
471
472 prev_log_likelihood = log_likelihood;
473 }
474
475 Err(StatsError::ConvergenceError(format!(
476 "Cox model failed to converge after {max_iter} iterations"
477 )))
478 }
479
480 fn partial_likelihood_derivatives(
482 durations: &ArrayView1<f64>,
483 events: &ArrayView1<bool>,
484 covariates: &ArrayView2<f64>,
485 beta: &Array1<f64>,
486 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
487 let n_samples_ = durations.len();
488 let n_features = beta.len();
489
490 let mut indices: Vec<usize> = (0..n_samples_).collect();
492 indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
493
494 let mut log_likelihood = 0.0;
495 let mut gradient = Array1::zeros(n_features);
496 let mut hessian = Array2::zeros((n_features, n_features));
497
498 for &i in &indices {
499 if !events[i] {
500 continue; }
502
503 let t_i = durations[i];
504 let x_i = covariates.row(i);
505
506 let mut risk_set = Vec::new();
508 for &j in &indices {
509 if durations[j] >= t_i {
510 risk_set.push(j);
511 }
512 }
513
514 if risk_set.is_empty() {
515 continue;
516 }
517
518 let mut exp_beta_x = Array1::zeros(risk_set.len());
520 for (k, &j) in risk_set.iter().enumerate() {
521 let x_j = covariates.row(j);
522 exp_beta_x[k] = x_j.dot(beta).exp();
523 }
524
525 let sum_exp = exp_beta_x.sum();
526 if sum_exp <= 0.0 {
527 continue;
528 }
529
530 log_likelihood += x_i.dot(beta) - sum_exp.ln();
532
533 let mut weighted_x = Array1::<f64>::zeros(n_features);
535 for (k, &j) in risk_set.iter().enumerate() {
536 let x_j = covariates.row(j);
537 let weight = exp_beta_x[k] / sum_exp;
538 weighted_x = &weighted_x + &(weight * &x_j.to_owned());
539 }
540 gradient = &gradient + &(&x_i.to_owned() - &weighted_x);
541
542 for p in 0..n_features {
544 for q in 0..n_features {
545 let mut weighted_sum = 0.0;
546 for (k, &j) in risk_set.iter().enumerate() {
547 let x_j = covariates.row(j);
548 let weight = exp_beta_x[k] / sum_exp;
549 weighted_sum += weight * x_j[p] * x_j[q];
550 }
551 hessian[[p, q]] -= weighted_sum - (weighted_x[p] * weighted_x[q]);
552 }
553 }
554 }
555
556 Ok((log_likelihood, gradient, hessian))
557 }
558
559 fn invert_hessian(hessian: &Array2<f64>) -> Result<Array2<f64>> {
561 let neg_hessian = -hessian;
562 scirs2_linalg::inv(&neg_hessian.view(), None)
563 .map_err(|e| StatsError::ComputationError(format!("Failed to invert Hessian: {e}")))
564 }
565
566 fn calculatebaseline_hazard(
568 durations: &ArrayView1<f64>,
569 events: &ArrayView1<bool>,
570 covariates: &ArrayView2<f64>,
571 beta: &Array1<f64>,
572 ) -> Result<(Array1<f64>, Array1<f64>)> {
573 let n_samples_ = durations.len();
574
575 let mut indices: Vec<usize> = (0..n_samples_).collect();
577 indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
578
579 let mut times = Vec::new();
580 let mut cumulative_hazard = Vec::new();
581 let mut current_cumhaz = 0.0;
582
583 for &i in &indices {
584 if !events[i] {
585 continue;
586 }
587
588 let t_i = durations[i];
589
590 let mut risk_sum = 0.0;
592 for &j in &indices {
593 if durations[j] >= t_i {
594 let x_j = covariates.row(j);
595 risk_sum += x_j.dot(beta).exp();
596 }
597 }
598
599 if risk_sum > 0.0 {
600 current_cumhaz += 1.0 / risk_sum; times.push(t_i);
602 cumulative_hazard.push(current_cumhaz);
603 }
604 }
605
606 Ok((Array1::from_vec(times), Array1::from_vec(cumulative_hazard)))
607 }
608
609 pub fn predict_hazard_ratio(&self, covariates: ArrayView2<f64>) -> Result<Array1<f64>> {
611 checkarray_finite(&covariates, "covariates")?;
612
613 if covariates.ncols() != self.coefficients.len() {
614 return Err(StatsError::DimensionMismatch(format!(
615 "covariates has {features} features, expected {expected}",
616 features = covariates.ncols(),
617 expected = self.coefficients.len()
618 )));
619 }
620
621 let mut hazard_ratios = Array1::zeros(covariates.nrows());
622
623 for i in 0..covariates.nrows() {
624 let x_i = covariates.row(i);
625 hazard_ratios[i] = x_i.dot(&self.coefficients).exp();
626 }
627
628 Ok(hazard_ratios)
629 }
630}
631
632#[derive(Debug, Clone)]
636pub struct AFTModel {
637 pub coefficients: Array1<f64>,
639 pub scale: f64,
641 pub distribution: AFTDistribution,
643}
644
645#[derive(Debug, Clone, Copy)]
647pub enum AFTDistribution {
648 Weibull,
650 Lognormal,
652 Exponential,
654}
655
656impl AFTModel {
657 pub fn fit(
659 durations: ArrayView1<f64>,
660 events: ArrayView1<bool>,
661 covariates: ArrayView2<f64>,
662 distribution: AFTDistribution,
663 ) -> Result<Self> {
664 checkarray_finite(&durations, "durations")?;
665 checkarray_finite(&covariates, "covariates")?;
666
667 let (n_samples_, n_features) = covariates.dim();
668
669 if durations.len() != n_samples_ || events.len() != n_samples_ {
670 return Err(StatsError::DimensionMismatch(
671 "All input arrays must have the same number of samples".to_string(),
672 ));
673 }
674
675 let mut y = Array1::zeros(n_samples_);
679 let mut weights = Array1::zeros(n_samples_);
680
681 for i in 0..n_samples_ {
682 y[i] = durations[i].ln();
683 weights[i] = if events[i] { 1.0 } else { 0.5 }; }
685
686 let mut xtx = Array2::zeros((n_features, n_features));
688 let mut xty = Array1::zeros(n_features);
689
690 for i in 0..n_samples_ {
691 let x_i = covariates.row(i);
692 let w = weights[i];
693
694 for j in 0..n_features {
695 xty[j] += w * x_i[j] * y[i];
696 for k in 0..n_features {
697 xtx[[j, k]] += w * x_i[j] * x_i[k];
698 }
699 }
700 }
701
702 let coefficients = scirs2_linalg::solve(&xtx.view(), &xty.view(), None).map_err(|e| {
703 StatsError::ComputationError(format!("Failed to solve regression: {e}"))
704 })?;
705
706 let mut residual_sum = 0.0;
708 let mut count = 0;
709
710 for i in 0..n_samples_ {
711 if events[i] {
712 let x_i = covariates.row(i);
713 let predicted = x_i.dot(&coefficients);
714 let residual = y[i] - predicted;
715 residual_sum += residual * residual;
716 count += 1;
717 }
718 }
719
720 let scale = if count > 0 {
721 (residual_sum / count as f64).sqrt()
722 } else {
723 1.0
724 };
725
726 Ok(Self {
727 coefficients,
728 scale,
729 distribution,
730 })
731 }
732
733 pub fn predict(&self, covariates: ArrayView2<f64>) -> Result<Array1<f64>> {
735 checkarray_finite(&covariates, "covariates")?;
736
737 if covariates.ncols() != self.coefficients.len() {
738 return Err(StatsError::DimensionMismatch(format!(
739 "covariates has {features} features, expected {expected}",
740 features = covariates.ncols(),
741 expected = self.coefficients.len()
742 )));
743 }
744
745 let mut predictions = Array1::zeros(covariates.nrows());
746
747 for i in 0..covariates.nrows() {
748 let x_i = covariates.row(i);
749 let log_time = x_i.dot(&self.coefficients);
750 predictions[i] = log_time.exp();
751 }
752
753 Ok(predictions)
754 }
755}
756
757#[derive(Debug, Clone)]
761pub struct ExtendedCoxModel {
762 pub coefficients: Array1<f64>,
764 pub covariance_matrix: Array2<f64>,
766 pub log_likelihood: f64,
768 pub stratumbaseline_hazards: Vec<(Array1<f64>, Array1<f64>)>, pub strata: Option<Array1<usize>>,
772 pub n_strata: usize,
774 pub time_varying_indices: Vec<usize>,
776 pub n_iter: usize,
778}
779
780impl ExtendedCoxModel {
781 pub fn fit_stratified(
792 durations: ArrayView1<f64>,
793 events: ArrayView1<bool>,
794 covariates: ArrayView2<f64>,
795 strata: Option<ArrayView1<usize>>,
796 time_varying_indices: Option<Vec<usize>>,
797 max_iter: Option<usize>,
798 tol: Option<f64>,
799 ) -> Result<Self> {
800 checkarray_finite(&durations, "durations")?;
801 checkarray_finite(&covariates, "covariates")?;
802
803 let (n_samples_, n_features) = covariates.dim();
804 let max_iter = max_iter.unwrap_or(100);
805 let tol = tol.unwrap_or(1e-6);
806
807 if durations.len() != n_samples_ || events.len() != n_samples_ {
808 return Err(StatsError::DimensionMismatch(
809 "All input arrays must have the same number of samples".to_string(),
810 ));
811 }
812
813 let (strata_array, n_strata) = if let Some(strata_input) = strata {
815 if strata_input.len() != n_samples_ {
816 return Err(StatsError::DimensionMismatch(
817 "Strata length must match number of samples".to_string(),
818 ));
819 }
820 let max_stratum = strata_input.iter().cloned().max().unwrap_or(0);
821 (Some(strata_input.to_owned()), max_stratum + 1)
822 } else {
823 (None, 1)
824 };
825
826 let time_varying_indices = time_varying_indices.unwrap_or_default();
827
828 let mut beta = Array1::zeros(n_features);
830 let mut prev_log_likelihood = f64::NEG_INFINITY;
831
832 for iteration in 0..max_iter {
833 let (log_likelihood, gradient, hessian) =
835 Self::stratified_partial_likelihood_derivatives(
836 &durations,
837 &events,
838 &covariates,
839 &beta,
840 &strata_array,
841 n_strata,
842 )?;
843
844 if (log_likelihood - prev_log_likelihood).abs() < tol {
846 let covariance_matrix = Self::invert_hessian(&hessian)?;
847 let baseline_hazards = Self::calculate_stratifiedbaseline_hazards(
848 &durations,
849 &events,
850 &covariates,
851 &beta,
852 &strata_array,
853 n_strata,
854 )?;
855
856 return Ok(Self {
857 coefficients: beta,
858 covariance_matrix,
859 log_likelihood,
860 stratumbaseline_hazards: baseline_hazards,
861 strata: strata_array,
862 n_strata,
863 time_varying_indices,
864 n_iter: iteration + 1,
865 });
866 }
867
868 let hessian_inv = Self::invert_hessian(&hessian)?;
870 let delta = hessian_inv.dot(&gradient);
871 beta = &beta + δ
872
873 prev_log_likelihood = log_likelihood;
874 }
875
876 Err(StatsError::ConvergenceError(format!(
877 "Extended Cox model failed to converge after {max_iter} iterations"
878 )))
879 }
880
881 fn stratified_partial_likelihood_derivatives(
883 durations: &ArrayView1<f64>,
884 events: &ArrayView1<bool>,
885 covariates: &ArrayView2<f64>,
886 beta: &Array1<f64>,
887 strata: &Option<Array1<usize>>,
888 n_strata: usize,
889 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
890 let n_samples_ = durations.len();
891 let n_features = beta.len();
892
893 let mut total_log_likelihood = 0.0;
894 let mut total_gradient = Array1::zeros(n_features);
895 let mut total_hessian = Array2::zeros((n_features, n_features));
896
897 for stratum in 0..n_strata {
899 let stratum_indices: Vec<usize> = if let Some(ref strata_array) = strata {
901 (0..n_samples_)
902 .filter(|&i| strata_array[i] == stratum)
903 .collect()
904 } else {
905 (0..n_samples_).collect()
906 };
907
908 if stratum_indices.is_empty() {
909 continue;
910 }
911
912 let mut sorted_indices = stratum_indices.clone();
914 sorted_indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
915
916 let (stratum_ll, stratum_grad, stratum_hess) = Self::stratum_partial_likelihood(
918 durations,
919 events,
920 covariates,
921 beta,
922 &sorted_indices,
923 )?;
924
925 total_log_likelihood += stratum_ll;
926 total_gradient = &total_gradient + &stratum_grad;
927 total_hessian = &total_hessian + &stratum_hess;
928 }
929
930 Ok((total_log_likelihood, total_gradient, total_hessian))
931 }
932
933 fn stratum_partial_likelihood(
935 durations: &ArrayView1<f64>,
936 events: &ArrayView1<bool>,
937 covariates: &ArrayView2<f64>,
938 beta: &Array1<f64>,
939 sorted_indices: &[usize],
940 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
941 let n_features = beta.len();
942
943 let mut log_likelihood = 0.0;
944 let mut gradient = Array1::zeros(n_features);
945 let mut hessian = Array2::zeros((n_features, n_features));
946
947 for &i in sorted_indices {
948 if !events[i] {
949 continue; }
951
952 let t_i = durations[i];
953 let x_i = covariates.row(i);
954
955 let mut risk_set = Vec::new();
957 for &j in sorted_indices {
958 if durations[j] >= t_i {
959 risk_set.push(j);
960 }
961 }
962
963 if risk_set.is_empty() {
964 continue;
965 }
966
967 let mut exp_beta_x = Array1::zeros(risk_set.len());
969 for (k, &j) in risk_set.iter().enumerate() {
970 let x_j = covariates.row(j);
971 exp_beta_x[k] = x_j.dot(beta).exp();
972 }
973
974 let sum_exp = exp_beta_x.sum();
975 if sum_exp <= 0.0 {
976 continue;
977 }
978
979 log_likelihood += x_i.dot(beta) - sum_exp.ln();
981
982 let mut weighted_x = Array1::<f64>::zeros(n_features);
984 for (k, &j) in risk_set.iter().enumerate() {
985 let x_j = covariates.row(j);
986 let weight = exp_beta_x[k] / sum_exp;
987 weighted_x = &weighted_x + &(weight * &x_j.to_owned());
988 }
989 gradient = &gradient + &(&x_i.to_owned() - &weighted_x);
990
991 for p in 0..n_features {
993 for q in 0..n_features {
994 let mut weighted_sum = 0.0;
995 for (k, &j) in risk_set.iter().enumerate() {
996 let x_j = covariates.row(j);
997 let weight = exp_beta_x[k] / sum_exp;
998 weighted_sum += weight * x_j[p] * x_j[q];
999 }
1000 hessian[[p, q]] -= weighted_sum - (weighted_x[p] * weighted_x[q]);
1001 }
1002 }
1003 }
1004
1005 Ok((log_likelihood, gradient, hessian))
1006 }
1007
1008 fn calculate_stratifiedbaseline_hazards(
1010 durations: &ArrayView1<f64>,
1011 events: &ArrayView1<bool>,
1012 covariates: &ArrayView2<f64>,
1013 beta: &Array1<f64>,
1014 strata: &Option<Array1<usize>>,
1015 n_strata: usize,
1016 ) -> Result<Vec<(Array1<f64>, Array1<f64>)>> {
1017 let n_samples_ = durations.len();
1018 let mut baseline_hazards = Vec::new();
1019
1020 for stratum in 0..n_strata {
1021 let stratum_indices: Vec<usize> = if let Some(ref strata_array) = strata {
1023 (0..n_samples_)
1024 .filter(|&i| strata_array[i] == stratum)
1025 .collect()
1026 } else {
1027 (0..n_samples_).collect()
1028 };
1029
1030 if stratum_indices.is_empty() {
1031 baseline_hazards.push((Array1::zeros(0), Array1::zeros(0)));
1032 continue;
1033 }
1034
1035 let mut sorted_indices = stratum_indices.clone();
1037 sorted_indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
1038
1039 let mut times = Vec::new();
1040 let mut cumulative_hazard = Vec::new();
1041 let mut current_cumhaz = 0.0;
1042
1043 for &i in &sorted_indices {
1044 if !events[i] {
1045 continue;
1046 }
1047
1048 let t_i = durations[i];
1049
1050 let mut risk_sum = 0.0;
1052 for &j in &sorted_indices {
1053 if durations[j] >= t_i {
1054 let x_j = covariates.row(j);
1055 risk_sum += x_j.dot(beta).exp();
1056 }
1057 }
1058
1059 if risk_sum > 0.0 {
1060 current_cumhaz += 1.0 / risk_sum; times.push(t_i);
1062 cumulative_hazard.push(current_cumhaz);
1063 }
1064 }
1065
1066 baseline_hazards.push((Array1::from_vec(times), Array1::from_vec(cumulative_hazard)));
1067 }
1068
1069 Ok(baseline_hazards)
1070 }
1071
1072 fn invert_hessian(hessian: &Array2<f64>) -> Result<Array2<f64>> {
1074 let neg_hessian = -hessian;
1075 scirs2_linalg::inv(&neg_hessian.view(), None)
1076 .map_err(|e| StatsError::ComputationError(format!("Failed to invert Hessian: {e}")))
1077 }
1078
1079 pub fn predict_hazard_ratio_stratified(
1081 &self,
1082 covariates: ArrayView2<f64>,
1083 strata: Option<ArrayView1<usize>>,
1084 ) -> Result<Array1<f64>> {
1085 checkarray_finite(&covariates, "covariates")?;
1086
1087 if covariates.ncols() != self.coefficients.len() {
1088 return Err(StatsError::DimensionMismatch(format!(
1089 "covariates has {features} features, expected {expected}",
1090 features = covariates.ncols(),
1091 expected = self.coefficients.len()
1092 )));
1093 }
1094
1095 if let Some(ref strata_input) = strata {
1096 if strata_input.len() != covariates.nrows() {
1097 return Err(StatsError::DimensionMismatch(
1098 "Strata length must match number of predictions".to_string(),
1099 ));
1100 }
1101 }
1102
1103 let mut hazard_ratios = Array1::zeros(covariates.nrows());
1104
1105 for i in 0..covariates.nrows() {
1106 let x_i = covariates.row(i);
1107 hazard_ratios[i] = x_i.dot(&self.coefficients).exp();
1108 }
1109
1110 Ok(hazard_ratios)
1111 }
1112
1113 pub fn coefficient_confidence_intervals(&self, confidencelevel: f64) -> Result<Array2<f64>> {
1115 check_probability(confidencelevel, "confidence_level")?;
1116
1117 let n_features = self.coefficients.len();
1118 let mut intervals = Array2::zeros((n_features, 2));
1119 let _alpha = (1.0 - confidencelevel) / 2.0;
1120 let z_critical = 1.96; for i in 0..n_features {
1123 let coeff = self.coefficients[i];
1124 let se = self.covariance_matrix[[i, i]].sqrt();
1125
1126 intervals[[i, 0]] = coeff - z_critical * se; intervals[[i, 1]] = coeff + z_critical * se; }
1129
1130 Ok(intervals)
1131 }
1132}
1133
1134#[derive(Debug, Clone)]
1139pub struct CompetingRisksModel {
1140 pub coefficients: Vec<Array1<f64>>,
1142 pub covariance_matrices: Vec<Array2<f64>>,
1144 pub baseline_cifs: Vec<(Array1<f64>, Array1<f64>)>, pub n_risks: usize,
1148 pub log_likelihood: f64,
1150}
1151
1152impl CompetingRisksModel {
1153 pub fn fit(
1162 durations: ArrayView1<f64>,
1163 events: ArrayView1<usize>,
1164 covariates: ArrayView2<f64>,
1165 n_risks: usize,
1166 target_risk: usize,
1167 max_iter: Option<usize>,
1168 tol: Option<f64>,
1169 ) -> Result<Self> {
1170 checkarray_finite(&durations, "durations")?;
1171 checkarray_finite(&covariates, "covariates")?;
1172 check_positive(n_risks, "n_risks")?;
1173
1174 let (n_samples_, n_features) = covariates.dim();
1175 let max_iter = max_iter.unwrap_or(100);
1176 let tol = tol.unwrap_or(1e-6);
1177
1178 if durations.len() != n_samples_ || events.len() != n_samples_ {
1179 return Err(StatsError::DimensionMismatch(
1180 "All input arrays must have the same number of samples".to_string(),
1181 ));
1182 }
1183
1184 if target_risk == 0 || target_risk > n_risks {
1185 return Err(StatsError::InvalidArgument(
1186 "target_risk must be between 1 and n_risks".to_string(),
1187 ));
1188 }
1189
1190 let mut coefficients = vec![Array1::zeros(n_features); n_risks];
1193 let mut covariance_matrices = vec![Array2::zeros((n_features, n_features)); n_risks];
1194 let mut baseline_cifs = vec![(Array1::zeros(0), Array1::zeros(0)); n_risks];
1195
1196 let (modified_durations, modified_events, modified_weights) =
1198 Self::prepare_fine_gray_data(&durations, &events, target_risk)?;
1199
1200 let mut beta = Array1::zeros(n_features);
1202 let mut prev_log_likelihood = f64::NEG_INFINITY;
1203
1204 for _iteration in 0..max_iter {
1205 let (log_likelihood, gradient, hessian) = Self::subdistribution_partial_likelihood(
1207 &modified_durations,
1208 &modified_events,
1209 &covariates,
1210 &modified_weights,
1211 &beta,
1212 )?;
1213
1214 if (log_likelihood - prev_log_likelihood).abs() < tol {
1216 coefficients[target_risk - 1] = beta.clone();
1217 covariance_matrices[target_risk - 1] = Self::invert_hessian(&hessian)?;
1218
1219 let (times, cif) = Self::calculatebaseline_cif(
1221 &modified_durations,
1222 &modified_events,
1223 &covariates,
1224 &modified_weights,
1225 &beta,
1226 )?;
1227 baseline_cifs[target_risk - 1] = (times, cif);
1228
1229 return Ok(Self {
1230 coefficients,
1231 covariance_matrices,
1232 baseline_cifs,
1233 n_risks,
1234 log_likelihood,
1235 });
1236 }
1237
1238 let hessian_inv = Self::invert_hessian(&hessian)?;
1240 let delta = hessian_inv.dot(&gradient);
1241 beta = &beta + δ
1242
1243 prev_log_likelihood = log_likelihood;
1244 }
1245
1246 Err(StatsError::ConvergenceError(format!(
1247 "Competing _risks model failed to converge after {max_iter} iterations"
1248 )))
1249 }
1250
1251 fn prepare_fine_gray_data(
1253 durations: &ArrayView1<f64>,
1254 events: &ArrayView1<usize>,
1255 target_risk: usize,
1256 ) -> Result<(Array1<f64>, Array1<bool>, Array1<f64>)> {
1257 let n_samples_ = durations.len();
1258 let modified_durations = durations.to_owned();
1259 let mut modified_events = Array1::from_elem(n_samples_, false);
1260 let mut weights = Array1::ones(n_samples_);
1261
1262 let censoring_km = Self::kaplan_meier_censoring(durations, events)?;
1264
1265 for i in 0..n_samples_ {
1266 if events[i] == target_risk {
1267 modified_events[i] = true;
1269 weights[i] = 1.0;
1270 } else if events[i] == 0 {
1271 modified_events[i] = false;
1273 weights[i] = 1.0;
1274 } else {
1275 modified_events[i] = false;
1277
1278 let km_prob = Self::interpolate_km_probability(
1280 &censoring_km.0,
1281 &censoring_km.1,
1282 durations[i],
1283 );
1284 weights[i] = if km_prob > 0.0 { 1.0 / km_prob } else { 0.0 };
1285 }
1286 }
1287
1288 Ok((modified_durations, modified_events, weights))
1289 }
1290
1291 fn kaplan_meier_censoring(
1293 durations: &ArrayView1<f64>,
1294 events: &ArrayView1<usize>,
1295 ) -> Result<(Array1<f64>, Array1<f64>)> {
1296 let censoring_events: Array1<bool> = events.mapv(|e| e == 0);
1298
1299 let mut time_event_pairs: Vec<(f64, bool)> = durations
1301 .iter()
1302 .zip(censoring_events.iter())
1303 .map(|(&t, &e)| (t, e))
1304 .collect();
1305 time_event_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
1306
1307 let mut times = Vec::new();
1308 let mut survival_probs = Vec::new();
1309 let mut current_survival = 1.0;
1310 let mut current_at_risk = time_event_pairs.len();
1311
1312 let mut i = 0;
1313 while i < time_event_pairs.len() {
1314 let current_time = time_event_pairs[i].0;
1315 let mut events_at_time = 0;
1316 let mut total_at_time = 0;
1317
1318 while i < time_event_pairs.len() && time_event_pairs[i].0 == current_time {
1319 if time_event_pairs[i].1 {
1320 events_at_time += 1;
1321 }
1322 total_at_time += 1;
1323 i += 1;
1324 }
1325
1326 if events_at_time > 0 {
1327 let survival_this_time = 1.0 - (events_at_time as f64) / (current_at_risk as f64);
1328 current_survival *= survival_this_time;
1329
1330 times.push(current_time);
1331 survival_probs.push(current_survival);
1332 }
1333
1334 current_at_risk -= total_at_time;
1335 }
1336
1337 Ok((Array1::from_vec(times), Array1::from_vec(survival_probs)))
1338 }
1339
1340 fn interpolate_km_probability(times: &Array1<f64>, probs: &Array1<f64>, t: f64) -> f64 {
1342 if times.is_empty() {
1343 return 1.0;
1344 }
1345
1346 if t <= times[0] {
1347 return 1.0;
1348 }
1349
1350 for i in 0..times.len() {
1351 if times[i] >= t {
1352 return probs[i];
1353 }
1354 }
1355
1356 probs[probs.len() - 1]
1358 }
1359
1360 fn subdistribution_partial_likelihood(
1362 durations: &Array1<f64>,
1363 events: &Array1<bool>,
1364 covariates: &ArrayView2<f64>,
1365 weights: &Array1<f64>,
1366 beta: &Array1<f64>,
1367 ) -> Result<(f64, Array1<f64>, Array2<f64>)> {
1368 let n_samples_ = durations.len();
1369 let n_features = beta.len();
1370
1371 let mut indices: Vec<usize> = (0..n_samples_).collect();
1373 indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
1374
1375 let mut log_likelihood = 0.0;
1376 let mut gradient = Array1::zeros(n_features);
1377 let mut hessian = Array2::zeros((n_features, n_features));
1378
1379 for &i in &indices {
1380 if !events[i] {
1381 continue; }
1383
1384 let t_i = durations[i];
1385 let x_i = covariates.row(i);
1386 let w_i = weights[i];
1387
1388 let mut weighted_exp_sum = 0.0;
1390 let mut weighted_x_sum = Array1::zeros(n_features);
1391 let mut weighted_xx_sum = Array2::zeros((n_features, n_features));
1392
1393 for &j in &indices {
1394 if durations[j] >= t_i {
1395 let x_j = covariates.row(j);
1396 let w_j = weights[j];
1397 let exp_beta_x = x_j.dot(beta).exp();
1398 let weighted_exp = w_j * exp_beta_x;
1399
1400 weighted_exp_sum += weighted_exp;
1401 weighted_x_sum = &weighted_x_sum + &(weighted_exp * &x_j.to_owned());
1402
1403 for p in 0..n_features {
1404 for q in 0..n_features {
1405 weighted_xx_sum[[p, q]] += weighted_exp * x_j[p] * x_j[q];
1406 }
1407 }
1408 }
1409 }
1410
1411 if weighted_exp_sum <= 0.0 {
1412 continue;
1413 }
1414
1415 let weighted_mean_x = &weighted_x_sum / weighted_exp_sum;
1417
1418 log_likelihood += w_i * (x_i.dot(beta) - weighted_exp_sum.ln());
1419 gradient = &gradient + &(w_i * (&x_i.to_owned() - &weighted_mean_x));
1420
1421 let weighted_mean_xx = &weighted_xx_sum / weighted_exp_sum;
1423 let outer_product = outer_product_array(&weighted_mean_x);
1424 hessian = &hessian - &(w_i * (&weighted_mean_xx - &outer_product));
1425 }
1426
1427 Ok((log_likelihood, gradient, hessian))
1428 }
1429
1430 fn calculatebaseline_cif(
1432 durations: &Array1<f64>,
1433 events: &Array1<bool>,
1434 covariates: &ArrayView2<f64>,
1435 weights: &Array1<f64>,
1436 beta: &Array1<f64>,
1437 ) -> Result<(Array1<f64>, Array1<f64>)> {
1438 let n_samples_ = durations.len();
1439
1440 let mut indices: Vec<usize> = (0..n_samples_).collect();
1442 indices.sort_by(|&i, &j| durations[i].partial_cmp(&durations[j]).unwrap());
1443
1444 let mut times = Vec::new();
1445 let mut cumulative_incidence = Vec::new();
1446 let mut current_cif = 0.0;
1447
1448 for &i in &indices {
1449 if !events[i] {
1450 continue;
1451 }
1452
1453 let t_i = durations[i];
1454 let w_i = weights[i];
1455
1456 let mut weighted_risk_sum = 0.0;
1458 for &j in &indices {
1459 if durations[j] >= t_i {
1460 let x_j = covariates.row(j);
1461 let w_j = weights[j];
1462 weighted_risk_sum += w_j * x_j.dot(beta).exp();
1463 }
1464 }
1465
1466 if weighted_risk_sum > 0.0 {
1467 current_cif += w_i / weighted_risk_sum;
1468 times.push(t_i);
1469 cumulative_incidence.push(current_cif);
1470 }
1471 }
1472
1473 Ok((
1474 Array1::from_vec(times),
1475 Array1::from_vec(cumulative_incidence),
1476 ))
1477 }
1478
1479 fn invert_hessian(hessian: &Array2<f64>) -> Result<Array2<f64>> {
1481 let neg_hessian = -hessian;
1482 scirs2_linalg::inv(&neg_hessian.view(), None)
1483 .map_err(|e| StatsError::ComputationError(format!("Failed to invert Hessian: {e}")))
1484 }
1485
1486 pub fn predict_cumulative_incidence(
1488 &self,
1489 covariates: ArrayView2<f64>,
1490 target_risk: usize,
1491 times: ArrayView1<f64>,
1492 ) -> Result<Array2<f64>> {
1493 checkarray_finite(&covariates, "covariates")?;
1494 checkarray_finite(×, "times")?;
1495
1496 if target_risk == 0 || target_risk > self.n_risks {
1497 return Err(StatsError::InvalidArgument(
1498 "target_risk must be between 1 and n_risks".to_string(),
1499 ));
1500 }
1501
1502 let risk_idx = target_risk - 1;
1503 let n_samples_ = covariates.nrows();
1504 let n_times = times.len();
1505 let mut predictions = Array2::zeros((n_samples_, n_times));
1506
1507 let beta = &self.coefficients[risk_idx];
1508 let (baseline_times, baseline_cif) = &self.baseline_cifs[risk_idx];
1509
1510 for i in 0..n_samples_ {
1511 let x_i = covariates.row(i);
1512 let hazard_ratio = x_i.dot(beta).exp();
1513
1514 for (j, &t) in times.iter().enumerate() {
1515 let baseline_value = Self::interpolatebaseline_cif(baseline_times, baseline_cif, t);
1517
1518 predictions[[i, j]] = 1.0 - (1.0 - baseline_value).powf(hazard_ratio);
1521 }
1522 }
1523
1524 Ok(predictions)
1525 }
1526
1527 fn interpolatebaseline_cif(times: &Array1<f64>, cif: &Array1<f64>, t: f64) -> f64 {
1529 if times.is_empty() {
1530 return 0.0;
1531 }
1532
1533 if t <= times[0] {
1534 return 0.0;
1535 }
1536
1537 for i in 0..times.len() {
1538 if times[i] >= t {
1539 return cif[i];
1540 }
1541 }
1542
1543 cif[cif.len() - 1]
1545 }
1546}
1547
1548#[allow(dead_code)]
1550fn outer_product_array(v: &Array1<f64>) -> Array2<f64> {
1551 let n = v.len();
1552 let mut result = Array2::zeros((n, n));
1553 for i in 0..n {
1554 for j in 0..n {
1555 result[[i, j]] = v[i] * v[j];
1556 }
1557 }
1558 result
1559}