1use crate::error::{StatsError, StatsResult};
12
13use super::shared_frailty::SharedFrailtyModel;
14use super::types::{ClusterInfo, FrailtyConfig, FrailtyDistribution, FrailtyResult};
15
16#[derive(Debug, Clone)]
22pub struct NestedFrailtyResult {
23 pub coefficients: Vec<f64>,
25 pub inner_variance: f64,
27 pub outer_variance: f64,
29 pub inner_frailty_estimates: Vec<f64>,
31 pub outer_frailty_estimates: Vec<f64>,
33 pub log_likelihood_history: Vec<f64>,
35 pub converged: bool,
37 pub iterations: usize,
39 pub baseline_hazard: Vec<(f64, f64)>,
41}
42
43#[derive(Debug, Clone)]
56pub struct NestedFrailtyModel {
57 config: FrailtyConfig,
58}
59
60impl NestedFrailtyModel {
61 pub fn new(config: FrailtyConfig) -> Self {
63 Self { config }
64 }
65
66 pub fn fit(
78 &self,
79 times: &[f64],
80 events: &[bool],
81 covariates: &[&[f64]],
82 inner_clusters: &[usize],
83 outer_clusters: &[usize],
84 ) -> StatsResult<NestedFrailtyResult> {
85 let n = times.len();
86
87 if n == 0 {
89 return Err(StatsError::InvalidArgument(
90 "times must not be empty".into(),
91 ));
92 }
93 if events.len() != n
94 || covariates.len() != n
95 || inner_clusters.len() != n
96 || outer_clusters.len() != n
97 {
98 return Err(StatsError::DimensionMismatch(format!(
99 "All input arrays must have length {n}"
100 )));
101 }
102 let n_events = events.iter().filter(|&&e| e).count();
103 if n_events == 0 {
104 return Err(StatsError::InvalidArgument("No events observed".into()));
105 }
106 for &t in times {
107 if !t.is_finite() || t < 0.0 {
108 return Err(StatsError::InvalidArgument(format!(
109 "times must be finite and non-negative; got {t}"
110 )));
111 }
112 }
113
114 let inner_infos = build_cluster_infos(inner_clusters, events)?;
116 let outer_infos = build_cluster_infos(outer_clusters, events)?;
117
118 if inner_infos.len() < 2 {
119 return Err(StatsError::InvalidArgument(
120 "At least two inner clusters are required".into(),
121 ));
122 }
123 if outer_infos.len() < 2 {
124 return Err(StatsError::InvalidArgument(
125 "At least two outer clusters are required".into(),
126 ));
127 }
128
129 let mut subj_to_inner = vec![0_usize; n];
131 for (k, ci) in inner_infos.iter().enumerate() {
132 for &idx in &ci.subject_indices {
133 subj_to_inner[idx] = k;
134 }
135 }
136 let mut subj_to_outer = vec![0_usize; n];
137 for (k, ci) in outer_infos.iter().enumerate() {
138 for &idx in &ci.subject_indices {
139 subj_to_outer[idx] = k;
140 }
141 }
142
143 let inner_to_outer: Vec<usize> = inner_infos
145 .iter()
146 .map(|ci| {
147 if ci.subject_indices.is_empty() {
148 0
149 } else {
150 subj_to_outer[ci.subject_indices[0]]
152 }
153 })
154 .collect();
155
156 let p = if !covariates.is_empty() {
157 covariates[0].len()
158 } else {
159 0
160 };
161
162 let mut beta = vec![0.0_f64; p];
164 let mut theta_inner = self.config.initial_variance;
165 let mut theta_outer = self.config.initial_variance;
166 let mut inner_frailties = vec![1.0_f64; inner_infos.len()];
167 let mut outer_frailties = vec![1.0_f64; outer_infos.len()];
168 let mut ll_history: Vec<f64> = Vec::new();
169 let mut converged = false;
170 let mut iterations = 0_usize;
171
172 let mut order: Vec<usize> = (0..n).collect();
174 order.sort_by(|&a, &b| {
175 times[a]
176 .partial_cmp(×[b])
177 .unwrap_or(std::cmp::Ordering::Equal)
178 });
179
180 for iter in 0..self.config.max_iterations {
182 iterations = iter + 1;
183
184 let combined_frailties: Vec<f64> = (0..n)
186 .map(|j| outer_frailties[subj_to_outer[j]] * inner_frailties[subj_to_inner[j]])
187 .collect();
188
189 let exp_xb: Vec<f64> = (0..n)
191 .map(|j| {
192 let mut lin = 0.0_f64;
193 for col in 0..p {
194 lin += covariates[j][col] * beta[col];
195 }
196 lin.exp()
197 })
198 .collect();
199
200 let inner_cum_haz = compute_cluster_cumulative_hazard(
202 &inner_infos,
203 &exp_xb,
204 &combined_frailties,
205 times,
206 events,
207 &order,
208 );
209
210 update_frailties_gamma(
212 &mut inner_frailties,
213 &inner_infos,
214 &inner_cum_haz,
215 theta_inner,
216 );
217
218 let combined_frailties2: Vec<f64> = (0..n)
220 .map(|j| outer_frailties[subj_to_outer[j]] * inner_frailties[subj_to_inner[j]])
221 .collect();
222
223 let outer_cum_haz = compute_cluster_cumulative_hazard(
225 &outer_infos,
226 &exp_xb,
227 &combined_frailties2,
228 times,
229 events,
230 &order,
231 );
232
233 update_frailties_gamma(
235 &mut outer_frailties,
236 &outer_infos,
237 &outer_cum_haz,
238 theta_outer,
239 );
240
241 theta_inner = moment_variance(&inner_frailties).max(1e-8);
243 theta_outer = moment_variance(&outer_frailties).max(1e-8);
244
245 if p > 0 {
247 let combined: Vec<f64> = (0..n)
248 .map(|j| outer_frailties[subj_to_outer[j]] * inner_frailties[subj_to_inner[j]])
249 .collect();
250 update_beta_newton(&mut beta, times, events, covariates, &combined, &order);
251 }
252
253 let ll = nested_log_likelihood(
255 times,
256 events,
257 covariates,
258 &beta,
259 &inner_frailties,
260 &outer_frailties,
261 &subj_to_inner,
262 &subj_to_outer,
263 &order,
264 theta_inner,
265 theta_outer,
266 &inner_infos,
267 &outer_infos,
268 );
269 ll_history.push(ll);
270
271 if ll_history.len() >= 2 {
273 let prev = ll_history[ll_history.len() - 2];
274 let rel_change = if prev.abs() > 1e-12 {
275 (ll - prev).abs() / prev.abs()
276 } else {
277 (ll - prev).abs()
278 };
279 if rel_change < self.config.tolerance {
280 converged = true;
281 break;
282 }
283 }
284 }
285
286 let combined: Vec<f64> = (0..n)
288 .map(|j| outer_frailties[subj_to_outer[j]] * inner_frailties[subj_to_inner[j]])
289 .collect();
290 let exp_xb: Vec<f64> = (0..n)
291 .map(|j| {
292 let mut lin = 0.0_f64;
293 for col in 0..p {
294 lin += covariates[j][col] * beta[col];
295 }
296 lin.exp()
297 })
298 .collect();
299 let baseline_hazard = compute_baseline_hazard(times, events, &exp_xb, &combined, &order);
300
301 Ok(NestedFrailtyResult {
302 coefficients: beta,
303 inner_variance: theta_inner,
304 outer_variance: theta_outer,
305 inner_frailty_estimates: inner_frailties,
306 outer_frailty_estimates: outer_frailties,
307 log_likelihood_history: ll_history,
308 converged,
309 iterations,
310 baseline_hazard,
311 })
312 }
313}
314
315fn build_cluster_infos(clusters: &[usize], events: &[bool]) -> StatsResult<Vec<ClusterInfo>> {
320 let max_id = clusters.iter().copied().max().unwrap_or(0);
321 let mut buckets: Vec<Vec<usize>> = vec![Vec::new(); max_id + 1];
322 for (i, &c) in clusters.iter().enumerate() {
323 buckets[c].push(i);
324 }
325 Ok(buckets
326 .into_iter()
327 .enumerate()
328 .filter(|(_, indices)| !indices.is_empty())
329 .map(|(id, indices)| ClusterInfo::new(id, indices, events))
330 .collect())
331}
332
333fn compute_cluster_cumulative_hazard(
334 cluster_infos: &[ClusterInfo],
335 exp_xb: &[f64],
336 combined_frailties: &[f64],
337 times: &[f64],
338 events: &[bool],
339 order: &[usize],
340) -> Vec<f64> {
341 let n = times.len();
342
343 let mut risk_sum = 0.0_f64;
345 for j in 0..n {
346 risk_sum += combined_frailties[j] * exp_xb[j];
347 }
348
349 let mut cum_h0_at = vec![0.0_f64; n];
350 let mut cum_h0 = 0.0_f64;
351 let mut risk_ptr = 0_usize;
352 for &idx in order {
353 while risk_ptr < order.len() && times[order[risk_ptr]] < times[idx] - 1e-15 {
354 let rem = order[risk_ptr];
355 risk_sum -= combined_frailties[rem] * exp_xb[rem];
356 risk_ptr += 1;
357 }
358 if events[idx] && risk_sum > 1e-30 {
359 cum_h0 += 1.0 / risk_sum;
360 }
361 cum_h0_at[idx] = cum_h0;
362 }
363
364 cluster_infos
365 .iter()
366 .map(|ci| {
367 ci.subject_indices
368 .iter()
369 .map(|&j| exp_xb[j] * cum_h0_at[j])
370 .sum::<f64>()
371 })
372 .collect()
373}
374
375fn update_frailties_gamma(
376 frailties: &mut [f64],
377 cluster_infos: &[ClusterInfo],
378 cum_hazard: &[f64],
379 theta: f64,
380) {
381 let inv_theta = 1.0 / theta.max(1e-15);
382 for (k, ci) in cluster_infos.iter().enumerate() {
383 let d_i = ci.n_events as f64;
384 let h_i = cum_hazard[k];
385 frailties[k] = (d_i + inv_theta) / (h_i + inv_theta);
386 }
387}
388
389fn moment_variance(frailties: &[f64]) -> f64 {
390 let k = frailties.len() as f64;
391 if k < 1.0 {
392 return 1.0;
393 }
394 let mean = frailties.iter().sum::<f64>() / k;
395 frailties.iter().map(|&u| (u - mean).powi(2)).sum::<f64>() / k
396}
397
398fn update_beta_newton(
399 beta: &mut [f64],
400 times: &[f64],
401 events: &[bool],
402 covariates: &[&[f64]],
403 combined_frailties: &[f64],
404 order: &[usize],
405) {
406 let p = beta.len();
407 let n = times.len();
408 if p == 0 {
409 return;
410 }
411
412 let exp_xb: Vec<f64> = (0..n)
413 .map(|j| {
414 let mut lin = 0.0_f64;
415 for col in 0..p {
416 lin += covariates[j][col] * beta[col];
417 }
418 lin.exp()
419 })
420 .collect();
421
422 let mut gradient = vec![0.0_f64; p];
423 let mut hessian_diag = vec![0.0_f64; p];
424
425 let mut s0 = 0.0_f64;
426 let mut s1 = vec![0.0_f64; p];
427
428 for j in 0..n {
429 let w = combined_frailties[j] * exp_xb[j];
430 s0 += w;
431 for col in 0..p {
432 s1[col] += w * covariates[j][col];
433 }
434 }
435
436 let mut risk_ptr = 0_usize;
437 for &idx in order {
438 while risk_ptr < order.len() && times[order[risk_ptr]] < times[idx] - 1e-15 {
439 let rem = order[risk_ptr];
440 let w = combined_frailties[rem] * exp_xb[rem];
441 s0 -= w;
442 for col in 0..p {
443 s1[col] -= w * covariates[rem][col];
444 }
445 risk_ptr += 1;
446 }
447 if events[idx] && s0 > 1e-30 {
448 for col in 0..p {
449 let mean_x = s1[col] / s0;
450 gradient[col] += covariates[idx][col] - mean_x;
451 hessian_diag[col] -= 1.0; }
454 }
455 }
456
457 let step_size = 0.3_f64;
459 let ridge = 1e-3;
460 for col in 0..p {
461 let h = hessian_diag[col] - ridge;
462 if h.abs() > 1e-30 {
463 beta[col] += step_size * gradient[col] / (-h);
464 }
465 }
466}
467
468fn compute_baseline_hazard(
469 times: &[f64],
470 events: &[bool],
471 exp_xb: &[f64],
472 combined_frailties: &[f64],
473 order: &[usize],
474) -> Vec<(f64, f64)> {
475 let n = times.len();
476 let mut risk_sum = 0.0_f64;
477 for j in 0..n {
478 risk_sum += combined_frailties[j] * exp_xb[j];
479 }
480
481 let mut baseline = Vec::new();
482 let mut cum_h0 = 0.0_f64;
483 let mut risk_ptr = 0_usize;
484 for &idx in order {
485 while risk_ptr < order.len() && times[order[risk_ptr]] < times[idx] - 1e-15 {
486 let rem = order[risk_ptr];
487 risk_sum -= combined_frailties[rem] * exp_xb[rem];
488 risk_ptr += 1;
489 }
490 if events[idx] && risk_sum > 1e-30 {
491 cum_h0 += 1.0 / risk_sum;
492 baseline.push((times[idx], cum_h0));
493 }
494 }
495 baseline.dedup_by(|a, b| (a.0 - b.0).abs() < 1e-15);
496 baseline
497}
498
499fn lgamma(x: f64) -> f64 {
500 let c = [
501 0.999_999_999_999_809_93,
502 676.520_368_121_885_10,
503 -1_259.139_216_722_402_8,
504 771.323_428_777_653_10,
505 -176.615_029_162_140_60,
506 12.507_343_278_686_905,
507 -0.138_571_095_265_720_12,
508 9.984_369_578_019_572e-6,
509 1.505_632_735_149_311_6e-7,
510 ];
511 let x = x - 1.0;
512 let mut ser = c[0];
513 for (i, &ci) in c[1..].iter().enumerate() {
514 ser += ci / (x + i as f64 + 1.0);
515 }
516 let tmp = x + 7.5;
517 0.5 * std::f64::consts::TAU.ln() + (x + 0.5) * tmp.ln() - tmp + ser.ln()
518}
519
520fn nested_log_likelihood(
521 times: &[f64],
522 events: &[bool],
523 covariates: &[&[f64]],
524 beta: &[f64],
525 inner_frailties: &[f64],
526 outer_frailties: &[f64],
527 subj_to_inner: &[usize],
528 subj_to_outer: &[usize],
529 order: &[usize],
530 theta_inner: f64,
531 theta_outer: f64,
532 inner_infos: &[ClusterInfo],
533 outer_infos: &[ClusterInfo],
534) -> f64 {
535 let n = times.len();
536 let p = beta.len();
537
538 let exp_xb: Vec<f64> = (0..n)
539 .map(|j| {
540 let mut lin = 0.0_f64;
541 for col in 0..p {
542 lin += covariates[j][col] * beta[col];
543 }
544 lin.exp()
545 })
546 .collect();
547
548 let combined: Vec<f64> = (0..n)
549 .map(|j| outer_frailties[subj_to_outer[j]] * inner_frailties[subj_to_inner[j]])
550 .collect();
551
552 let mut ll = 0.0_f64;
554 let mut s0 = 0.0_f64;
555 for j in 0..n {
556 s0 += combined[j] * exp_xb[j];
557 }
558
559 let mut risk_ptr = 0_usize;
560 for &idx in order {
561 while risk_ptr < order.len() && times[order[risk_ptr]] < times[idx] - 1e-15 {
562 let rem = order[risk_ptr];
563 s0 -= combined[rem] * exp_xb[rem];
564 risk_ptr += 1;
565 }
566 if events[idx] {
567 let u = combined[idx].max(1e-30);
568 let mut xb = 0.0_f64;
569 for col in 0..p {
570 xb += covariates[idx][col] * beta[col];
571 }
572 ll += u.ln() + xb - s0.max(1e-30).ln();
573 }
574 }
575
576 let inv_ti = 1.0 / theta_inner.max(1e-15);
578 for (k, _ci) in inner_infos.iter().enumerate() {
579 let u = inner_frailties[k].max(1e-30);
580 ll += (inv_ti - 1.0) * u.ln() - u * inv_ti - lgamma(inv_ti) + inv_ti * inv_ti.ln();
581 }
582
583 let inv_to = 1.0 / theta_outer.max(1e-15);
585 for (k, _ci) in outer_infos.iter().enumerate() {
586 let u = outer_frailties[k].max(1e-30);
587 ll += (inv_to - 1.0) * u.ln() - u * inv_to - lgamma(inv_to) + inv_to * inv_to.ln();
588 }
589
590 ll
591}
592
593#[cfg(test)]
598mod tests {
599 use super::*;
600
601 fn generate_nested_data(
602 n_outer: usize,
603 n_inner_per_outer: usize,
604 n_per_inner: usize,
605 ) -> (Vec<f64>, Vec<bool>, Vec<Vec<f64>>, Vec<usize>, Vec<usize>) {
606 let mut times = Vec::new();
607 let mut events = Vec::new();
608 let mut covariates = Vec::new();
609 let mut inner_clusters = Vec::new();
610 let mut outer_clusters = Vec::new();
611
612 let mut inner_id = 0_usize;
613 for outer in 0..n_outer {
614 let outer_effect = 1.0 + 0.3 * (outer as f64 * 1.5).sin();
615 for inner_offset in 0..n_inner_per_outer {
616 let inner_effect = 1.0 + 0.2 * (inner_id as f64 * 2.3).sin();
617 for subj in 0..n_per_inner {
618 let x = ((inner_id * n_per_inner + subj) as f64 * 0.2).sin();
619 let rate = outer_effect * inner_effect * (0.3 * x).exp();
620 let pseudo_rand = 0.5
621 + 0.4
622 * ((outer * 11 + inner_offset * 7 + subj * 3) as f64 * 1.618)
623 .sin()
624 .abs();
625 let t = pseudo_rand / rate.max(0.01);
626 let event = (outer + inner_offset + subj) % 3 != 0;
627
628 times.push(t.max(0.01));
629 events.push(event);
630 covariates.push(vec![x]);
631 inner_clusters.push(inner_id);
632 outer_clusters.push(outer);
633 }
634 inner_id += 1;
635 }
636 }
637
638 (times, events, covariates, inner_clusters, outer_clusters)
639 }
640
641 #[test]
642 fn test_nested_frailty_basic() {
643 let (times, events, cov_owned, inner_cl, outer_cl) = generate_nested_data(3, 3, 10);
644 let covariates: Vec<&[f64]> = cov_owned.iter().map(|v| v.as_slice()).collect();
645
646 let model = NestedFrailtyModel::new(FrailtyConfig {
647 max_iterations: 50,
648 ..FrailtyConfig::default()
649 });
650 let result = model
651 .fit(×, &events, &covariates, &inner_cl, &outer_cl)
652 .expect("nested fit should succeed");
653
654 assert_eq!(result.outer_frailty_estimates.len(), 3);
655 assert_eq!(result.inner_frailty_estimates.len(), 9);
656 assert!(result.inner_variance > 0.0);
657 assert!(result.outer_variance > 0.0);
658 assert!(result.iterations > 0);
659 }
660
661 #[test]
662 fn test_nested_two_variance_components() {
663 let (times, events, cov_owned, inner_cl, outer_cl) = generate_nested_data(4, 2, 8);
664 let covariates: Vec<&[f64]> = cov_owned.iter().map(|v| v.as_slice()).collect();
665
666 let model = NestedFrailtyModel::new(FrailtyConfig {
667 max_iterations: 100,
668 ..FrailtyConfig::default()
669 });
670 let result = model
671 .fit(×, &events, &covariates, &inner_cl, &outer_cl)
672 .expect("nested fit should succeed");
673
674 assert!(
676 result.inner_variance > 0.0,
677 "Inner variance should be positive"
678 );
679 assert!(
680 result.outer_variance > 0.0,
681 "Outer variance should be positive"
682 );
683 }
684
685 #[test]
686 fn test_nested_empty_data_error() {
687 let model = NestedFrailtyModel::new(FrailtyConfig::default());
688 let result = model.fit(&[], &[], &[], &[], &[]);
689 assert!(result.is_err());
690 }
691
692 #[test]
693 fn test_nested_single_outer_cluster_error() {
694 let model = NestedFrailtyModel::new(FrailtyConfig::default());
695 let result = model.fit(
697 &[1.0, 2.0, 3.0, 4.0],
698 &[true, true, true, false],
699 &[&[0.1][..], &[0.2][..], &[0.3][..], &[0.4][..]],
700 &[0, 1, 2, 3],
701 &[0, 0, 0, 0], );
703 assert!(result.is_err());
704 }
705
706 #[test]
707 fn test_nested_baseline_hazard() {
708 let (times, events, cov_owned, inner_cl, outer_cl) = generate_nested_data(3, 2, 8);
709 let covariates: Vec<&[f64]> = cov_owned.iter().map(|v| v.as_slice()).collect();
710
711 let model = NestedFrailtyModel::new(FrailtyConfig {
712 max_iterations: 30,
713 ..FrailtyConfig::default()
714 });
715 let result = model
716 .fit(×, &events, &covariates, &inner_cl, &outer_cl)
717 .expect("nested fit should succeed");
718
719 assert!(!result.baseline_hazard.is_empty());
720 for i in 1..result.baseline_hazard.len() {
722 assert!(result.baseline_hazard[i].1 >= result.baseline_hazard[i - 1].1 - 1e-10);
723 }
724 }
725}