Skip to main content

survival/regression/
joint_competing.rs

1#![allow(
2    unused_variables,
3    unused_imports,
4    clippy::too_many_arguments,
5    clippy::needless_range_loop
6)]
7
8use pyo3::prelude::*;
9use rayon::prelude::*;
10
11#[derive(Debug, Clone, Copy, PartialEq)]
12#[pyclass]
13pub enum CorrelationType {
14    Independent,
15    SharedFrailty,
16    CopulaBased,
17}
18
19#[pymethods]
20impl CorrelationType {
21    #[new]
22    fn new(name: &str) -> PyResult<Self> {
23        match name.to_lowercase().as_str() {
24            "independent" => Ok(CorrelationType::Independent),
25            "shared_frailty" | "sharedfrailty" | "frailty" => Ok(CorrelationType::SharedFrailty),
26            "copula_based" | "copulabased" | "copula" => Ok(CorrelationType::CopulaBased),
27            _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
28                "Unknown correlation type. Use 'independent', 'shared_frailty', or 'copula_based'",
29            )),
30        }
31    }
32}
33
34#[derive(Debug, Clone)]
35#[pyclass]
36pub struct JointCompetingRisksConfig {
37    #[pyo3(get, set)]
38    pub num_causes: usize,
39    #[pyo3(get, set)]
40    pub correlation_structure: CorrelationType,
41    #[pyo3(get, set)]
42    pub frailty_variance: f64,
43    #[pyo3(get, set)]
44    pub max_iter: usize,
45    #[pyo3(get, set)]
46    pub tol: f64,
47    #[pyo3(get, set)]
48    pub estimate_correlation: bool,
49}
50
51#[pymethods]
52impl JointCompetingRisksConfig {
53    #[new]
54    #[pyo3(signature = (
55        num_causes=2,
56        correlation_structure=CorrelationType::Independent,
57        frailty_variance=1.0,
58        max_iter=100,
59        tol=1e-6,
60        estimate_correlation=true
61    ))]
62    pub fn new(
63        num_causes: usize,
64        correlation_structure: CorrelationType,
65        frailty_variance: f64,
66        max_iter: usize,
67        tol: f64,
68        estimate_correlation: bool,
69    ) -> PyResult<Self> {
70        if num_causes < 2 {
71            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
72                "num_causes must be at least 2",
73            ));
74        }
75        if frailty_variance <= 0.0 {
76            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
77                "frailty_variance must be positive",
78            ));
79        }
80        if max_iter == 0 {
81            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
82                "max_iter must be positive",
83            ));
84        }
85
86        Ok(JointCompetingRisksConfig {
87            num_causes,
88            correlation_structure,
89            frailty_variance,
90            max_iter,
91            tol,
92            estimate_correlation,
93        })
94    }
95}
96
97#[derive(Debug, Clone)]
98#[pyclass]
99pub struct CauseResult {
100    #[pyo3(get)]
101    pub cause: usize,
102    #[pyo3(get)]
103    pub coefficients: Vec<f64>,
104    #[pyo3(get)]
105    pub std_errors: Vec<f64>,
106    #[pyo3(get)]
107    pub hazard_ratios: Vec<f64>,
108    #[pyo3(get)]
109    pub baseline_hazard_times: Vec<f64>,
110    #[pyo3(get)]
111    pub baseline_hazard: Vec<f64>,
112    #[pyo3(get)]
113    pub cumulative_baseline_hazard: Vec<f64>,
114}
115
116#[derive(Debug, Clone)]
117#[pyclass]
118pub struct JointCompetingRisksResult {
119    #[pyo3(get)]
120    pub cause_specific_results: Vec<CauseResult>,
121    #[pyo3(get)]
122    pub subdistribution_results: Vec<CauseResult>,
123    #[pyo3(get)]
124    pub correlation_matrix: Option<Vec<Vec<f64>>>,
125    #[pyo3(get)]
126    pub frailty_variance: Option<f64>,
127    #[pyo3(get)]
128    pub log_likelihood: f64,
129    #[pyo3(get)]
130    pub aic: f64,
131    #[pyo3(get)]
132    pub bic: f64,
133    #[pyo3(get)]
134    pub n_events_by_cause: Vec<usize>,
135    #[pyo3(get)]
136    pub n_obs: usize,
137    #[pyo3(get)]
138    pub n_iter: usize,
139    #[pyo3(get)]
140    pub converged: bool,
141}
142
143#[pymethods]
144impl JointCompetingRisksResult {
145    fn __repr__(&self) -> String {
146        format!(
147            "JointCompetingRisksResult(n_causes={}, n_obs={}, converged={})",
148            self.cause_specific_results.len(),
149            self.n_obs,
150            self.converged
151        )
152    }
153
154    fn predict_cif(&self, x: Vec<f64>, n_obs: usize, cause_idx: usize) -> PyResult<Vec<Vec<f64>>> {
155        if cause_idx >= self.cause_specific_results.len() {
156            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
157                "cause_idx out of range",
158            ));
159        }
160
161        let cs = &self.cause_specific_results[cause_idx];
162        let n_vars = cs.coefficients.len();
163        let n_times = cs.baseline_hazard_times.len();
164
165        let all_cum_hazards: Vec<Vec<Vec<f64>>> = self
166            .cause_specific_results
167            .iter()
168            .map(|cr| {
169                (0..n_obs)
170                    .map(|i| {
171                        let mut lp = 0.0;
172                        for j in 0..cr.coefficients.len().min(n_vars) {
173                            lp += x[i * n_vars + j] * cr.coefficients[j];
174                        }
175                        let exp_lp = lp.exp();
176                        cr.cumulative_baseline_hazard
177                            .iter()
178                            .map(|&h0| h0 * exp_lp)
179                            .collect()
180                    })
181                    .collect()
182            })
183            .collect();
184
185        let cif: Vec<Vec<f64>> = (0..n_obs)
186            .into_par_iter()
187            .map(|i| {
188                let cs = &self.cause_specific_results[cause_idx];
189                let mut cif_vec = Vec::with_capacity(n_times);
190                let mut cum_inc = 0.0;
191                let mut prev_surv = 1.0;
192
193                for t in 0..n_times {
194                    let mut total_hazard = 0.0;
195                    for k in 0..self.cause_specific_results.len() {
196                        if t < all_cum_hazards[k][i].len() {
197                            let h_t = if t == 0 {
198                                all_cum_hazards[k][i][t]
199                            } else {
200                                all_cum_hazards[k][i][t] - all_cum_hazards[k][i][t - 1]
201                            };
202                            total_hazard += h_t.max(0.0);
203                        }
204                    }
205
206                    let h_cause_t = if t == 0 {
207                        all_cum_hazards[cause_idx][i][t]
208                    } else {
209                        all_cum_hazards[cause_idx][i][t] - all_cum_hazards[cause_idx][i][t - 1]
210                    };
211
212                    cum_inc += prev_surv * h_cause_t.max(0.0);
213                    prev_surv *= (-total_hazard).exp();
214                    cif_vec.push(cum_inc.min(1.0));
215                }
216
217                cif_vec
218            })
219            .collect();
220
221        Ok(cif)
222    }
223
224    fn predict_overall_survival(&self, x: Vec<f64>, n_obs: usize) -> Vec<Vec<f64>> {
225        let n_times = self.cause_specific_results[0].baseline_hazard_times.len();
226        let n_vars = self.cause_specific_results[0].coefficients.len();
227
228        (0..n_obs)
229            .into_par_iter()
230            .map(|i| {
231                let mut surv_vec = Vec::with_capacity(n_times);
232                let mut cum_surv = 1.0;
233
234                for t in 0..n_times {
235                    let mut total_hazard = 0.0;
236
237                    for cs in &self.cause_specific_results {
238                        let mut lp = 0.0;
239                        for j in 0..cs.coefficients.len().min(n_vars) {
240                            lp += x[i * n_vars + j] * cs.coefficients[j];
241                        }
242                        let exp_lp = lp.exp();
243
244                        let h_t = if t == 0 {
245                            cs.cumulative_baseline_hazard
246                                .first()
247                                .copied()
248                                .unwrap_or(0.0)
249                        } else {
250                            cs.cumulative_baseline_hazard.get(t).copied().unwrap_or(0.0)
251                                - cs.cumulative_baseline_hazard
252                                    .get(t - 1)
253                                    .copied()
254                                    .unwrap_or(0.0)
255                        };
256
257                        total_hazard += h_t * exp_lp;
258                    }
259
260                    cum_surv *= (-total_hazard).exp();
261                    surv_vec.push(cum_surv.clamp(0.0, 1.0));
262                }
263
264                surv_vec
265            })
266            .collect()
267    }
268}
269
270fn fit_cause_specific_cox(
271    x: &[f64],
272    n: usize,
273    p: usize,
274    time: &[f64],
275    cause: &[i32],
276    weights: &[f64],
277    cause_of_interest: i32,
278    max_iter: usize,
279    tol: f64,
280) -> (Vec<f64>, Vec<f64>, f64, bool, usize) {
281    let mut beta = vec![0.0; p];
282    let mut converged = false;
283    let mut n_iter = 0;
284    let mut loglik = 0.0;
285
286    for iter in 0..max_iter {
287        n_iter = iter + 1;
288
289        let (gradient, hessian, ll) =
290            compute_gradient_hessian(x, n, p, time, cause, weights, &beta, cause_of_interest);
291        loglik = ll;
292
293        let delta = match solve_system(&hessian, &gradient) {
294            Some(d) => d,
295            None => break,
296        };
297
298        let max_change: f64 = delta.iter().map(|d| d.abs()).fold(0.0, f64::max);
299
300        for j in 0..p {
301            beta[j] += delta[j];
302        }
303
304        if max_change < tol {
305            converged = true;
306            break;
307        }
308    }
309
310    let (_, final_hessian, _) =
311        compute_gradient_hessian(x, n, p, time, cause, weights, &beta, cause_of_interest);
312
313    let var_cov = invert_matrix(&final_hessian).unwrap_or_else(|| vec![vec![0.0; p]; p]);
314    let std_errors: Vec<f64> = (0..p)
315        .map(|j| var_cov[j][j].abs().sqrt().max(1e-10))
316        .collect();
317
318    (beta, std_errors, loglik, converged, n_iter)
319}
320
321fn compute_gradient_hessian(
322    x: &[f64],
323    n: usize,
324    p: usize,
325    time: &[f64],
326    cause: &[i32],
327    weights: &[f64],
328    beta: &[f64],
329    cause_of_interest: i32,
330) -> (Vec<f64>, Vec<Vec<f64>>, f64) {
331    let eta: Vec<f64> = (0..n)
332        .map(|i| {
333            let mut e = 0.0;
334            for j in 0..p {
335                e += x[i * p + j] * beta[j];
336            }
337            e.clamp(-700.0, 700.0)
338        })
339        .collect();
340
341    let exp_eta: Vec<f64> = eta.iter().map(|&e| e.exp()).collect();
342
343    let mut sorted_indices: Vec<usize> = (0..n).collect();
344    sorted_indices.sort_by(|&a, &b| {
345        time[b]
346            .partial_cmp(&time[a])
347            .unwrap_or(std::cmp::Ordering::Equal)
348    });
349
350    let mut gradient = vec![0.0; p];
351    let mut hessian = vec![vec![0.0; p]; p];
352    let mut loglik = 0.0;
353
354    let mut risk_sum = 0.0;
355    let mut weighted_x = vec![0.0; p];
356    let mut weighted_x_outer = vec![vec![0.0; p]; p];
357
358    for &idx in &sorted_indices {
359        let w = weights[idx] * exp_eta[idx];
360        risk_sum += w;
361
362        for j in 0..p {
363            let xij = x[idx * p + j];
364            weighted_x[j] += w * xij;
365
366            for k in 0..p {
367                let xik = x[idx * p + k];
368                weighted_x_outer[j][k] += w * xij * xik;
369            }
370        }
371
372        if cause[idx] == cause_of_interest && risk_sum > 0.0 {
373            loglik += weights[idx] * (eta[idx] - risk_sum.ln());
374
375            for j in 0..p {
376                let xij = x[idx * p + j];
377                let x_bar = weighted_x[j] / risk_sum;
378                gradient[j] += weights[idx] * (xij - x_bar);
379
380                for k in 0..p {
381                    let x_bar_k = weighted_x[k] / risk_sum;
382                    let x_outer_bar = weighted_x_outer[j][k] / risk_sum;
383                    hessian[j][k] -= weights[idx] * (x_outer_bar - x_bar * x_bar_k);
384                }
385            }
386        }
387    }
388
389    (gradient, hessian, loglik)
390}
391
392fn solve_system(a: &[Vec<f64>], b: &[f64]) -> Option<Vec<f64>> {
393    let n = b.len();
394    let mut aug: Vec<Vec<f64>> = a.to_vec();
395    let mut rhs = b.to_vec();
396
397    for i in 0..n {
398        let mut max_row = i;
399        for k in (i + 1)..n {
400            if aug[k][i].abs() > aug[max_row][i].abs() {
401                max_row = k;
402            }
403        }
404        aug.swap(i, max_row);
405        rhs.swap(i, max_row);
406
407        if aug[i][i].abs() < 1e-12 {
408            return None;
409        }
410
411        for k in (i + 1)..n {
412            let factor = aug[k][i] / aug[i][i];
413            rhs[k] -= factor * rhs[i];
414            for j in i..n {
415                aug[k][j] -= factor * aug[i][j];
416            }
417        }
418    }
419
420    let mut x = vec![0.0; n];
421    for i in (0..n).rev() {
422        x[i] = rhs[i];
423        for j in (i + 1)..n {
424            x[i] -= aug[i][j] * x[j];
425        }
426        x[i] /= aug[i][i];
427    }
428
429    Some(x)
430}
431
432fn invert_matrix(mat: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
433    let n = mat.len();
434    if n == 0 {
435        return None;
436    }
437
438    let mut aug: Vec<Vec<f64>> = mat
439        .iter()
440        .enumerate()
441        .map(|(i, row)| {
442            let mut new_row = row.clone();
443            new_row.extend(vec![0.0; n]);
444            new_row[n + i] = 1.0;
445            new_row
446        })
447        .collect();
448
449    for i in 0..n {
450        let mut max_row = i;
451        for k in (i + 1)..n {
452            if aug[k][i].abs() > aug[max_row][i].abs() {
453                max_row = k;
454            }
455        }
456        aug.swap(i, max_row);
457
458        if aug[i][i].abs() < 1e-12 {
459            return None;
460        }
461
462        let pivot = aug[i][i];
463        for val in aug[i].iter_mut() {
464            *val /= pivot;
465        }
466
467        for k in 0..n {
468            if k != i {
469                let factor = aug[k][i];
470                for j in 0..(2 * n) {
471                    aug[k][j] -= factor * aug[i][j];
472                }
473            }
474        }
475    }
476
477    Some(aug.into_iter().map(|row| row[n..].to_vec()).collect())
478}
479
480fn compute_baseline_hazard(
481    n: usize,
482    time: &[f64],
483    cause: &[i32],
484    weights: &[f64],
485    exp_eta: &[f64],
486    cause_of_interest: i32,
487) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
488    let mut sorted_indices: Vec<usize> = (0..n).collect();
489    sorted_indices.sort_by(|&a, &b| {
490        time[a]
491            .partial_cmp(&time[b])
492            .unwrap_or(std::cmp::Ordering::Equal)
493    });
494
495    let mut unique_times = Vec::new();
496    let mut baseline = Vec::new();
497    let mut cumulative = Vec::new();
498    let mut cum_h0 = 0.0;
499
500    let mut i = 0;
501    while i < n {
502        let idx = sorted_indices[i];
503        if cause[idx] != cause_of_interest {
504            i += 1;
505            continue;
506        }
507
508        let current_time = time[idx];
509        let mut n_events = 0.0;
510
511        while i < n && (time[sorted_indices[i]] - current_time).abs() < 1e-9 {
512            if cause[sorted_indices[i]] == cause_of_interest {
513                n_events += weights[sorted_indices[i]];
514            }
515            i += 1;
516        }
517
518        let mut risk_sum = 0.0;
519        for &j in &sorted_indices {
520            if time[j] >= current_time {
521                risk_sum += weights[j] * exp_eta[j];
522            }
523        }
524
525        if risk_sum > 0.0 && n_events > 0.0 {
526            let h0 = n_events / risk_sum;
527            cum_h0 += h0;
528
529            unique_times.push(current_time);
530            baseline.push(h0);
531            cumulative.push(cum_h0);
532        }
533    }
534
535    (unique_times, baseline, cumulative)
536}
537
538#[pyfunction]
539#[pyo3(signature = (x, n_obs, n_vars, time, cause, config, weights=None))]
540pub fn joint_competing_risks(
541    x: Vec<f64>,
542    n_obs: usize,
543    n_vars: usize,
544    time: Vec<f64>,
545    cause: Vec<i32>,
546    config: &JointCompetingRisksConfig,
547    weights: Option<Vec<f64>>,
548) -> PyResult<JointCompetingRisksResult> {
549    if x.len() != n_obs * n_vars {
550        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
551            "x length must equal n_obs * n_vars",
552        ));
553    }
554    if time.len() != n_obs || cause.len() != n_obs {
555        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
556            "time and cause must have length n_obs",
557        ));
558    }
559
560    let wt = weights.unwrap_or_else(|| vec![1.0; n_obs]);
561
562    let n_events_by_cause: Vec<usize> = (1..=config.num_causes as i32)
563        .map(|c| cause.iter().filter(|&&cc| cc == c).count())
564        .collect();
565
566    let mut total_loglik = 0.0;
567    let mut total_n_iter = 0;
568    let mut all_converged = true;
569
570    let mut cause_specific_results = Vec::with_capacity(config.num_causes);
571
572    for c in 1..=config.num_causes as i32 {
573        let (beta, std_errors, loglik, converged, n_iter) = fit_cause_specific_cox(
574            &x,
575            n_obs,
576            n_vars,
577            &time,
578            &cause,
579            &wt,
580            c,
581            config.max_iter,
582            config.tol,
583        );
584
585        total_loglik += loglik;
586        total_n_iter = total_n_iter.max(n_iter);
587        all_converged = all_converged && converged;
588
589        let exp_eta: Vec<f64> = (0..n_obs)
590            .map(|i| {
591                let mut e = 0.0;
592                for j in 0..n_vars {
593                    e += x[i * n_vars + j] * beta[j];
594                }
595                e.clamp(-700.0, 700.0).exp()
596            })
597            .collect();
598
599        let (times, baseline, cumulative) =
600            compute_baseline_hazard(n_obs, &time, &cause, &wt, &exp_eta, c);
601
602        let hazard_ratios: Vec<f64> = beta.iter().map(|&b| b.exp()).collect();
603
604        cause_specific_results.push(CauseResult {
605            cause: c as usize,
606            coefficients: beta,
607            std_errors,
608            hazard_ratios,
609            baseline_hazard_times: times,
610            baseline_hazard: baseline,
611            cumulative_baseline_hazard: cumulative,
612        });
613    }
614
615    let subdistribution_results = cause_specific_results.clone();
616
617    let correlation_matrix = match config.correlation_structure {
618        CorrelationType::Independent => None,
619        CorrelationType::SharedFrailty | CorrelationType::CopulaBased => {
620            let mut corr = vec![vec![0.0; config.num_causes]; config.num_causes];
621            for i in 0..config.num_causes {
622                corr[i][i] = 1.0;
623            }
624            Some(corr)
625        }
626    };
627
628    let frailty_variance = match config.correlation_structure {
629        CorrelationType::SharedFrailty => Some(config.frailty_variance),
630        _ => None,
631    };
632
633    let n_params = n_vars * config.num_causes;
634    let aic = -2.0 * total_loglik + 2.0 * n_params as f64;
635    let bic = -2.0 * total_loglik + (n_params as f64) * (n_obs as f64).ln();
636
637    Ok(JointCompetingRisksResult {
638        cause_specific_results,
639        subdistribution_results,
640        correlation_matrix,
641        frailty_variance,
642        log_likelihood: total_loglik,
643        aic,
644        bic,
645        n_events_by_cause,
646        n_obs,
647        n_iter: total_n_iter,
648        converged: all_converged,
649    })
650}
651
652#[cfg(test)]
653mod tests {
654    use super::*;
655
656    #[test]
657    fn test_config() {
658        let config =
659            JointCompetingRisksConfig::new(2, CorrelationType::Independent, 1.0, 100, 1e-6, true)
660                .unwrap();
661        assert_eq!(config.num_causes, 2);
662    }
663
664    #[test]
665    fn test_config_validation() {
666        assert!(
667            JointCompetingRisksConfig::new(1, CorrelationType::Independent, 1.0, 100, 1e-6, true)
668                .is_err()
669        );
670        assert!(
671            JointCompetingRisksConfig::new(2, CorrelationType::Independent, -1.0, 100, 1e-6, true)
672                .is_err()
673        );
674    }
675
676    #[test]
677    fn test_joint_competing_risks_basic() {
678        let x = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.5, 0.5];
679        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
680        let cause = vec![1, 2, 0, 1, 2];
681
682        let config =
683            JointCompetingRisksConfig::new(2, CorrelationType::Independent, 1.0, 100, 1e-5, true)
684                .unwrap();
685
686        let result = joint_competing_risks(x, 5, 2, time, cause, &config, None).unwrap();
687
688        assert_eq!(result.cause_specific_results.len(), 2);
689        assert_eq!(result.n_events_by_cause.len(), 2);
690        assert_eq!(result.n_obs, 5);
691    }
692}