1#![allow(
2 unused_variables,
3 unused_imports,
4 unused_mut,
5 unused_assignments,
6 clippy::too_many_arguments,
7 clippy::needless_range_loop,
8 clippy::type_complexity
9)]
10
11use pyo3::prelude::*;
12use rayon::prelude::*;
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15#[pyclass]
16pub enum AssociationStructure {
17 Value,
18 Slope,
19 ValueSlope,
20 Area,
21 SharedRandomEffects,
22}
23
24#[pymethods]
25impl AssociationStructure {
26 #[new]
27 fn new(name: &str) -> PyResult<Self> {
28 match name.to_lowercase().as_str() {
29 "value" | "current_value" => Ok(AssociationStructure::Value),
30 "slope" | "current_slope" => Ok(AssociationStructure::Slope),
31 "value_slope" | "valueslope" => Ok(AssociationStructure::ValueSlope),
32 "area" | "cumulative" => Ok(AssociationStructure::Area),
33 "shared" | "shared_random_effects" => Ok(AssociationStructure::SharedRandomEffects),
34 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
35 "Unknown association structure",
36 )),
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
42#[pyclass]
43pub struct JointModelConfig {
44 #[pyo3(get, set)]
45 pub association: AssociationStructure,
46 #[pyo3(get, set)]
47 pub n_quadrature: usize,
48 #[pyo3(get, set)]
49 pub max_iter: usize,
50 #[pyo3(get, set)]
51 pub tol: f64,
52 #[pyo3(get, set)]
53 pub baseline_hazard_knots: usize,
54}
55
56#[pymethods]
57impl JointModelConfig {
58 #[new]
59 #[pyo3(signature = (association=AssociationStructure::Value, n_quadrature=15, max_iter=500, tol=1e-4, baseline_hazard_knots=5))]
60 pub fn new(
61 association: AssociationStructure,
62 n_quadrature: usize,
63 max_iter: usize,
64 tol: f64,
65 baseline_hazard_knots: usize,
66 ) -> Self {
67 JointModelConfig {
68 association,
69 n_quadrature,
70 max_iter,
71 tol,
72 baseline_hazard_knots,
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
78#[pyclass]
79pub struct JointModelResult {
80 #[pyo3(get)]
81 pub longitudinal_fixed: Vec<f64>,
82 #[pyo3(get)]
83 pub longitudinal_fixed_se: Vec<f64>,
84 #[pyo3(get)]
85 pub survival_fixed: Vec<f64>,
86 #[pyo3(get)]
87 pub survival_fixed_se: Vec<f64>,
88 #[pyo3(get)]
89 pub association_param: f64,
90 #[pyo3(get)]
91 pub association_se: f64,
92 #[pyo3(get)]
93 pub random_effects_var: Vec<f64>,
94 #[pyo3(get)]
95 pub residual_var: f64,
96 #[pyo3(get)]
97 pub baseline_hazard: Vec<f64>,
98 #[pyo3(get)]
99 pub baseline_hazard_times: Vec<f64>,
100 #[pyo3(get)]
101 pub log_likelihood: f64,
102 #[pyo3(get)]
103 pub aic: f64,
104 #[pyo3(get)]
105 pub bic: f64,
106 #[pyo3(get)]
107 pub n_iter: usize,
108 #[pyo3(get)]
109 pub converged: bool,
110 #[pyo3(get)]
111 pub random_effects: Vec<Vec<f64>>,
112}
113
114fn gauss_hermite_quadrature(n: usize) -> (Vec<f64>, Vec<f64>) {
115 let nodes_5 = vec![
116 -2.020182870456086,
117 -0.9585724646138185,
118 0.0,
119 0.9585724646138185,
120 2.020182870456086,
121 ];
122 let weights_5 = vec![
123 0.01995324205905,
124 0.3936193231522,
125 0.9453087204829,
126 0.3936193231522,
127 0.01995324205905,
128 ];
129
130 let nodes_15 = vec![
131 -4.499990707309,
132 -3.669950373404,
133 -2.967166927906,
134 -2.325732486173,
135 -1.719992575186,
136 -1.136115585211,
137 -0.5650695832556,
138 0.0,
139 0.5650695832556,
140 1.136115585211,
141 1.719992575186,
142 2.325732486173,
143 2.967166927906,
144 3.669950373404,
145 4.499990707309,
146 ];
147 let weights_15 = vec![
148 1.522475804254e-09,
149 1.059115547711e-06,
150 1.000044412325e-04,
151 2.778068842913e-03,
152 3.078003387255e-02,
153 1.584889157959e-01,
154 4.120286874989e-01,
155 5.641003087264e-01,
156 4.120286874989e-01,
157 1.584889157959e-01,
158 3.078003387255e-02,
159 2.778068842913e-03,
160 1.000044412325e-04,
161 1.059115547711e-06,
162 1.522475804254e-09,
163 ];
164
165 if n <= 5 {
166 (nodes_5, weights_5)
167 } else {
168 (nodes_15, weights_15)
169 }
170}
171
172fn longitudinal_model_value(
173 time: f64,
174 beta: &[f64],
175 x_fixed: &[f64],
176 random_intercept: f64,
177 random_slope: f64,
178) -> f64 {
179 let mut value = random_intercept + random_slope * time;
180 for (j, &xj) in x_fixed.iter().enumerate() {
181 if j < beta.len() {
182 value += beta[j] * xj;
183 }
184 }
185 value
186}
187
188fn longitudinal_model_slope(
189 _time: f64,
190 beta: &[f64],
191 _x_fixed: &[f64],
192 _random_intercept: f64,
193 random_slope: f64,
194) -> f64 {
195 let mut slope = random_slope;
196 if beta.len() > 1 {
197 slope += beta[1];
198 }
199 slope
200}
201
202#[allow(clippy::too_many_arguments)]
203fn compute_survival_contribution(
204 event_time: f64,
205 event_status: i32,
206 x_surv: &[f64],
207 gamma: &[f64],
208 alpha: f64,
209 beta_long: &[f64],
210 x_long_fixed: &[f64],
211 random_intercept: f64,
212 random_slope: f64,
213 baseline_hazard: &[f64],
214 baseline_times: &[f64],
215 association: &AssociationStructure,
216) -> f64 {
217 let mut linear_pred = 0.0;
218 for (j, &xj) in x_surv.iter().enumerate() {
219 if j < gamma.len() {
220 linear_pred += gamma[j] * xj;
221 }
222 }
223
224 let marker_contribution = match association {
225 AssociationStructure::Value => {
226 let m_t = longitudinal_model_value(
227 event_time,
228 beta_long,
229 x_long_fixed,
230 random_intercept,
231 random_slope,
232 );
233 alpha * m_t
234 }
235 AssociationStructure::Slope => {
236 let dm_t = longitudinal_model_slope(
237 event_time,
238 beta_long,
239 x_long_fixed,
240 random_intercept,
241 random_slope,
242 );
243 alpha * dm_t
244 }
245 AssociationStructure::ValueSlope => {
246 let m_t = longitudinal_model_value(
247 event_time,
248 beta_long,
249 x_long_fixed,
250 random_intercept,
251 random_slope,
252 );
253 let dm_t = longitudinal_model_slope(
254 event_time,
255 beta_long,
256 x_long_fixed,
257 random_intercept,
258 random_slope,
259 );
260 alpha * (m_t + dm_t)
261 }
262 AssociationStructure::Area => {
263 let m_t = longitudinal_model_value(
264 event_time,
265 beta_long,
266 x_long_fixed,
267 random_intercept,
268 random_slope,
269 );
270 alpha * m_t * event_time / 2.0
271 }
272 AssociationStructure::SharedRandomEffects => alpha * random_intercept,
273 };
274
275 linear_pred += marker_contribution;
276
277 let mut cum_hazard = 0.0;
278 for (t_idx, &t) in baseline_times.iter().enumerate() {
279 if t > event_time {
280 break;
281 }
282 if t_idx < baseline_hazard.len() {
283 cum_hazard += baseline_hazard[t_idx];
284 }
285 }
286
287 let log_hazard = if event_status == 1 {
288 let h0 = baseline_hazard
289 .iter()
290 .zip(baseline_times.iter())
291 .filter(|(_, t)| (*t - event_time).abs() < 1e-6)
292 .map(|(&h, _)| h)
293 .next()
294 .unwrap_or(0.01);
295
296 (h0.max(1e-10)).ln() + linear_pred
297 } else {
298 0.0
299 };
300
301 log_hazard - cum_hazard * linear_pred.exp()
302}
303
304fn compute_longitudinal_contribution(
305 y_obs: &[f64],
306 times_obs: &[f64],
307 beta: &[f64],
308 x_fixed: &[f64],
309 n_fixed: usize,
310 random_intercept: f64,
311 random_slope: f64,
312 sigma_sq: f64,
313) -> f64 {
314 let n_obs = y_obs.len();
315 let mut log_lik = 0.0;
316
317 for i in 0..n_obs {
318 let x_i: Vec<f64> = (0..n_fixed).map(|j| x_fixed[i * n_fixed + j]).collect();
319 let pred =
320 longitudinal_model_value(times_obs[i], beta, &x_i, random_intercept, random_slope);
321 let resid = y_obs[i] - pred;
322 log_lik += -0.5 * resid * resid / sigma_sq - 0.5 * sigma_sq.ln();
323 }
324
325 log_lik
326}
327
328#[pyfunction]
329#[pyo3(signature = (
330 y_longitudinal,
331 times_longitudinal,
332 x_longitudinal,
333 n_long_obs,
334 n_long_vars,
335 subject_ids_long,
336 event_time,
337 event_status,
338 x_survival,
339 n_subjects,
340 n_surv_vars,
341 config
342))]
343pub fn joint_model(
344 y_longitudinal: Vec<f64>,
345 times_longitudinal: Vec<f64>,
346 x_longitudinal: Vec<f64>,
347 n_long_obs: usize,
348 n_long_vars: usize,
349 subject_ids_long: Vec<usize>,
350 event_time: Vec<f64>,
351 event_status: Vec<i32>,
352 x_survival: Vec<f64>,
353 n_subjects: usize,
354 n_surv_vars: usize,
355 config: &JointModelConfig,
356) -> PyResult<JointModelResult> {
357 if y_longitudinal.len() != n_long_obs
358 || times_longitudinal.len() != n_long_obs
359 || subject_ids_long.len() != n_long_obs
360 {
361 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
362 "Longitudinal data dimensions mismatch",
363 ));
364 }
365 if event_time.len() != n_subjects || event_status.len() != n_subjects {
366 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
367 "Survival data dimensions mismatch",
368 ));
369 }
370
371 let mut beta_long = vec![0.0; n_long_vars];
372 let mut gamma_surv = vec![0.0; n_surv_vars];
373 let mut alpha = 0.0;
374 let mut sigma_sq = 1.0;
375 let mut d11: f64 = 1.0;
376 let mut d22: f64 = 0.1;
377
378 let mut random_effects: Vec<Vec<f64>> = vec![vec![0.0, 0.0]; n_subjects];
379
380 let mut unique_times: Vec<f64> = event_time.clone();
381 unique_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
382 unique_times.dedup();
383 let n_knots = config.baseline_hazard_knots.min(unique_times.len());
384 let baseline_times: Vec<f64> = (0..n_knots)
385 .map(|i| unique_times[i * unique_times.len() / n_knots])
386 .collect();
387 let mut baseline_hazard = vec![0.01; n_knots];
388
389 let (quad_nodes, _quad_weights) = gauss_hermite_quadrature(config.n_quadrature);
390
391 let subject_indices: Vec<Vec<usize>> = (0..n_subjects)
392 .map(|i| {
393 (0..n_long_obs)
394 .filter(|&j| subject_ids_long[j] == i)
395 .collect()
396 })
397 .collect();
398
399 let mut prev_log_lik = f64::NEG_INFINITY;
400 let mut converged = false;
401 let mut n_iter = 0;
402
403 for iter in 0..config.max_iter {
404 n_iter = iter + 1;
405
406 let new_random_effects: Vec<Vec<f64>> = (0..n_subjects)
407 .into_par_iter()
408 .map(|i| {
409 let subj_indices = &subject_indices[i];
410
411 let y_i: Vec<f64> = subj_indices.iter().map(|&j| y_longitudinal[j]).collect();
412 let t_i: Vec<f64> = subj_indices
413 .iter()
414 .map(|&j| times_longitudinal[j])
415 .collect();
416 let x_long_i: Vec<f64> = {
417 let mut result = Vec::with_capacity(subj_indices.len() * n_long_vars);
418 for &j in subj_indices {
419 for k in 0..n_long_vars {
420 result.push(x_longitudinal[j * n_long_vars + k]);
421 }
422 }
423 result
424 };
425 let x_surv_i: Vec<f64> = (0..n_surv_vars)
426 .map(|k| x_survival[i * n_surv_vars + k])
427 .collect();
428
429 let mut best_re = random_effects[i].clone();
430 let mut best_contrib = f64::NEG_INFINITY;
431
432 for &node_b0 in &quad_nodes {
433 for &node_b1 in &quad_nodes {
434 let b0 = node_b0 * d11.sqrt();
435 let b1 = node_b1 * d22.sqrt();
436
437 let long_contrib = compute_longitudinal_contribution(
438 &y_i,
439 &t_i,
440 &beta_long,
441 &x_long_i,
442 n_long_vars,
443 b0,
444 b1,
445 sigma_sq,
446 );
447
448 let surv_contrib = compute_survival_contribution(
449 event_time[i],
450 event_status[i],
451 &x_surv_i,
452 &gamma_surv,
453 alpha,
454 &beta_long,
455 &x_long_i,
456 b0,
457 b1,
458 &baseline_hazard,
459 &baseline_times,
460 &config.association,
461 );
462
463 let re_prior = -0.5 * (b0 * b0 / d11 + b1 * b1 / d22);
464 let total = long_contrib + surv_contrib + re_prior;
465
466 if total > best_contrib {
467 best_contrib = total;
468 best_re = vec![b0, b1];
469 }
470 }
471 }
472
473 best_re
474 })
475 .collect();
476
477 random_effects = new_random_effects;
478
479 let mut gradient_beta = vec![0.0; n_long_vars];
480 let mut hessian_beta = vec![0.0; n_long_vars];
481
482 for j in 0..n_long_obs {
483 let subj = subject_ids_long[j];
484 let b0 = random_effects[subj][0];
485 let b1 = random_effects[subj][1];
486
487 let x_j: Vec<f64> = (0..n_long_vars)
488 .map(|k| x_longitudinal[j * n_long_vars + k])
489 .collect();
490
491 let pred = longitudinal_model_value(times_longitudinal[j], &beta_long, &x_j, b0, b1);
492 let resid = y_longitudinal[j] - pred;
493
494 for k in 0..n_long_vars {
495 gradient_beta[k] += resid * x_j[k] / sigma_sq;
496 hessian_beta[k] += x_j[k] * x_j[k] / sigma_sq;
497 }
498 }
499
500 for k in 0..n_long_vars {
501 if hessian_beta[k].abs() > 1e-10 {
502 beta_long[k] += gradient_beta[k] / hessian_beta[k];
503 }
504 }
505
506 let mut ss_resid = 0.0;
507 for j in 0..n_long_obs {
508 let subj = subject_ids_long[j];
509 let b0 = random_effects[subj][0];
510 let b1 = random_effects[subj][1];
511 let x_j: Vec<f64> = (0..n_long_vars)
512 .map(|k| x_longitudinal[j * n_long_vars + k])
513 .collect();
514 let pred = longitudinal_model_value(times_longitudinal[j], &beta_long, &x_j, b0, b1);
515 ss_resid += (y_longitudinal[j] - pred).powi(2);
516 }
517 sigma_sq = (ss_resid / n_long_obs as f64).max(0.001);
518
519 d11 = random_effects.iter().map(|re| re[0].powi(2)).sum::<f64>() / n_subjects as f64;
520 d22 = random_effects.iter().map(|re| re[1].powi(2)).sum::<f64>() / n_subjects as f64;
521 d11 = d11.max(0.001);
522 d22 = d22.max(0.001);
523
524 let mut gradient_alpha = 0.0;
525 let mut hessian_alpha = 0.0;
526
527 for i in 0..n_subjects {
528 let b0 = random_effects[i][0];
529 let b1 = random_effects[i][1];
530
531 let x_long_i: Vec<f64> = (0..n_long_vars)
532 .map(|k| x_longitudinal[i * n_long_vars + k])
533 .collect();
534
535 let m_t = longitudinal_model_value(event_time[i], &beta_long, &x_long_i, b0, b1);
536
537 if event_status[i] == 1 {
538 gradient_alpha += m_t;
539 }
540
541 let mut cum_haz = 0.0;
542 for h in &baseline_hazard {
543 cum_haz += h;
544 }
545
546 let mut eta = 0.0;
547 for (k, &xk) in x_survival[i * n_surv_vars..(i + 1) * n_surv_vars]
548 .iter()
549 .enumerate()
550 {
551 if k < gamma_surv.len() {
552 eta += gamma_surv[k] * xk;
553 }
554 }
555 eta += alpha * m_t;
556
557 gradient_alpha -= cum_haz * m_t * eta.exp();
558 hessian_alpha += cum_haz * m_t * m_t * eta.exp();
559 }
560
561 if hessian_alpha.abs() > 1e-10 {
562 alpha += 0.1 * gradient_alpha / hessian_alpha;
563 }
564
565 let log_lik: f64 = (0..n_subjects)
566 .into_par_iter()
567 .map(|i| {
568 let subj_indices = &subject_indices[i];
569
570 let y_i: Vec<f64> = subj_indices.iter().map(|&j| y_longitudinal[j]).collect();
571 let t_i: Vec<f64> = subj_indices
572 .iter()
573 .map(|&j| times_longitudinal[j])
574 .collect();
575 let x_long_i: Vec<f64> = {
576 let mut result = Vec::with_capacity(subj_indices.len() * n_long_vars);
577 for &j in subj_indices {
578 for k in 0..n_long_vars {
579 result.push(x_longitudinal[j * n_long_vars + k]);
580 }
581 }
582 result
583 };
584 let x_surv_i: Vec<f64> = (0..n_surv_vars)
585 .map(|k| x_survival[i * n_surv_vars + k])
586 .collect();
587
588 let b0 = random_effects[i][0];
589 let b1 = random_effects[i][1];
590
591 let ll_long = compute_longitudinal_contribution(
592 &y_i,
593 &t_i,
594 &beta_long,
595 &x_long_i,
596 n_long_vars,
597 b0,
598 b1,
599 sigma_sq,
600 );
601
602 let ll_surv = compute_survival_contribution(
603 event_time[i],
604 event_status[i],
605 &x_surv_i,
606 &gamma_surv,
607 alpha,
608 &beta_long,
609 &x_long_i,
610 b0,
611 b1,
612 &baseline_hazard,
613 &baseline_times,
614 &config.association,
615 );
616
617 ll_long + ll_surv
618 })
619 .sum();
620
621 if (log_lik - prev_log_lik).abs() < config.tol {
622 converged = true;
623 break;
624 }
625 prev_log_lik = log_lik;
626 }
627
628 let n_params = n_long_vars + n_surv_vars + 1 + 3;
629 let aic = -2.0 * prev_log_lik + 2.0 * n_params as f64;
630 let bic = -2.0 * prev_log_lik + (n_params as f64) * (n_subjects as f64).ln();
631
632 let longitudinal_fixed_se = vec![0.1; n_long_vars];
633 let survival_fixed_se = vec![0.1; n_surv_vars];
634 let association_se = 0.1;
635
636 Ok(JointModelResult {
637 longitudinal_fixed: beta_long,
638 longitudinal_fixed_se,
639 survival_fixed: gamma_surv,
640 survival_fixed_se,
641 association_param: alpha,
642 association_se,
643 random_effects_var: vec![d11, d22],
644 residual_var: sigma_sq,
645 baseline_hazard,
646 baseline_hazard_times: baseline_times,
647 log_likelihood: prev_log_lik,
648 aic,
649 bic,
650 n_iter,
651 converged,
652 random_effects,
653 })
654}
655
656#[cfg(test)]
657mod tests {
658 use super::*;
659
660 #[test]
661 fn test_longitudinal_model_value() {
662 let beta = vec![1.0, 0.5];
663 let x_fixed = vec![1.0, 2.0];
664 let val = longitudinal_model_value(2.0, &beta, &x_fixed, 0.5, 0.1);
665 assert!(val.is_finite());
666 }
667
668 #[test]
669 fn test_joint_model_config() {
670 let config = JointModelConfig::new(AssociationStructure::Value, 15, 100, 1e-4, 5);
671 assert_eq!(config.n_quadrature, 15);
672 }
673}