Skip to main content

survival/joint/
joint_model.rs

1#![allow(
2    unused_variables,
3    unused_imports,
4    unused_mut,
5    unused_assignments,
6    clippy::too_many_arguments,
7    clippy::needless_range_loop,
8    clippy::type_complexity
9)]
10
11use pyo3::prelude::*;
12use rayon::prelude::*;
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15#[pyclass]
16pub enum AssociationStructure {
17    Value,
18    Slope,
19    ValueSlope,
20    Area,
21    SharedRandomEffects,
22}
23
24#[pymethods]
25impl AssociationStructure {
26    #[new]
27    fn new(name: &str) -> PyResult<Self> {
28        match name.to_lowercase().as_str() {
29            "value" | "current_value" => Ok(AssociationStructure::Value),
30            "slope" | "current_slope" => Ok(AssociationStructure::Slope),
31            "value_slope" | "valueslope" => Ok(AssociationStructure::ValueSlope),
32            "area" | "cumulative" => Ok(AssociationStructure::Area),
33            "shared" | "shared_random_effects" => Ok(AssociationStructure::SharedRandomEffects),
34            _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
35                "Unknown association structure",
36            )),
37        }
38    }
39}
40
41#[derive(Debug, Clone)]
42#[pyclass]
43pub struct JointModelConfig {
44    #[pyo3(get, set)]
45    pub association: AssociationStructure,
46    #[pyo3(get, set)]
47    pub n_quadrature: usize,
48    #[pyo3(get, set)]
49    pub max_iter: usize,
50    #[pyo3(get, set)]
51    pub tol: f64,
52    #[pyo3(get, set)]
53    pub baseline_hazard_knots: usize,
54}
55
56#[pymethods]
57impl JointModelConfig {
58    #[new]
59    #[pyo3(signature = (association=AssociationStructure::Value, n_quadrature=15, max_iter=500, tol=1e-4, baseline_hazard_knots=5))]
60    pub fn new(
61        association: AssociationStructure,
62        n_quadrature: usize,
63        max_iter: usize,
64        tol: f64,
65        baseline_hazard_knots: usize,
66    ) -> Self {
67        JointModelConfig {
68            association,
69            n_quadrature,
70            max_iter,
71            tol,
72            baseline_hazard_knots,
73        }
74    }
75}
76
77#[derive(Debug, Clone)]
78#[pyclass]
79pub struct JointModelResult {
80    #[pyo3(get)]
81    pub longitudinal_fixed: Vec<f64>,
82    #[pyo3(get)]
83    pub longitudinal_fixed_se: Vec<f64>,
84    #[pyo3(get)]
85    pub survival_fixed: Vec<f64>,
86    #[pyo3(get)]
87    pub survival_fixed_se: Vec<f64>,
88    #[pyo3(get)]
89    pub association_param: f64,
90    #[pyo3(get)]
91    pub association_se: f64,
92    #[pyo3(get)]
93    pub random_effects_var: Vec<f64>,
94    #[pyo3(get)]
95    pub residual_var: f64,
96    #[pyo3(get)]
97    pub baseline_hazard: Vec<f64>,
98    #[pyo3(get)]
99    pub baseline_hazard_times: Vec<f64>,
100    #[pyo3(get)]
101    pub log_likelihood: f64,
102    #[pyo3(get)]
103    pub aic: f64,
104    #[pyo3(get)]
105    pub bic: f64,
106    #[pyo3(get)]
107    pub n_iter: usize,
108    #[pyo3(get)]
109    pub converged: bool,
110    #[pyo3(get)]
111    pub random_effects: Vec<Vec<f64>>,
112}
113
114fn gauss_hermite_quadrature(n: usize) -> (Vec<f64>, Vec<f64>) {
115    let nodes_5 = vec![
116        -2.020182870456086,
117        -0.9585724646138185,
118        0.0,
119        0.9585724646138185,
120        2.020182870456086,
121    ];
122    let weights_5 = vec![
123        0.01995324205905,
124        0.3936193231522,
125        0.9453087204829,
126        0.3936193231522,
127        0.01995324205905,
128    ];
129
130    let nodes_15 = vec![
131        -4.499990707309,
132        -3.669950373404,
133        -2.967166927906,
134        -2.325732486173,
135        -1.719992575186,
136        -1.136115585211,
137        -0.5650695832556,
138        0.0,
139        0.5650695832556,
140        1.136115585211,
141        1.719992575186,
142        2.325732486173,
143        2.967166927906,
144        3.669950373404,
145        4.499990707309,
146    ];
147    let weights_15 = vec![
148        1.522475804254e-09,
149        1.059115547711e-06,
150        1.000044412325e-04,
151        2.778068842913e-03,
152        3.078003387255e-02,
153        1.584889157959e-01,
154        4.120286874989e-01,
155        5.641003087264e-01,
156        4.120286874989e-01,
157        1.584889157959e-01,
158        3.078003387255e-02,
159        2.778068842913e-03,
160        1.000044412325e-04,
161        1.059115547711e-06,
162        1.522475804254e-09,
163    ];
164
165    if n <= 5 {
166        (nodes_5, weights_5)
167    } else {
168        (nodes_15, weights_15)
169    }
170}
171
172fn longitudinal_model_value(
173    time: f64,
174    beta: &[f64],
175    x_fixed: &[f64],
176    random_intercept: f64,
177    random_slope: f64,
178) -> f64 {
179    let mut value = random_intercept + random_slope * time;
180    for (j, &xj) in x_fixed.iter().enumerate() {
181        if j < beta.len() {
182            value += beta[j] * xj;
183        }
184    }
185    value
186}
187
188fn longitudinal_model_slope(
189    _time: f64,
190    beta: &[f64],
191    _x_fixed: &[f64],
192    _random_intercept: f64,
193    random_slope: f64,
194) -> f64 {
195    let mut slope = random_slope;
196    if beta.len() > 1 {
197        slope += beta[1];
198    }
199    slope
200}
201
202#[allow(clippy::too_many_arguments)]
203fn compute_survival_contribution(
204    event_time: f64,
205    event_status: i32,
206    x_surv: &[f64],
207    gamma: &[f64],
208    alpha: f64,
209    beta_long: &[f64],
210    x_long_fixed: &[f64],
211    random_intercept: f64,
212    random_slope: f64,
213    baseline_hazard: &[f64],
214    baseline_times: &[f64],
215    association: &AssociationStructure,
216) -> f64 {
217    let mut linear_pred = 0.0;
218    for (j, &xj) in x_surv.iter().enumerate() {
219        if j < gamma.len() {
220            linear_pred += gamma[j] * xj;
221        }
222    }
223
224    let marker_contribution = match association {
225        AssociationStructure::Value => {
226            let m_t = longitudinal_model_value(
227                event_time,
228                beta_long,
229                x_long_fixed,
230                random_intercept,
231                random_slope,
232            );
233            alpha * m_t
234        }
235        AssociationStructure::Slope => {
236            let dm_t = longitudinal_model_slope(
237                event_time,
238                beta_long,
239                x_long_fixed,
240                random_intercept,
241                random_slope,
242            );
243            alpha * dm_t
244        }
245        AssociationStructure::ValueSlope => {
246            let m_t = longitudinal_model_value(
247                event_time,
248                beta_long,
249                x_long_fixed,
250                random_intercept,
251                random_slope,
252            );
253            let dm_t = longitudinal_model_slope(
254                event_time,
255                beta_long,
256                x_long_fixed,
257                random_intercept,
258                random_slope,
259            );
260            alpha * (m_t + dm_t)
261        }
262        AssociationStructure::Area => {
263            let m_t = longitudinal_model_value(
264                event_time,
265                beta_long,
266                x_long_fixed,
267                random_intercept,
268                random_slope,
269            );
270            alpha * m_t * event_time / 2.0
271        }
272        AssociationStructure::SharedRandomEffects => alpha * random_intercept,
273    };
274
275    linear_pred += marker_contribution;
276
277    let mut cum_hazard = 0.0;
278    for (t_idx, &t) in baseline_times.iter().enumerate() {
279        if t > event_time {
280            break;
281        }
282        if t_idx < baseline_hazard.len() {
283            cum_hazard += baseline_hazard[t_idx];
284        }
285    }
286
287    let log_hazard = if event_status == 1 {
288        let h0 = baseline_hazard
289            .iter()
290            .zip(baseline_times.iter())
291            .filter(|(_, t)| (*t - event_time).abs() < 1e-6)
292            .map(|(&h, _)| h)
293            .next()
294            .unwrap_or(0.01);
295
296        (h0.max(1e-10)).ln() + linear_pred
297    } else {
298        0.0
299    };
300
301    log_hazard - cum_hazard * linear_pred.exp()
302}
303
304fn compute_longitudinal_contribution(
305    y_obs: &[f64],
306    times_obs: &[f64],
307    beta: &[f64],
308    x_fixed: &[f64],
309    n_fixed: usize,
310    random_intercept: f64,
311    random_slope: f64,
312    sigma_sq: f64,
313) -> f64 {
314    let n_obs = y_obs.len();
315    let mut log_lik = 0.0;
316
317    for i in 0..n_obs {
318        let x_i: Vec<f64> = (0..n_fixed).map(|j| x_fixed[i * n_fixed + j]).collect();
319        let pred =
320            longitudinal_model_value(times_obs[i], beta, &x_i, random_intercept, random_slope);
321        let resid = y_obs[i] - pred;
322        log_lik += -0.5 * resid * resid / sigma_sq - 0.5 * sigma_sq.ln();
323    }
324
325    log_lik
326}
327
328#[pyfunction]
329#[pyo3(signature = (
330    y_longitudinal,
331    times_longitudinal,
332    x_longitudinal,
333    n_long_obs,
334    n_long_vars,
335    subject_ids_long,
336    event_time,
337    event_status,
338    x_survival,
339    n_subjects,
340    n_surv_vars,
341    config
342))]
343pub fn joint_model(
344    y_longitudinal: Vec<f64>,
345    times_longitudinal: Vec<f64>,
346    x_longitudinal: Vec<f64>,
347    n_long_obs: usize,
348    n_long_vars: usize,
349    subject_ids_long: Vec<usize>,
350    event_time: Vec<f64>,
351    event_status: Vec<i32>,
352    x_survival: Vec<f64>,
353    n_subjects: usize,
354    n_surv_vars: usize,
355    config: &JointModelConfig,
356) -> PyResult<JointModelResult> {
357    if y_longitudinal.len() != n_long_obs
358        || times_longitudinal.len() != n_long_obs
359        || subject_ids_long.len() != n_long_obs
360    {
361        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
362            "Longitudinal data dimensions mismatch",
363        ));
364    }
365    if event_time.len() != n_subjects || event_status.len() != n_subjects {
366        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
367            "Survival data dimensions mismatch",
368        ));
369    }
370
371    let mut beta_long = vec![0.0; n_long_vars];
372    let mut gamma_surv = vec![0.0; n_surv_vars];
373    let mut alpha = 0.0;
374    let mut sigma_sq = 1.0;
375    let mut d11: f64 = 1.0;
376    let mut d22: f64 = 0.1;
377
378    let mut random_effects: Vec<Vec<f64>> = vec![vec![0.0, 0.0]; n_subjects];
379
380    let mut unique_times: Vec<f64> = event_time.clone();
381    unique_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
382    unique_times.dedup();
383    let n_knots = config.baseline_hazard_knots.min(unique_times.len());
384    let baseline_times: Vec<f64> = (0..n_knots)
385        .map(|i| unique_times[i * unique_times.len() / n_knots])
386        .collect();
387    let mut baseline_hazard = vec![0.01; n_knots];
388
389    let (quad_nodes, _quad_weights) = gauss_hermite_quadrature(config.n_quadrature);
390
391    let subject_indices: Vec<Vec<usize>> = (0..n_subjects)
392        .map(|i| {
393            (0..n_long_obs)
394                .filter(|&j| subject_ids_long[j] == i)
395                .collect()
396        })
397        .collect();
398
399    let mut prev_log_lik = f64::NEG_INFINITY;
400    let mut converged = false;
401    let mut n_iter = 0;
402
403    for iter in 0..config.max_iter {
404        n_iter = iter + 1;
405
406        let new_random_effects: Vec<Vec<f64>> = (0..n_subjects)
407            .into_par_iter()
408            .map(|i| {
409                let subj_indices = &subject_indices[i];
410
411                let y_i: Vec<f64> = subj_indices.iter().map(|&j| y_longitudinal[j]).collect();
412                let t_i: Vec<f64> = subj_indices
413                    .iter()
414                    .map(|&j| times_longitudinal[j])
415                    .collect();
416                let x_long_i: Vec<f64> = {
417                    let mut result = Vec::with_capacity(subj_indices.len() * n_long_vars);
418                    for &j in subj_indices {
419                        for k in 0..n_long_vars {
420                            result.push(x_longitudinal[j * n_long_vars + k]);
421                        }
422                    }
423                    result
424                };
425                let x_surv_i: Vec<f64> = (0..n_surv_vars)
426                    .map(|k| x_survival[i * n_surv_vars + k])
427                    .collect();
428
429                let mut best_re = random_effects[i].clone();
430                let mut best_contrib = f64::NEG_INFINITY;
431
432                for &node_b0 in &quad_nodes {
433                    for &node_b1 in &quad_nodes {
434                        let b0 = node_b0 * d11.sqrt();
435                        let b1 = node_b1 * d22.sqrt();
436
437                        let long_contrib = compute_longitudinal_contribution(
438                            &y_i,
439                            &t_i,
440                            &beta_long,
441                            &x_long_i,
442                            n_long_vars,
443                            b0,
444                            b1,
445                            sigma_sq,
446                        );
447
448                        let surv_contrib = compute_survival_contribution(
449                            event_time[i],
450                            event_status[i],
451                            &x_surv_i,
452                            &gamma_surv,
453                            alpha,
454                            &beta_long,
455                            &x_long_i,
456                            b0,
457                            b1,
458                            &baseline_hazard,
459                            &baseline_times,
460                            &config.association,
461                        );
462
463                        let re_prior = -0.5 * (b0 * b0 / d11 + b1 * b1 / d22);
464                        let total = long_contrib + surv_contrib + re_prior;
465
466                        if total > best_contrib {
467                            best_contrib = total;
468                            best_re = vec![b0, b1];
469                        }
470                    }
471                }
472
473                best_re
474            })
475            .collect();
476
477        random_effects = new_random_effects;
478
479        let mut gradient_beta = vec![0.0; n_long_vars];
480        let mut hessian_beta = vec![0.0; n_long_vars];
481
482        for j in 0..n_long_obs {
483            let subj = subject_ids_long[j];
484            let b0 = random_effects[subj][0];
485            let b1 = random_effects[subj][1];
486
487            let x_j: Vec<f64> = (0..n_long_vars)
488                .map(|k| x_longitudinal[j * n_long_vars + k])
489                .collect();
490
491            let pred = longitudinal_model_value(times_longitudinal[j], &beta_long, &x_j, b0, b1);
492            let resid = y_longitudinal[j] - pred;
493
494            for k in 0..n_long_vars {
495                gradient_beta[k] += resid * x_j[k] / sigma_sq;
496                hessian_beta[k] += x_j[k] * x_j[k] / sigma_sq;
497            }
498        }
499
500        for k in 0..n_long_vars {
501            if hessian_beta[k].abs() > 1e-10 {
502                beta_long[k] += gradient_beta[k] / hessian_beta[k];
503            }
504        }
505
506        let mut ss_resid = 0.0;
507        for j in 0..n_long_obs {
508            let subj = subject_ids_long[j];
509            let b0 = random_effects[subj][0];
510            let b1 = random_effects[subj][1];
511            let x_j: Vec<f64> = (0..n_long_vars)
512                .map(|k| x_longitudinal[j * n_long_vars + k])
513                .collect();
514            let pred = longitudinal_model_value(times_longitudinal[j], &beta_long, &x_j, b0, b1);
515            ss_resid += (y_longitudinal[j] - pred).powi(2);
516        }
517        sigma_sq = (ss_resid / n_long_obs as f64).max(0.001);
518
519        d11 = random_effects.iter().map(|re| re[0].powi(2)).sum::<f64>() / n_subjects as f64;
520        d22 = random_effects.iter().map(|re| re[1].powi(2)).sum::<f64>() / n_subjects as f64;
521        d11 = d11.max(0.001);
522        d22 = d22.max(0.001);
523
524        let mut gradient_alpha = 0.0;
525        let mut hessian_alpha = 0.0;
526
527        for i in 0..n_subjects {
528            let b0 = random_effects[i][0];
529            let b1 = random_effects[i][1];
530
531            let x_long_i: Vec<f64> = (0..n_long_vars)
532                .map(|k| x_longitudinal[i * n_long_vars + k])
533                .collect();
534
535            let m_t = longitudinal_model_value(event_time[i], &beta_long, &x_long_i, b0, b1);
536
537            if event_status[i] == 1 {
538                gradient_alpha += m_t;
539            }
540
541            let mut cum_haz = 0.0;
542            for h in &baseline_hazard {
543                cum_haz += h;
544            }
545
546            let mut eta = 0.0;
547            for (k, &xk) in x_survival[i * n_surv_vars..(i + 1) * n_surv_vars]
548                .iter()
549                .enumerate()
550            {
551                if k < gamma_surv.len() {
552                    eta += gamma_surv[k] * xk;
553                }
554            }
555            eta += alpha * m_t;
556
557            gradient_alpha -= cum_haz * m_t * eta.exp();
558            hessian_alpha += cum_haz * m_t * m_t * eta.exp();
559        }
560
561        if hessian_alpha.abs() > 1e-10 {
562            alpha += 0.1 * gradient_alpha / hessian_alpha;
563        }
564
565        let log_lik: f64 = (0..n_subjects)
566            .into_par_iter()
567            .map(|i| {
568                let subj_indices = &subject_indices[i];
569
570                let y_i: Vec<f64> = subj_indices.iter().map(|&j| y_longitudinal[j]).collect();
571                let t_i: Vec<f64> = subj_indices
572                    .iter()
573                    .map(|&j| times_longitudinal[j])
574                    .collect();
575                let x_long_i: Vec<f64> = {
576                    let mut result = Vec::with_capacity(subj_indices.len() * n_long_vars);
577                    for &j in subj_indices {
578                        for k in 0..n_long_vars {
579                            result.push(x_longitudinal[j * n_long_vars + k]);
580                        }
581                    }
582                    result
583                };
584                let x_surv_i: Vec<f64> = (0..n_surv_vars)
585                    .map(|k| x_survival[i * n_surv_vars + k])
586                    .collect();
587
588                let b0 = random_effects[i][0];
589                let b1 = random_effects[i][1];
590
591                let ll_long = compute_longitudinal_contribution(
592                    &y_i,
593                    &t_i,
594                    &beta_long,
595                    &x_long_i,
596                    n_long_vars,
597                    b0,
598                    b1,
599                    sigma_sq,
600                );
601
602                let ll_surv = compute_survival_contribution(
603                    event_time[i],
604                    event_status[i],
605                    &x_surv_i,
606                    &gamma_surv,
607                    alpha,
608                    &beta_long,
609                    &x_long_i,
610                    b0,
611                    b1,
612                    &baseline_hazard,
613                    &baseline_times,
614                    &config.association,
615                );
616
617                ll_long + ll_surv
618            })
619            .sum();
620
621        if (log_lik - prev_log_lik).abs() < config.tol {
622            converged = true;
623            break;
624        }
625        prev_log_lik = log_lik;
626    }
627
628    let n_params = n_long_vars + n_surv_vars + 1 + 3;
629    let aic = -2.0 * prev_log_lik + 2.0 * n_params as f64;
630    let bic = -2.0 * prev_log_lik + (n_params as f64) * (n_subjects as f64).ln();
631
632    let longitudinal_fixed_se = vec![0.1; n_long_vars];
633    let survival_fixed_se = vec![0.1; n_surv_vars];
634    let association_se = 0.1;
635
636    Ok(JointModelResult {
637        longitudinal_fixed: beta_long,
638        longitudinal_fixed_se,
639        survival_fixed: gamma_surv,
640        survival_fixed_se,
641        association_param: alpha,
642        association_se,
643        random_effects_var: vec![d11, d22],
644        residual_var: sigma_sq,
645        baseline_hazard,
646        baseline_hazard_times: baseline_times,
647        log_likelihood: prev_log_lik,
648        aic,
649        bic,
650        n_iter,
651        converged,
652        random_effects,
653    })
654}
655
656#[cfg(test)]
657mod tests {
658    use super::*;
659
660    #[test]
661    fn test_longitudinal_model_value() {
662        let beta = vec![1.0, 0.5];
663        let x_fixed = vec![1.0, 2.0];
664        let val = longitudinal_model_value(2.0, &beta, &x_fixed, 0.5, 0.1);
665        assert!(val.is_finite());
666    }
667
668    #[test]
669    fn test_joint_model_config() {
670        let config = JointModelConfig::new(AssociationStructure::Value, 15, 100, 1e-4, 5);
671        assert_eq!(config.n_quadrature, 15);
672    }
673}