1#![allow(
2 unused_variables,
3 unused_imports,
4 clippy::too_many_arguments,
5 clippy::needless_range_loop
6)]
7
8use pyo3::prelude::*;
9use rayon::prelude::*;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
12#[pyclass]
13pub enum CorrelationType {
14 Independent,
15 SharedFrailty,
16 CopulaBased,
17}
18
19#[pymethods]
20impl CorrelationType {
21 #[new]
22 fn new(name: &str) -> PyResult<Self> {
23 match name.to_lowercase().as_str() {
24 "independent" => Ok(CorrelationType::Independent),
25 "shared_frailty" | "sharedfrailty" | "frailty" => Ok(CorrelationType::SharedFrailty),
26 "copula_based" | "copulabased" | "copula" => Ok(CorrelationType::CopulaBased),
27 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
28 "Unknown correlation type. Use 'independent', 'shared_frailty', or 'copula_based'",
29 )),
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
35#[pyclass]
36pub struct JointCompetingRisksConfig {
37 #[pyo3(get, set)]
38 pub num_causes: usize,
39 #[pyo3(get, set)]
40 pub correlation_structure: CorrelationType,
41 #[pyo3(get, set)]
42 pub frailty_variance: f64,
43 #[pyo3(get, set)]
44 pub max_iter: usize,
45 #[pyo3(get, set)]
46 pub tol: f64,
47 #[pyo3(get, set)]
48 pub estimate_correlation: bool,
49}
50
51#[pymethods]
52impl JointCompetingRisksConfig {
53 #[new]
54 #[pyo3(signature = (
55 num_causes=2,
56 correlation_structure=CorrelationType::Independent,
57 frailty_variance=1.0,
58 max_iter=100,
59 tol=1e-6,
60 estimate_correlation=true
61 ))]
62 pub fn new(
63 num_causes: usize,
64 correlation_structure: CorrelationType,
65 frailty_variance: f64,
66 max_iter: usize,
67 tol: f64,
68 estimate_correlation: bool,
69 ) -> PyResult<Self> {
70 if num_causes < 2 {
71 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
72 "num_causes must be at least 2",
73 ));
74 }
75 if frailty_variance <= 0.0 {
76 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
77 "frailty_variance must be positive",
78 ));
79 }
80 if max_iter == 0 {
81 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
82 "max_iter must be positive",
83 ));
84 }
85
86 Ok(JointCompetingRisksConfig {
87 num_causes,
88 correlation_structure,
89 frailty_variance,
90 max_iter,
91 tol,
92 estimate_correlation,
93 })
94 }
95}
96
97#[derive(Debug, Clone)]
98#[pyclass]
99pub struct CauseResult {
100 #[pyo3(get)]
101 pub cause: usize,
102 #[pyo3(get)]
103 pub coefficients: Vec<f64>,
104 #[pyo3(get)]
105 pub std_errors: Vec<f64>,
106 #[pyo3(get)]
107 pub hazard_ratios: Vec<f64>,
108 #[pyo3(get)]
109 pub baseline_hazard_times: Vec<f64>,
110 #[pyo3(get)]
111 pub baseline_hazard: Vec<f64>,
112 #[pyo3(get)]
113 pub cumulative_baseline_hazard: Vec<f64>,
114}
115
116#[derive(Debug, Clone)]
117#[pyclass]
118pub struct JointCompetingRisksResult {
119 #[pyo3(get)]
120 pub cause_specific_results: Vec<CauseResult>,
121 #[pyo3(get)]
122 pub subdistribution_results: Vec<CauseResult>,
123 #[pyo3(get)]
124 pub correlation_matrix: Option<Vec<Vec<f64>>>,
125 #[pyo3(get)]
126 pub frailty_variance: Option<f64>,
127 #[pyo3(get)]
128 pub log_likelihood: f64,
129 #[pyo3(get)]
130 pub aic: f64,
131 #[pyo3(get)]
132 pub bic: f64,
133 #[pyo3(get)]
134 pub n_events_by_cause: Vec<usize>,
135 #[pyo3(get)]
136 pub n_obs: usize,
137 #[pyo3(get)]
138 pub n_iter: usize,
139 #[pyo3(get)]
140 pub converged: bool,
141}
142
143#[pymethods]
144impl JointCompetingRisksResult {
145 fn __repr__(&self) -> String {
146 format!(
147 "JointCompetingRisksResult(n_causes={}, n_obs={}, converged={})",
148 self.cause_specific_results.len(),
149 self.n_obs,
150 self.converged
151 )
152 }
153
154 fn predict_cif(&self, x: Vec<f64>, n_obs: usize, cause_idx: usize) -> PyResult<Vec<Vec<f64>>> {
155 if cause_idx >= self.cause_specific_results.len() {
156 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
157 "cause_idx out of range",
158 ));
159 }
160
161 let cs = &self.cause_specific_results[cause_idx];
162 let n_vars = cs.coefficients.len();
163 let n_times = cs.baseline_hazard_times.len();
164
165 let all_cum_hazards: Vec<Vec<Vec<f64>>> = self
166 .cause_specific_results
167 .iter()
168 .map(|cr| {
169 (0..n_obs)
170 .map(|i| {
171 let mut lp = 0.0;
172 for j in 0..cr.coefficients.len().min(n_vars) {
173 lp += x[i * n_vars + j] * cr.coefficients[j];
174 }
175 let exp_lp = lp.exp();
176 cr.cumulative_baseline_hazard
177 .iter()
178 .map(|&h0| h0 * exp_lp)
179 .collect()
180 })
181 .collect()
182 })
183 .collect();
184
185 let cif: Vec<Vec<f64>> = (0..n_obs)
186 .into_par_iter()
187 .map(|i| {
188 let cs = &self.cause_specific_results[cause_idx];
189 let mut cif_vec = Vec::with_capacity(n_times);
190 let mut cum_inc = 0.0;
191 let mut prev_surv = 1.0;
192
193 for t in 0..n_times {
194 let mut total_hazard = 0.0;
195 for k in 0..self.cause_specific_results.len() {
196 if t < all_cum_hazards[k][i].len() {
197 let h_t = if t == 0 {
198 all_cum_hazards[k][i][t]
199 } else {
200 all_cum_hazards[k][i][t] - all_cum_hazards[k][i][t - 1]
201 };
202 total_hazard += h_t.max(0.0);
203 }
204 }
205
206 let h_cause_t = if t == 0 {
207 all_cum_hazards[cause_idx][i][t]
208 } else {
209 all_cum_hazards[cause_idx][i][t] - all_cum_hazards[cause_idx][i][t - 1]
210 };
211
212 cum_inc += prev_surv * h_cause_t.max(0.0);
213 prev_surv *= (-total_hazard).exp();
214 cif_vec.push(cum_inc.min(1.0));
215 }
216
217 cif_vec
218 })
219 .collect();
220
221 Ok(cif)
222 }
223
224 fn predict_overall_survival(&self, x: Vec<f64>, n_obs: usize) -> Vec<Vec<f64>> {
225 let n_times = self.cause_specific_results[0].baseline_hazard_times.len();
226 let n_vars = self.cause_specific_results[0].coefficients.len();
227
228 (0..n_obs)
229 .into_par_iter()
230 .map(|i| {
231 let mut surv_vec = Vec::with_capacity(n_times);
232 let mut cum_surv = 1.0;
233
234 for t in 0..n_times {
235 let mut total_hazard = 0.0;
236
237 for cs in &self.cause_specific_results {
238 let mut lp = 0.0;
239 for j in 0..cs.coefficients.len().min(n_vars) {
240 lp += x[i * n_vars + j] * cs.coefficients[j];
241 }
242 let exp_lp = lp.exp();
243
244 let h_t = if t == 0 {
245 cs.cumulative_baseline_hazard
246 .first()
247 .copied()
248 .unwrap_or(0.0)
249 } else {
250 cs.cumulative_baseline_hazard.get(t).copied().unwrap_or(0.0)
251 - cs.cumulative_baseline_hazard
252 .get(t - 1)
253 .copied()
254 .unwrap_or(0.0)
255 };
256
257 total_hazard += h_t * exp_lp;
258 }
259
260 cum_surv *= (-total_hazard).exp();
261 surv_vec.push(cum_surv.clamp(0.0, 1.0));
262 }
263
264 surv_vec
265 })
266 .collect()
267 }
268}
269
270fn fit_cause_specific_cox(
271 x: &[f64],
272 n: usize,
273 p: usize,
274 time: &[f64],
275 cause: &[i32],
276 weights: &[f64],
277 cause_of_interest: i32,
278 max_iter: usize,
279 tol: f64,
280) -> (Vec<f64>, Vec<f64>, f64, bool, usize) {
281 let mut beta = vec![0.0; p];
282 let mut converged = false;
283 let mut n_iter = 0;
284 let mut loglik = 0.0;
285
286 for iter in 0..max_iter {
287 n_iter = iter + 1;
288
289 let (gradient, hessian, ll) =
290 compute_gradient_hessian(x, n, p, time, cause, weights, &beta, cause_of_interest);
291 loglik = ll;
292
293 let delta = match solve_system(&hessian, &gradient) {
294 Some(d) => d,
295 None => break,
296 };
297
298 let max_change: f64 = delta.iter().map(|d| d.abs()).fold(0.0, f64::max);
299
300 for j in 0..p {
301 beta[j] += delta[j];
302 }
303
304 if max_change < tol {
305 converged = true;
306 break;
307 }
308 }
309
310 let (_, final_hessian, _) =
311 compute_gradient_hessian(x, n, p, time, cause, weights, &beta, cause_of_interest);
312
313 let var_cov = invert_matrix(&final_hessian).unwrap_or_else(|| vec![vec![0.0; p]; p]);
314 let std_errors: Vec<f64> = (0..p)
315 .map(|j| var_cov[j][j].abs().sqrt().max(1e-10))
316 .collect();
317
318 (beta, std_errors, loglik, converged, n_iter)
319}
320
321fn compute_gradient_hessian(
322 x: &[f64],
323 n: usize,
324 p: usize,
325 time: &[f64],
326 cause: &[i32],
327 weights: &[f64],
328 beta: &[f64],
329 cause_of_interest: i32,
330) -> (Vec<f64>, Vec<Vec<f64>>, f64) {
331 let eta: Vec<f64> = (0..n)
332 .map(|i| {
333 let mut e = 0.0;
334 for j in 0..p {
335 e += x[i * p + j] * beta[j];
336 }
337 e.clamp(-700.0, 700.0)
338 })
339 .collect();
340
341 let exp_eta: Vec<f64> = eta.iter().map(|&e| e.exp()).collect();
342
343 let mut sorted_indices: Vec<usize> = (0..n).collect();
344 sorted_indices.sort_by(|&a, &b| {
345 time[b]
346 .partial_cmp(&time[a])
347 .unwrap_or(std::cmp::Ordering::Equal)
348 });
349
350 let mut gradient = vec![0.0; p];
351 let mut hessian = vec![vec![0.0; p]; p];
352 let mut loglik = 0.0;
353
354 let mut risk_sum = 0.0;
355 let mut weighted_x = vec![0.0; p];
356 let mut weighted_x_outer = vec![vec![0.0; p]; p];
357
358 for &idx in &sorted_indices {
359 let w = weights[idx] * exp_eta[idx];
360 risk_sum += w;
361
362 for j in 0..p {
363 let xij = x[idx * p + j];
364 weighted_x[j] += w * xij;
365
366 for k in 0..p {
367 let xik = x[idx * p + k];
368 weighted_x_outer[j][k] += w * xij * xik;
369 }
370 }
371
372 if cause[idx] == cause_of_interest && risk_sum > 0.0 {
373 loglik += weights[idx] * (eta[idx] - risk_sum.ln());
374
375 for j in 0..p {
376 let xij = x[idx * p + j];
377 let x_bar = weighted_x[j] / risk_sum;
378 gradient[j] += weights[idx] * (xij - x_bar);
379
380 for k in 0..p {
381 let x_bar_k = weighted_x[k] / risk_sum;
382 let x_outer_bar = weighted_x_outer[j][k] / risk_sum;
383 hessian[j][k] -= weights[idx] * (x_outer_bar - x_bar * x_bar_k);
384 }
385 }
386 }
387 }
388
389 (gradient, hessian, loglik)
390}
391
392fn solve_system(a: &[Vec<f64>], b: &[f64]) -> Option<Vec<f64>> {
393 let n = b.len();
394 let mut aug: Vec<Vec<f64>> = a.to_vec();
395 let mut rhs = b.to_vec();
396
397 for i in 0..n {
398 let mut max_row = i;
399 for k in (i + 1)..n {
400 if aug[k][i].abs() > aug[max_row][i].abs() {
401 max_row = k;
402 }
403 }
404 aug.swap(i, max_row);
405 rhs.swap(i, max_row);
406
407 if aug[i][i].abs() < 1e-12 {
408 return None;
409 }
410
411 for k in (i + 1)..n {
412 let factor = aug[k][i] / aug[i][i];
413 rhs[k] -= factor * rhs[i];
414 for j in i..n {
415 aug[k][j] -= factor * aug[i][j];
416 }
417 }
418 }
419
420 let mut x = vec![0.0; n];
421 for i in (0..n).rev() {
422 x[i] = rhs[i];
423 for j in (i + 1)..n {
424 x[i] -= aug[i][j] * x[j];
425 }
426 x[i] /= aug[i][i];
427 }
428
429 Some(x)
430}
431
432fn invert_matrix(mat: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
433 let n = mat.len();
434 if n == 0 {
435 return None;
436 }
437
438 let mut aug: Vec<Vec<f64>> = mat
439 .iter()
440 .enumerate()
441 .map(|(i, row)| {
442 let mut new_row = row.clone();
443 new_row.extend(vec![0.0; n]);
444 new_row[n + i] = 1.0;
445 new_row
446 })
447 .collect();
448
449 for i in 0..n {
450 let mut max_row = i;
451 for k in (i + 1)..n {
452 if aug[k][i].abs() > aug[max_row][i].abs() {
453 max_row = k;
454 }
455 }
456 aug.swap(i, max_row);
457
458 if aug[i][i].abs() < 1e-12 {
459 return None;
460 }
461
462 let pivot = aug[i][i];
463 for val in aug[i].iter_mut() {
464 *val /= pivot;
465 }
466
467 for k in 0..n {
468 if k != i {
469 let factor = aug[k][i];
470 for j in 0..(2 * n) {
471 aug[k][j] -= factor * aug[i][j];
472 }
473 }
474 }
475 }
476
477 Some(aug.into_iter().map(|row| row[n..].to_vec()).collect())
478}
479
480fn compute_baseline_hazard(
481 n: usize,
482 time: &[f64],
483 cause: &[i32],
484 weights: &[f64],
485 exp_eta: &[f64],
486 cause_of_interest: i32,
487) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
488 let mut sorted_indices: Vec<usize> = (0..n).collect();
489 sorted_indices.sort_by(|&a, &b| {
490 time[a]
491 .partial_cmp(&time[b])
492 .unwrap_or(std::cmp::Ordering::Equal)
493 });
494
495 let mut unique_times = Vec::new();
496 let mut baseline = Vec::new();
497 let mut cumulative = Vec::new();
498 let mut cum_h0 = 0.0;
499
500 let mut i = 0;
501 while i < n {
502 let idx = sorted_indices[i];
503 if cause[idx] != cause_of_interest {
504 i += 1;
505 continue;
506 }
507
508 let current_time = time[idx];
509 let mut n_events = 0.0;
510
511 while i < n && (time[sorted_indices[i]] - current_time).abs() < 1e-9 {
512 if cause[sorted_indices[i]] == cause_of_interest {
513 n_events += weights[sorted_indices[i]];
514 }
515 i += 1;
516 }
517
518 let mut risk_sum = 0.0;
519 for &j in &sorted_indices {
520 if time[j] >= current_time {
521 risk_sum += weights[j] * exp_eta[j];
522 }
523 }
524
525 if risk_sum > 0.0 && n_events > 0.0 {
526 let h0 = n_events / risk_sum;
527 cum_h0 += h0;
528
529 unique_times.push(current_time);
530 baseline.push(h0);
531 cumulative.push(cum_h0);
532 }
533 }
534
535 (unique_times, baseline, cumulative)
536}
537
538#[pyfunction]
539#[pyo3(signature = (x, n_obs, n_vars, time, cause, config, weights=None))]
540pub fn joint_competing_risks(
541 x: Vec<f64>,
542 n_obs: usize,
543 n_vars: usize,
544 time: Vec<f64>,
545 cause: Vec<i32>,
546 config: &JointCompetingRisksConfig,
547 weights: Option<Vec<f64>>,
548) -> PyResult<JointCompetingRisksResult> {
549 if x.len() != n_obs * n_vars {
550 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
551 "x length must equal n_obs * n_vars",
552 ));
553 }
554 if time.len() != n_obs || cause.len() != n_obs {
555 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
556 "time and cause must have length n_obs",
557 ));
558 }
559
560 let wt = weights.unwrap_or_else(|| vec![1.0; n_obs]);
561
562 let n_events_by_cause: Vec<usize> = (1..=config.num_causes as i32)
563 .map(|c| cause.iter().filter(|&&cc| cc == c).count())
564 .collect();
565
566 let mut total_loglik = 0.0;
567 let mut total_n_iter = 0;
568 let mut all_converged = true;
569
570 let mut cause_specific_results = Vec::with_capacity(config.num_causes);
571
572 for c in 1..=config.num_causes as i32 {
573 let (beta, std_errors, loglik, converged, n_iter) = fit_cause_specific_cox(
574 &x,
575 n_obs,
576 n_vars,
577 &time,
578 &cause,
579 &wt,
580 c,
581 config.max_iter,
582 config.tol,
583 );
584
585 total_loglik += loglik;
586 total_n_iter = total_n_iter.max(n_iter);
587 all_converged = all_converged && converged;
588
589 let exp_eta: Vec<f64> = (0..n_obs)
590 .map(|i| {
591 let mut e = 0.0;
592 for j in 0..n_vars {
593 e += x[i * n_vars + j] * beta[j];
594 }
595 e.clamp(-700.0, 700.0).exp()
596 })
597 .collect();
598
599 let (times, baseline, cumulative) =
600 compute_baseline_hazard(n_obs, &time, &cause, &wt, &exp_eta, c);
601
602 let hazard_ratios: Vec<f64> = beta.iter().map(|&b| b.exp()).collect();
603
604 cause_specific_results.push(CauseResult {
605 cause: c as usize,
606 coefficients: beta,
607 std_errors,
608 hazard_ratios,
609 baseline_hazard_times: times,
610 baseline_hazard: baseline,
611 cumulative_baseline_hazard: cumulative,
612 });
613 }
614
615 let subdistribution_results = cause_specific_results.clone();
616
617 let correlation_matrix = match config.correlation_structure {
618 CorrelationType::Independent => None,
619 CorrelationType::SharedFrailty | CorrelationType::CopulaBased => {
620 let mut corr = vec![vec![0.0; config.num_causes]; config.num_causes];
621 for i in 0..config.num_causes {
622 corr[i][i] = 1.0;
623 }
624 Some(corr)
625 }
626 };
627
628 let frailty_variance = match config.correlation_structure {
629 CorrelationType::SharedFrailty => Some(config.frailty_variance),
630 _ => None,
631 };
632
633 let n_params = n_vars * config.num_causes;
634 let aic = -2.0 * total_loglik + 2.0 * n_params as f64;
635 let bic = -2.0 * total_loglik + (n_params as f64) * (n_obs as f64).ln();
636
637 Ok(JointCompetingRisksResult {
638 cause_specific_results,
639 subdistribution_results,
640 correlation_matrix,
641 frailty_variance,
642 log_likelihood: total_loglik,
643 aic,
644 bic,
645 n_events_by_cause,
646 n_obs,
647 n_iter: total_n_iter,
648 converged: all_converged,
649 })
650}
651
652#[cfg(test)]
653mod tests {
654 use super::*;
655
656 #[test]
657 fn test_config() {
658 let config =
659 JointCompetingRisksConfig::new(2, CorrelationType::Independent, 1.0, 100, 1e-6, true)
660 .unwrap();
661 assert_eq!(config.num_causes, 2);
662 }
663
664 #[test]
665 fn test_config_validation() {
666 assert!(
667 JointCompetingRisksConfig::new(1, CorrelationType::Independent, 1.0, 100, 1e-6, true)
668 .is_err()
669 );
670 assert!(
671 JointCompetingRisksConfig::new(2, CorrelationType::Independent, -1.0, 100, 1e-6, true)
672 .is_err()
673 );
674 }
675
676 #[test]
677 fn test_joint_competing_risks_basic() {
678 let x = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.5, 0.5];
679 let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
680 let cause = vec![1, 2, 0, 1, 2];
681
682 let config =
683 JointCompetingRisksConfig::new(2, CorrelationType::Independent, 1.0, 100, 1e-5, true)
684 .unwrap();
685
686 let result = joint_competing_risks(x, 5, 2, time, cause, &config, None).unwrap();
687
688 assert_eq!(result.cause_specific_results.len(), 2);
689 assert_eq!(result.n_events_by_cause.len(), 2);
690 assert_eq!(result.n_obs, 5);
691 }
692}